# Paired associate learning

In [18]:
list_length = 50
interjection_length = 200

import random
import numpy as np
import matplotlib.pyplot as plt
from helpers import * # general helper functions
%reload_ext autoreload
%autoreload 2

## task-specific helper functions

In [19]:
def plot_results(true_list, output_list, interjected_list=[], gated=False):
    ncols = 2

    recall = [w in output_list for w in true_list]
    precision = [w in set(true_list) for w in np.unique(output_list)]
    if gated:
        subset_of_words = [w for w in true_list if w[0] in ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']]
        recall_gated = [w in output_list for w in subset_of_words]
        precision_gated = [w in set(subset_of_words) for w in np.unique(output_list)]
        ncols +=2

    i_bool = len(interjected_list) >0
    if i_bool:
        precision_int = [w in set(interjected_list) for w in np.unique(output_list)]
        ncols += 1
    
    fig,ax = plt.subplots(1,ncols)
    ax[0].plot(recall)
    ax[0].set_title('Recall')
    ax[0].set_xlabel('Word in true list')
    ax[0].set_ylabel('Present in recalled list')

    ax[1].plot(precision)
    ax[1].set_title('Precision')
    ax[1].set_xlabel('Word in output (unique)')
    ax[1].set_ylabel('Present in true list')

    if i_bool:
        ax[2].plot(precision_int)
        ax[2].set_title('False alarms to interjections')
        ax[2].set_xlabel('Word in output (unique)')
        ax[2].set_ylabel('Present in interjection')
    
    if gated:
        ax[-2].plot(recall_gated)
        ax[-2].set_title('Gated recall')
        ax[-2].set_xlabel('Word in GATED true list (abcdefgh)')
        ax[-2].set_ylabel('Present in recalled list')

        ax[-1].plot(precision_gated)
        ax[-1].set_title('Gated precision')
        ax[-1].set_xlabel('Word in output (unique)')
        ax[-1].set_ylabel('Present in GATED true list (abcdefgh)')

## define task

In [20]:
all_words = txt_to_list('_data/wasnorm_wordpool.txt')

words = subsample_words(all_words, list_length*2, seed=0)
first_words = words[:list_length]
second_words = words[list_length:]
pairs = make_pairs(first_words, second_words)

np.random.seed(0)
perm = np.random.permutation(len(pairs)) 
shuffled_first_words = [first_words[i] for i in perm]
targets = [second_words[i] for i in perm]

interjection_words = subsample_words(all_words, interjection_length*2, seed=0, avoid=words)
interjection_pairs = make_pairs(interjection_words[:interjection_length], interjection_words[interjection_length:])


In [27]:
pa_inp = make_inp(
    pairs, shuffled_first_words,
    preamble = "Here is a list of word pairs:",
    cue = "Now we'll begin the memory test. Here is a list of words, one from each pair, in random order " \
        "Please output a list of corresponding word pairs from the studied list. " \
        "Do not think step by step, only output the list of words. Here are the test words: "
)
print(pa_inp)

pa_inp_informed = make_inp(
    pairs, shuffled_first_words,
    preamble = "Here is a list of word pairs. After seeing this list, you will be tested on your memory for the pairs:",
    cue = "Now we'll begin the memory test. Here is a list of words, one from each pair, in random order. " \
        "Please output a list of corresponding word pairs from the studied list. " \
        "Do not think step by step, only output the list of words. Here are the test words: "
)
print(pa_inp_informed)

pa_inp_interjection = make_inp(
    pairs, shuffled_first_words,
    preamble = "Here is a list of word pairs:",
    cue = "Now we'll begin the memory test. Here is a list of words, one from each pair in the original list, in random order. " \
        "Please output a list of corresponding word pairs from the studied list. " \
        "Do not think step by step, only output the list of words. Here are the test words:",
    interjection = "You will now see another list of word pairs:",
    interjection_words = interjection_pairs,
)
print(pa_inp_interjection)

pa_inp_interjection_informed = make_inp(
    pairs, shuffled_first_words,
    preamble = "Here is a list of word pairs. After seeing this list, you will see another list of pairs, and then be tested on your memory for these original pairs:",
    cue = "Now we'll begin the memory test. Here is a list of words, one from each pair in the original list, in random order. " \
        "Please output a list of corresponding word pairs from the studied list. " \
        "Do not think step by step, only output the list of words. Here are the test words:",
    interjection = "You will now see another list of word pairs. Later, in the memory test, you should not base your responses on the pairs in this list:",
    interjection_words = interjection_pairs,
)
print(pa_inp_interjection_informed)

pa_inp_gated = make_inp(
    pairs, shuffled_first_words,
    preamble = "Here is a list of word pairs:",
    cue = "Now we'll begin the memory test. Here is a list of words, one from each pair in the original list, in random order. " \
        "Please output a list of each word's pair in the studied list, but ONLY for words whose pair (second word) starts with a,b,c,d,e,f,g, or h. " \
        "Do not think step by step, only output the list of words. Here are the test words:",
)
print(pa_inp_gated)

pa_inp_gated_informed = make_inp(
    pairs, shuffled_first_words,
    preamble = "Here is a list of word pairs. After seeing this list, you will see another list of pairs, and then " \
        "be tested on your memory ONLY for the pairs where the SECOND word starts with a,b,c,d,e,f,g, and h.",
    cue = "Now we'll begin the memory test. Here is a list of words, one from each pair in the original list, in random order. " \
        "Please output a list of each word's pair in the studied list, but ONLY for words whose pair (second word) starts with a,b,c,d,e,f,g, or h. " \
        "Do not think step by step, only output the list of words. Here are the test words:",
)
print(pa_inp_gated_informed)

pa_inp_gated_informed_interjection = make_inp(
    pairs, shuffled_first_words,
    preamble = "Here is a list of word pairs. After seeing this list, you will see another list of pairs, and then " \
        "be tested on your memory ONLY for the pairs where the SECOND word starts with a,b,c,d,e,f,g, and h.",
    cue = "Now we'll begin the memory test. Here is a list of words, one from each pair in the original list, in random order. " \
        "Please output a list of each word's pair in the studied list, but ONLY for words whose pair (second word) starts with a,b,c,d,e,f,g, or h. " \
        "Do not think step by step, only output the list of words. Here are the test words:",
    interjection = "You will now see another list of words. Later, in the memory test, you should not base your responses on the pairs in this list:",
    interjection_words = interjection_pairs,

)
print(pa_inp_gated_informed_interjection)

Here is a list of word pairs: knuckle-toenail, vehicle-lid, man-thicket, bank-widow, employee-stallion, perch-shell, ozone-acorn, list-screw, whistle-paper, garbage-hand, onion-driveway, igloo-transplant, roach-guard, deodorant-teapot, pearl-bird, cheddar-cousin, fleet-raisin, cheek-detergent, van-dough, bulletin-chemist, senate-postage, zucchini-moss, earring-bug, plumber-brain, telephone-grasshopper, sap-pen, chimney-paint, giraffe-camera, butcher-gallon, traitor-producer, boss-fountain, sugar-temple, hall-castle, oatmeal-priest, quail-hamper, button-porcupine, house-crutch, microphone-saturn, goo-president, scout-rodent, sister-foot, cub-monster, proton-buggy, monastery-sage, pill-zebra, enemy-gorilla, biologist-refrigerator, prince-dresser, antler-fort, bulb-copier. Now we'll begin the memory test. Here is a list of words, one from each pair, in random order Please output a list of corresponding word pairs from the studied list. Do not think step by step, only output the list of wo

## run task

In [None]:
pipe = make_pipe(model_id = "Qwen/Qwen3-4B-Instruct-2507")

In [None]:
for inp in [pa_inp, pa_inp_informed, pa_inp_interjection, pa_inp_interjection]:
    print(inp)
    output = query_model(pipe, inp)
    plot_results(targets, output)

In [None]:
for inp in [pa_inp, pa_inp_informed, pa_inp_interjection, pa_inp_interjection]:
    print(inp)
    output = query_model(pipe, inp)
    plot_results(targets, output, gated=True)