# List learning - recognition

In [57]:
list_length = 50
interjection_length = 200
include_interjected_in_test = False

import random
import numpy as np
import matplotlib.pyplot as plt
from helpers import * # general helper functions
%reload_ext autoreload
%autoreload 2

## task-specific helper functions

In [56]:
def query_model_recognition(pipe, inp):
    outputs = query_model_helper(pipe, inp)
    lst_words = outputs[0]["generated_text"][-1]['content'].split('â€” Wait')[0].split(', ')
    lst_words = [w.strip().lower() for w in lst_words] # 'yes'/'no'
    if 'yes' not in lst_words or 'no' not in lst_words or len(np.unique(lst_words)) > 2:
        raise("ERROR - unexpected output: ", lst_words)
    lst_binary = [int(w=='yes') for w in lst_words]
    return lst_binary

from scipy.stats import norm
def plot_results(true, resp):
    true = np.array(true)
    resp = np.array(resp)
    hit_rate = np.sum(true == 1 and resp == 1)/np.sum(true==1)
    false_alarm_rate = np.sum(resp == 1 and true != 1)/np.sum(true!=1)
    d_prime = norm.ppf(hit_rate) - norm.ppf(false_alarm_rate)
    x,y,labels = [1,2,3],[hit_rate, false_alarm_rate, d_prime],['hit rate', 'false alarm rate', 'd\'']

    if include_interjected_in_test:
        x.extend([4,5])
        labels.extend(['false alarm rate on new', 'false alarm rate on interjected'])
        y.extend([
            (np.sum(true == 0 and resp == 1)/np.sum(true==0)),
            (np.sum(true == 2 and resp == 1)/np.sum(true==2)),
        ])

    plt.figure()
    plt.suptitle('Recognition test performance')
    plt.bar(x,y)
    plt.xticks(x, labels=labels, rotation=30)
    plt.ylim(0,1)
    plt.show()

## define task

In [58]:
all_words = txt_to_list('_data/wasnorm_wordpool.txt')

words = subsample_words(all_words, list_length, seed=0)

lures = subsample_words(all_words, list_length, seed=0, avoid=words)

interjection_words = subsample_words(all_words, list_length, seed=0, avoid=words+lures)

if include_interjected_in_test:
    interjection_sample = subsample_words(interjection_words, list_length, seed=0)
    combined = [(w, 1) for w in words] + [(l, 0) for l in lures] + [(l, 2) for l in interjection_sample] # new label 2 for interjection
else:
    combined = [(w, 1) for w in words] + [(l, 0) for l in lures]
    
np.random.seed(0)
perm = np.random.permutation(len(combined)) 
test_words = [combined[i][0] for i in perm]
test_labels = [combined[i][1] for i in perm]

In [63]:
recognition_inp = make_inp(
    words,test_words,
    preamble = "Here is a list of words:",
    cue = "Now we'll begin the memory test. Here is another list of words, " \
        "containing some words from the original list and some new words. " \
        "Please output a list of \"yes\" or \"no\" corresponding to whether each word was in the original list. " \
        "Only output the list of \"yes\" and \"no\", nothing else. Here are the test words:"
)
print(recognition_inp)

recognition_inp_informed = make_inp(
    words,test_words,
    preamble = "Here is a list of words. After seeing this list, you will be tested on your memory for words in this list:",
    cue = "Now we'll begin the memory test. Here is another list of words, containing some words from the original list and some new words. " \
        "Please output a list of \"yes\" or \"no\" corresponding to whether each word was in the original list. " \
        "Only output the list of \"yes\" and \"no\", nothing else. Here are the test words:"
)
print(recognition_inp_informed)


recognition_inp_interjection = make_inp(
    words,test_words,
    preamble = "Here is a list of words:",
    cue = "Now we'll begin the memory test. Here is another list of words, " \
        "containing some words from the original list, some from the intervening list, and some new words. " \
        "Please output a list of \"yes\" or \"no\" corresponding to whether each word was in the ORIGINAL list (not the intervening list). " \
        "Only output the list of \"yes\" and \"no\", nothing else. Here are the test words: ",
    interjection = "You will now see another list of words:",
    interjection_words = interjection_words,
)
print(recognition_inp_interjection)


recognition_inp_interjection_informed = make_inp(
    words, test_words,
    preamble = "Here is a list of words. After seeing this list, you will see another list, " \
        "and then a memory test for the words in the original (first) list:",
    cue = "Now we'll begin the memory test. Here is another list of words, " \
        "containing some words from the original list, some from the intervening list, and some new words. " \
        "Please output a list of \"yes\" or \"no\" corresponding to whether each word was in the ORIGINAL list (not the intervening list). " \
        "Only output the list of \"yes\" and \"no\", nothing else. Here are the test words:",
    interjection = "You will now see another list of words. Later, when doing the memory test, " \
        "you should not base your responses on the words in this list:",
    interjection_words = interjection_words,
)
print(recognition_inp_interjection_informed)

Here is a list of words: knuckle, vehicle, man, bank, employee, perch, ozone, list, whistle, garbage, onion, igloo, roach, deodorant, pearl, cheddar, fleet, cheek, van, bulletin, senate, zucchini, earring, plumber, telephone, sap, chimney, giraffe, butcher, traitor, boss, sugar, hall, oatmeal, quail, button, house, microphone, goo, scout, sister, cub, proton, monastery, pill, enemy, biologist, prince, antler, bulb. Now we'll begin the memory test. Here is another list of words, containing some words from the original list and some new words. Please output a list of "yes" or "no" corresponding to whether each word was in the original list. Only output the list of "yes" and "no", nothing else. Here are the test words:chimney, guardian, man, planet, grasshopper, family, fleet, shark, face, railroad, banner, pool, boulevard, deodorant, list, boss, earring, telephone, oatmeal, whistle, monastery, dice, bank, prison, enemy, antler, ozone, lollipop, reptile, cabin, jewel, heater, rattle, bus,

## run task

In [None]:
pipe = make_pipe(model_id = "Qwen/Qwen3-4B-Instruct-2507")

In [None]:
for inp in [recognition_inp, recognition_inp_informed, recognition_inp_interjection, recognition_inp_interjection_informed]:
    print(inp)
    binary_output = query_model_recognition(pipe, inp)
    plot_results(test_labels, binary_output)