# Memory Inference Experiments

Train on associative inference task, assess integrative encoding (proactive) vs. retrieval-time inference (reactive)

Paired associate inference setup: goal is to measure A-C association
- Simplest version: AB BC -> A?
- Fan in & fan out versions

Other task ideas (not implemented here):
- Acquired equivalence: AB CB AD -> C?
- Other task: AB CB XY -> A? AB AC XY -> B?

QUESTIONS TO THINK THROUGH
- task details: 
    - many targets or just one? 
    - distractors? 
    - rewards & decision probes?
    - instructions?
- how to get logits on output tokens
- how to assess internal representations & dynamics

In [None]:
import sys, random, uuid
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from tqdm import tqdm
import torch
from transformers import pipeline, AutoTokenizer

# for running on colab
if 'google.colab' in sys.modules:
    !git clone https://github.com/sflippl/models-of-memory.git
    sys.path.append('models-of-memory')
    dir = 'models-of-memory'
else:
    print("Running locally")
    dir = '.'

Running locally


In [161]:
def get_token_list(model_id, exclude=['Ġ','->']):
    """Get all possible model tokens, exclude the ones we'll use for the experiment"""
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    vocab = tokenizer.get_vocab()
    sorted_vocab = sorted(vocab.items(), key=lambda x: x[1])
    return [token for token,idx in sorted_vocab if not any([e in token for e in exclude])]


def generate_prompt(all_tokens, n_sets=1, n_xy_pairs=0, query_index=0, fan_in=1, fan_out=1):
    """
    Each "set" contains AB and BC (strictly, just contains BC because of the fan in/out)
    XY pairs are randomly intermixed if desired
    Query index determines which of the sets is probed ("A?")
    Fan_in will determine how many A tokens lead to the same B token (A1->B->C, A2->B->C)
    Fan_out will determine how many B tokens each A token leads to (A->B1->C, A->B2->C)
    """
    # ensure n_sets is a multiple of fan_out
    if n_sets % fan_out != 0:
        raise ValueError("n_sets must be a multiple of fan_out")
    elif fan_in > 1 and fan_out > 1:
        raise ValueError("fan_in and fan_out should not both be more than 1")

    # generate tokens of different stimulus types
    n_a_tokens = n_sets * fan_in / fan_out # if fan_in = 2, 2 A tokens for each BC; if fan_out = 2, 1 A token for 2 BCs
    n_b_tokens = n_c_tokens = n_sets
    n_x_tokens = n_y_tokens = n_xy_pairs
    tokens = np.random.choice(all_tokens, n_total_tokens, replace=False)
    split_indices = np.cumsum([n_a_tokens, n_b_tokens, n_c_tokens, n_x_tokens, n_y_tokens], dtype=int)[:-1]
    a_tokens, b_tokens, c_tokens, x_tokens, y_tokens = np.split(tokens, split_indices)

    # resolve AB pairs according to fan structure
    ab_pairs = []
    for set_i in range(n_sets): # For each set, determine which A(s) and B(s) participate
        a_start = int(set_i * fan_in / fan_out) # if fanning out, want to skip through As
        for a_idx in range(a_start, a_start + fan_in): # if fanning in, want muliple As per B
            ab_pairs.append((a_tokens[a_idx], b_tokens[set_i]))

    # BC and XY pairs are easy
    bc_pairs = list(zip(b_tokens, c_tokens))
    np.random.shuffle(bc_pairs)
    xy_pairs = list(zip(x_tokens, y_tokens))

    train_pairs = ab_pairs + bc_pairs # AB pairs come before BC pairs
    # insert XY pairs at random locations
    for xy_pair in xy_pairs:
        train_pairs.insert(np.random.randint(0, len(train_pairs) + 1), xy_pair)

    test_probe = a_tokens[query_index]
    test_target = b_tokens[query_index] 

    prompt = "Learn the associations below to make subsequent decisions.\n\n"
    for token1, token2 in train_pairs:
        prompt += f"{token1}->{token2} "
    prompt = prompt[:-1] + '\n\n'# take off the last space
    prompt += f"Probe token: {test_probe}\n"
    prompt += "Target: "

    return prompt, test_probe, test_target


def make_pipe(model_id):
    pipe = pipeline("text-generation", model=model_id, dtype=torch.bfloat16, device_map="auto")
    return pipe


def query_model(pipe, inp):
    messages = [{"role": "user", "content": inp}]
    all_outputs = pipe(messages, max_new_tokens=256, temperature=0.7, top_p=0.8, top_k=20,min_p=0.1)
    output = all_outputs[0]['generated_text'][-1]['content'].strip().lower()
    return output, all_outputs

In [None]:
pipe = make_pipe(model_id="Qwen/Qwen3-4B-Instruct-2507")

In [34]:
all_tokens = get_token_list('Qwen/Qwen3-4B-Instruct-2507')

