# Verbal N-back Task Evaluation with Gemini 2.5 Flash
This notebook handles the interaction with the Gemini API to perform the N-back task.
It saves the raw results into `gemini_results.json` for downstream analysis.

In [4]:
pip install google-genai dotenv

Note: you may need to restart the kernel to use updated packages.


In [1]:
import os
import json
import time
import math
import numpy as np
from scipy.stats import norm
from tabulate import tabulate
from concurrent.futures import ThreadPoolExecutor, as_completed
from google import genai
from dotenv import load_dotenv

# Load variables from .env into the environment
load_dotenv()
 
# The client automatically looks for the GEMINI_API_KEY environment variable
client = genai.Client()

# Configurations
n_list = [1, 2, 3]
blocks = 50
model_name = "gemini-2.5-flash"
MAX_WORKERS = 5 # Adjust this based on your Gemini API rate limits (e.g., Free vs Paid tier)

# Initialize data structures
all_trials = {}
for n in n_list:
    for b in range(blocks):
        path = os.path.join('/home/white/FTEC5660/projects/individual/ChatGPT-WM', 'datasets', 'letters', f'{n}back', f'{b}.txt')
        with open(path, 'r') as f:
            seq = f.readline().strip()
            cond = f.readline().strip()
            
        trials = []
        for i in range(len(seq)):
            trial = {
                'stimulus': seq[i],
                'target': cond[i],
                'response': '',
                'correct': ''
            }
            trials.append(trial)

        all_trials[f'{n}back_{b}'] = trials

def get_system_instruction(n):
    if n == 1:
        target_desc = "previous one"
    elif n == 2:
        target_desc = "letter two trials ago"
    elif n == 3:
        target_desc = "letter three trials ago"
    else:
        target_desc = f"letter {n} trials ago"
        
    return f"You are asked to perform a {n}-back task. You will see a sequence of letters. Your task is to respond with 'm' (no quotation marks, just the letter m) whenever the current letter is the same as the {target_desc}, and '-' (no quotation marks, just the dash sign) otherwise. Only 'm' and '-' are allowed responses. No explanations needed: please don't output any extra words!! The sequence will be presented one letter at a time. Now begins the task."

# Define the worker function that runs a single block sequentially
def process_block(n, b, block_trials):
    system_instruction = get_system_instruction(n)
    
    # Initialize the chat for this block
    chat = client.chats.create(model=model_name, config={"system_instruction": system_instruction, "temperature": 0.0})
    
    for i in range(len(block_trials)):
        input_letter = block_trials[i]['stimulus']
        target = block_trials[i]['target']
        
        # Retry logic for rate limits
        max_retries = 5
        chat_response = ""
        for attempt in range(max_retries):
            try:
                response = chat.send_message(input_letter)
                chat_response = response.text.strip().lower()
                break # Success, break out of retry loop
            except Exception as e:
                if attempt < max_retries - 1:
                    sleep_time = 2 ** attempt
                    time.sleep(sleep_time) # Exponential backoff
                else:
                    print(f"Error on {n}-back block {b} trial {i}: {e}")
                    chat_response = "invalid"

        # Evaluate response
        parsed_response = 'invalid'
        is_correct = False
        
        if chat_response in ['m', '-']:
            parsed_response = chat_response
            is_correct = (chat_response == target)
        elif len(chat_response) > 0:
            # Rule violation fallback: extract first letter
            first_char = chat_response[0]
            if first_char in ['m', '-']:
                parsed_response = first_char
                is_correct = (first_char == target)

        block_trials[i]['response'] = parsed_response
        block_trials[i]['correct'] = is_correct

    print(f"Finished {n}-back, block {b}")
    return f'{n}back_{b}', block_trials


# Run blocks in parallel
print("Starting parallel evaluation...")
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures = []
    
    # Submit all blocks to the thread pool
    for n in n_list:
        for b in range(blocks):
            futures.append(executor.submit(process_block, n, b, all_trials[f'{n}back_{b}']))
            
    # Collect results as they complete
    for future in as_completed(futures):
        block_key, updated_trials = future.result()
        all_trials[block_key] = updated_trials

# Save the final results
with open('gemini_results.json', 'w') as f:
    json.dump(all_trials, f)

print("Evaluation completed. Computing stats...")

# Compute summary stats and raw_data for distribution plots
def compute_summary_stats(all_trials, n_list, blocks):
    summary_stats = {}
    raw_data = {}
    for n in n_list:
        hit_rate, false_alarm_rate, accuracy, d_prime = [], [], [], []
        for b in range(blocks):
            trials = all_trials[f"{n}back_{b}"]
            
            hits, false_alarms, total_targets, total_lures = 0, 0, 0, 0
            for trial in trials:
                if trial['target'] == 'm':
                    total_targets += 1
                    if trial['correct']: hits += 1
                else:
                    total_lures += 1
                    if not trial['correct']: false_alarms += 1
                        
            hr = (hits / total_targets) * 100 if total_targets > 0 else 0
            far = (false_alarms / total_lures) * 100 if total_lures > 0 else 0
            acc = ((hits + (total_lures - false_alarms)) / (total_targets + total_lures)) * 100 if (total_targets + total_lures) > 0 else 0
            
            hit_rate.append(hr)
            false_alarm_rate.append(far)
            accuracy.append(acc)

            # Compute d prime
            hit_rate_adjusted = np.clip(hits / total_targets if total_targets > 0 else 0, 0.01, 0.99)
            false_alarm_rate_adjusted = np.clip(false_alarms / total_lures if total_lures > 0 else 0, 0.01, 0.99)
            d_prime.append(np.abs(norm.ppf(hit_rate_adjusted) - norm.ppf(false_alarm_rate_adjusted)))
        
        raw_data[n] = {
            "hit_rate": hit_rate,
            "false_alarm_rate": false_alarm_rate,
            "accuracy": accuracy,
            "d_prime": d_prime
        }
            
        summary_stats[n] = {
            "hit_rate": {"mean": np.mean(hit_rate),  "stderr": np.std(hit_rate) / math.sqrt(blocks)},
            "false_alarm_rate": {"mean": np.mean(false_alarm_rate), "stderr": np.std(false_alarm_rate) / math.sqrt(blocks)},
            "accuracy": {"mean": np.mean(accuracy), "stderr": np.std(accuracy) / math.sqrt(blocks)},
            "d_prime": {"mean": np.mean(d_prime), "stderr": np.std(d_prime) / math.sqrt(blocks)}
        }
    return summary_stats, raw_data

summary_stats_gemini, raw_data_gemini = compute_summary_stats(all_trials, n_list, blocks)

def create_table(summary_stats, n_list):
    headers = ["N-back", "Hit Rate (%)", "False Alarm Rate (%)", "Accuracy (%)", "D Prime"]
    table_data = []
    for n in n_list:
        row = [f"{n}-back",
               f"{summary_stats[n]['hit_rate']['mean']:.2f} ± {summary_stats[n]['hit_rate']['stderr']:.2f}",
               f"{summary_stats[n]['false_alarm_rate']['mean']:.2f} ± {summary_stats[n]['false_alarm_rate']['stderr']:.2f}",
               f"{summary_stats[n]['accuracy']['mean']:.2f} ± {summary_stats[n]['accuracy']['stderr']:.2f}",
               f"{summary_stats[n]['d_prime']['mean']:.2f} ± {summary_stats[n]['d_prime']['stderr']:.2f}"]
        table_data.append(row)

    table_str = tabulate(table_data, headers=headers, tablefmt="grid")
    return table_str

print("Gemini-2.5-Flash Results:")
print(create_table(summary_stats_gemini, n_list))

Starting parallel evaluation...
Finished 1-back, block 1
Finished 1-back, block 0
Finished 1-back, block 4
Finished 1-back, block 3
Finished 1-back, block 2
Finished 1-back, block 5
Finished 1-back, block 7
Finished 1-back, block 8
Finished 1-back, block 6
Finished 1-back, block 9
Finished 1-back, block 10
Finished 1-back, block 12
Finished 1-back, block 13
Finished 1-back, block 11
Finished 1-back, block 14
Finished 1-back, block 17
Finished 1-back, block 15
Finished 1-back, block 16
Finished 1-back, block 18
Finished 1-back, block 19
Finished 1-back, block 22
Finished 1-back, block 21
Finished 1-back, block 20
Finished 1-back, block 23
Finished 1-back, block 24
Finished 1-back, block 25
Finished 1-back, block 26
Finished 1-back, block 27
Finished 1-back, block 28
Finished 1-back, block 29
Finished 1-back, block 30
Finished 1-back, block 31
Finished 1-back, block 32
Finished 1-back, block 33
Finished 1-back, block 34
Finished 1-back, block 35
Finished 1-back, block 36
Finished 1-back,