In [154]:
prompt, probe, target = generate_prompt(
    all_tokens, 
    n_sets=1, 
    n_xy_pairs=0, 
    query_index=0,
    fan_in=1, # e.g., A1->B->C, A2->B->C
    fan_out=1 # e.g., A->B1->C, A->B2->C
)
print(prompt)

Learn the associations below to make subsequent decisions.

çĳª->flashdata flashdata->(',

Probe token: çĳª
Target: 


In [None]:
# analyze how the network is doing the task

In [170]:
## VALUE/DECISION VERSION

def generate_prompt_value(all_tokens, n_sets=1, n_xy_pairs=0, query_index=0, fan_in=1, fan_out=1):
    """
    Each "set" contains AB and BC (strictly, just contains BC because of the fan in/out)
    XY pairs are randomly intermixed if desired
    Query index determines which of the sets is probed ("A?")
    Fan_in will determine how many A tokens lead to the same B token (A1->B->C, A2->B->C)
    Fan_out will determine how many B tokens each A token leads to (A->B1->C, A->B2->C)
    """
    # ensure n_sets is a multiple of fan_out
    if n_sets % fan_out != 0:
        raise ValueError("n_sets must be a multiple of fan_out")
    elif fan_in > 1 and fan_out > 1:
        raise ValueError("fan_in and fan_out should not both be more than 1")

    # in value/decision version, C tokens are rewards
    c_tokens = np.random.permutation(np.arange(n_sets))
    all_tokens = [a for a in all_tokens if a not in c_tokens]

    # generate tokens of different stimulus types
    n_a_tokens = n_sets * fan_in / fan_out # if fan_in = 2, 2 A tokens for each BC; if fan_out = 2, 1 A token for 2 BCs
    n_b_tokens = n_sets
    n_x_tokens = n_y_tokens = n_xy_pairs
    tokens = np.random.choice(all_tokens, n_total_tokens, replace=False)
    split_indices = np.cumsum([n_a_tokens, n_b_tokens, n_x_tokens, n_y_tokens], dtype=int)[:-1]
    a_tokens, b_tokens, x_tokens, y_tokens = np.split(tokens, split_indices)
    
    # resolve AB pairs according to fan structure
    ab_pairs = []
    for set_i in range(n_sets): # For each set, determine which A(s) and B(s) participate
        a_start = int(set_i * fan_in / fan_out) # if fanning out, want to skip through As
        for a_idx in range(a_start, a_start + fan_in): # if fanning in, want muliple As per B
            ab_pairs.append((a_tokens[a_idx], b_tokens[set_i]))

    # BC and XY pairs are easy
    c_tokens
    bc_pairs = list(zip(b_tokens, c_tokens))
    np.random.shuffle(bc_pairs)
    xy_pairs = list(zip(x_tokens, y_tokens))

    train_pairs = ab_pairs 
    # insert XY pairs at random locations
    for xy_pair in xy_pairs:
        train_pairs.insert(np.random.randint(0, len(train_pairs) + 1), xy_pair)


    # TEST PROBE - note we are choosing the probe basically randomly here. would have to modify to be exhaustive
    target_value = c_tokens[query_index]
    target_a = np.random.choice([a for a, b in ab_pairs if b == b_tokens[query_index]])

    foil_index = np.random.choice([i for i in range(n_sets) if i != query_index])
    foil_a = np.random.choice([a for a, b in ab_pairs if b == b_tokens[foil_index]])
    foil_value = c_tokens[foil_index]
    correct_token = target_a if target_value > foil_value else foil_a
    test_probe = (target_a, foil_a) if np.random.rand() < 0.5 else (foil_a, target_a)

    prompt = "Learn the associations below to make subsequent decisions.\n\n"
    for a,b in train_pairs:
        prompt += f"{a}->{b} "
    prompt = prompt[:-1] + "\n\nNow learn the associated rewards.\n\n"
    for b,c in bc_pairs:
        prompt += f"{b}->${c} "
    prompt = prompt[:-1] + "\n\nNow decide between the following tokens, to receive associated rewards. Only output one token.\n\n"
    prompt += f'{test_probe[0]} or {test_probe[1]}?\n\n' # take off the last space
    prompt += "Chosen token: "

    return prompt, test_probe, correct_token

prompt, probe, target = generate_prompt_value(
    all_tokens, 
    n_sets=2, 
    n_xy_pairs=0, 
    query_index=0,
    fan_in=1, # e.g., A1->B->C, A2->B->C
    fan_out=1 # e.g., A->B1->C, A->B2->C
)
print(prompt)

Learn the associations below to make subsequent decisions.

getApplication->sun å¤©åĽ½->ĉself

Now learn the associated rewards.

ĉself->$0 sun->$1

Now decide between the following tokens, to receive associated rewards. Only output one token.

å¤©åĽ½ or getApplication?

Chosen token: 
