# Functions



In [None]:
"""
***Memory Inference Experiments***

Train on associative inference task, assess integrative encoding (proactive) vs. retrieval-time inference (reactive)

Paired associate inference setup: goal is to measure A-C association
- Simplest version: AB BC -> A? -> measure output probability of B and C
- Fan in & fan out versions (multiple As to one B, multiple Bs to one A)

Other task ideas (not implemented here):
- Acquired equivalence: AB CB AD -> C?
- Other task: AB CB XY -> A? AB AC XY -> B?

"""

import sys, random, uuid, os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from tqdm import tqdm
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import nltk
nltk.download('wordnet')
nltk.download('names')
nltk.download('gazetteers')
from nltk.corpus import wordnet, names, gazetteers

# for running on colab
if 'google.colab' in sys.modules:
    os.system("git clone https://github.com/sflippl/models-of-memory.git")
    sys.path.append('models-of-memory')
    dir = 'models-of-memory'
else:
    print("Running locally")
    dir = '.'

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package names to /root/nltk_data...
[nltk_data]   Package names is already up-to-date!
[nltk_data] Downloading package gazetteers to /root/nltk_data...
[nltk_data]   Package gazetteers is already up-to-date!


In [None]:
def generate_stimuli(all_tokens, n_sets, n_distractor_pairs=0,
                     fan_in_pct=0, fan_out_pct=0, fan_in_degree=0, fan_out_degree=0,
                     stimulus_type='words', stimulus_dict=None):
    """
    Each "set" is A->B->C, but strictly its the number of C items (because of fan in/fan out)
    Distractor XY pairs are randomly intermixed if requested
    Fan structure:
        - fan_in_pct: proportion of sets with fan-in structure (multiple As -> one B -> one C)
        - fan_out_pct: proportion of sets with fan-out structure (one A -> multiple Bs -> one C)
        - fan_in_degree: how many A tokens lead to the same B (e.g., 3 means A1->B, A2->B, A3->B->C)
        - fan_out_degree: how many B tokens each A leads to (e.g., 5 means A->B1, A->B2, ..., A->B5, then A->C)
    stimulus_type: 'words', 'names', or 'fakenames'
    stimulus_dict: optional dict with 'A' list and 'BC_pairs' list of tuples for cities/countries mode
    Returns:
        - list of tuples of training pairs
        - list of possible probes (A)
        - list of lists of direct targets B (may be more than one for fan out)
        - list of indirect targets C (only one per A)
        - list of fan types (simple, fan_in, fan_out)
    """
    if fan_in_pct + fan_out_pct > 1:
        raise ValueError("fan_in_pct + fan_out_pct cannot exceed 1.0")

    # determine how many sets of each type
    n_sets_fan_in = int(n_sets * fan_in_pct)
    n_sets_fan_out = int(n_sets * fan_out_pct)
    n_sets_simple = n_sets - n_sets_fan_in - n_sets_fan_out

    # calculate and generate tokens
    n_a_tokens = n_sets_simple + (n_sets_fan_in * fan_in_degree) + n_sets_fan_out
    n_b_tokens = n_sets_simple + n_sets_fan_in + (n_sets_fan_out * fan_out_degree)
    n_c_tokens = n_sets  # one C per set
    n_x_tokens = n_y_tokens = n_distractor_pairs

    if stimulus_dict:
        a_tokens = np.random.choice(stimulus_dict['A'], n_a_tokens, replace=False)
        # Select matching BC pairs
        bc_indices = np.random.choice(len(stimulus_dict['BC_pairs']), n_sets, replace=False)
        selected_bc = [stimulus_dict['BC_pairs'][i] for i in bc_indices]
        b_tokens = [bc[0] for bc in selected_bc]
        c_tokens = [bc[1] for bc in selected_bc]
        x_tokens = np.random.choice(all_tokens, n_x_tokens, replace=False)
        y_tokens = np.random.choice(all_tokens, n_y_tokens, replace=False)
    else:
        n_total_tokens = int(n_a_tokens + n_b_tokens + n_c_tokens + n_x_tokens + n_y_tokens)
        tokens = np.random.choice(all_tokens, n_total_tokens, replace=False)
        split_indices = np.cumsum([n_a_tokens, n_b_tokens, n_c_tokens, n_x_tokens, n_y_tokens], dtype=int)[:-1]
        a_tokens, b_tokens, c_tokens, x_tokens, y_tokens = np.split(tokens, split_indices)

    # build pairs and track mappings
    ab_pairs, bc_pairs, xy_pairs = [],[],[]
    direct_targets = []  # list of lists: each entry corresponds to each A, containing a list of its Bs
    indirect_targets = []  # list: each entry corresponds to each A, containing its C
    pair_types = [] # track whether each pair is AB, BC, or XY
    fan_types = []  # track whether each A is simple, fan_in, or fan_out
    a_idx, b_idx, c_idx = 0, 0, 0 # keep track of where we are in the tokens

    # simple sets: 1 A -> 1 B -> 1 C
    for _ in range(n_sets_simple):
        ab_pairs.append((a_tokens[a_idx], b_tokens[b_idx]))
        bc_pairs.append((b_tokens[b_idx], c_tokens[c_idx]))
        direct_targets.append([b_tokens[b_idx]])
        indirect_targets.append(c_tokens[c_idx])
        fan_types.append('simple')
        a_idx += 1
        b_idx += 1
        c_idx += 1

    # Fan-in sets: multiple As -> 1 B -> 1 C
    for _ in range(n_sets_fan_in):
        for _ in range(fan_in_degree): # A1->B, A2->B, ...
            ab_pairs.append((a_tokens[a_idx], b_tokens[b_idx]))
            direct_targets.append([b_tokens[b_idx]])
            indirect_targets.append(c_tokens[c_idx])
            fan_types.append('fan_in')
            a_idx += 1
        bc_pairs.append((b_tokens[b_idx], c_tokens[c_idx])) # only one BC pair per set
        b_idx += 1
        c_idx += 1

    # Fan-out sets: 1 A -> multiple Bs; only first B -> C
    for _ in range(n_sets_fan_out):
        b_list = []
        for i in range(fan_out_degree): # A->B1, A->B2, ...
            ab_pairs.append((a_tokens[a_idx], b_tokens[b_idx]))
            b_list.append(b_tokens[b_idx])
            if i == 0:  # Only pair the first B with C
                bc_pairs.append((b_tokens[b_idx], c_tokens[c_idx]))
            b_idx += 1
        direct_targets.append(b_list)
        indirect_targets.append(c_tokens[c_idx])
        fan_types.append('fan_out')
        a_idx += 1
        c_idx += 1

    xy_pairs = list(zip(x_tokens, y_tokens))

    # Shuffle BC pairs
    bc_perm = np.random.permutation(len(bc_pairs))
    bc_shuffled = [bc_pairs[i] for i in bc_perm]

    # Combine training pairs and track types
    train_pairs = ab_pairs + bc_shuffled
    pair_types = ['AB'] * len(ab_pairs) + ['BC'] * len(bc_pairs)

    # Insert XY pairs at random locations
    for xy_pair in xy_pairs:
        idx = np.random.randint(0, len(train_pairs) + 1)
        train_pairs.insert(idx, xy_pair)
        pair_types.insert(idx, 'XY')

    return train_pairs, pair_types, fan_types, a_tokens, direct_targets, indirect_targets



def generate_prompt(train_pairs, pair_types, test_probe, stimulus_type='words',
                    prompt_type='standard', target=None, foil=None, ):
    prompt = ""
    # training pairs
    if stimulus_type == 'words':
        for p in train_pairs:
            prompt += f"{p[0]}->{p[1]} "
    else:
        for p, p_type in zip(train_pairs, pair_types):
            if p_type == 'AB' or ():
                prompt += f"{p[0]} is from {p[1]}. "
            elif p_type == 'BC':
                prompt += f"{p[0]} is in {p[1]}. "
            elif random.random() < 0.5: # XY
                prompt += f"{p[0]} is from {p[1]}. "
            else:
                prompt += f"{p[0]} is in {p[1]}. "

    # Query structure
    if prompt_type == 'standard':
        if stimulus_type == 'words':
            prompt += f"{test_probe}->"
        else:
            prompt += f"{test_probe} is from "
    elif prompt_type == 'afc':
        choices = [target, foil]
        random.shuffle(choices)
        if stimulus_type == 'words':
            prompt += f"{test_probe}->{choices[0]} or {choices[1]}? "
        else:
            prompt += f"Is {test_probe} from {choices[0]} or {choices[1]}? "
    return prompt

In [None]:
def load_model(model_id):
    model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    return model, tokenizer


def get_single_tokens(tokenizer, arr):
  return [a for a in arr if len( tokenizer(a, add_special_tokens=False)["input_ids"] ) == 1 ]

def get_single_token_nouns(tokenizer):
    """get list of nouns from wordnet that are single tokens in the given model"""
    nouns = [lemma.name() for syn in wordnet.all_synsets("n") for lemma in syn.lemmas()]
    nouns = [n for n in nouns if n.isalpha()]
    return get_single_tokens(tokenizer, nouns)

def get_single_token_names(tokenizer):
    """get list of names from nltk names corpus that are single tokens"""
    all_names = names.words('male.txt') + names.words('female.txt')
    return list(set(get_single_tokens(tokenizer, all_names)))

def get_single_token_geography(tokenizer):
    pairs = [
        ("Paris", "France"), ("Berlin", "Germany"), ("Rome", "Italy"), ("Madrid", "Spain"), ("Lisbon", "Portugal"),
        ("Vienna", "Austria"), ("Brussels", "Belgium"), ("Athens", "Greece"), ("Warsaw", "Poland"), ("Prague", "Czechia"),
        ("Tokyo", "Japan"), ("Seoul", "Korea"), ("Beijing", "China"), ("Bangkok", "Thailand"), ("Hanoi", "Vietnam"),
        ("Delhi", "India"), ("Manila", "Philippines"), ("Jakarta", "Indonesia"), ("Cairo", "Egypt"), ("Nairobi", "Kenya"),
        ("Lagos", "Nigeria"), ("Accra", "Ghana"), ("Tunis", "Tunisia"), ("Algiers", "Algeria"), ("Moscow", "Russia"),
        ("Kyiv", "Ukraine"), ("Oslo", "Norway"), ("Stockholm", "Sweden"), ("Helsinki", "Finland"), ("Copenhagen", "Denmark"),
        ("Dublin", "Ireland"), ("London", "UK"), ("Ottawa", "Canada"), ("Mexico", "Mexico"), ("Havana", "Cuba"),
        ("Kingston", "Jamaica"), ("Panama", "Panama"), ("Bogota", "Colombia"), ("Quito", "Ecuador"), ("Lima", "Peru"),
        ("Santiago", "Chile"), ("Brasilia", "Brazil"), ("Caracas", "Venezuela"), ("Sydney", "Australia"), ("Suva", "Fiji"),
        ("Amman", "Jordan"), ("Beirut", "Lebanon"), ("Baghdad", "Iraq"), ("Tehran", "Iran"), ("Riyadh", "Saudi"),
        ("Kuwait", "Kuwait"), ("Doha", "Qatar"), ("Muscat", "Oman"), ("Kabul", "Afghanistan"), ("Islamabad", "Pakistan")
    ]
    single_token_pairs = []
    for city, country in pairs:
        country_ids = tokenizer(country, add_special_tokens=False)["input_ids"] # this is the only one that needs to be single-token
        if len(country_ids) == 1:
            single_token_pairs.append((city, country))
    return single_token_pairs


def get_fictional_single_tokens(tokenizer, count):
    """
    Scavenges the model's own vocabulary for strings it treats as atomic
    but that are not recognized as real English nouns.
    """
    # Get all potential strings from the model's vocabulary
    # (Handling various tokenizer formats like GPT, Llama, and BERT)
    vocab = tokenizer.get_vocab().keys()
    fictional_pool = []

    for t in vocab:
        # 1. Clean up subword/whitespace markers (like 'Ġ', ' ', '##')
        clean = t.replace('Ġ', '').replace(' ', '').replace('##', '')

        # 2. Heuristics for a "Country" name (alphabetic, 5-8 chars)
        if clean.isalpha() and 5 <= len(clean) <= 8:
            # 3. Validation: Verify it stays a single token and isn't a real word
            if len(tokenizer.encode(clean, add_special_tokens=False)) == 1:
                if not wordnet.synsets(clean.lower()):
                    fictional_pool.append(clean.capitalize())

        if len(fictional_pool) >= count + 50: # Get a buffer then stop
            break

    return random.sample(fictional_pool, min(count, len(fictional_pool)))

# Load model

In [None]:
model_id = "Qwen/Qwen3-4B"
model, tokenizer = load_model(model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

# Run (words)

In [None]:
n_sets = 3 # how many C items (BC pairs)
n_distractor_pairs = 0 # how many XY pairs
fan_in_pct = 0.0 # percentage of sets that fan in
fan_out_pct = 0.0 # percentage of sets that fan out
fan_in_degree = 2 # how many As per B
fan_out_degree = 2 # how many Bs per A

stimulus_type = 'words'

In [None]:
all_tokens = get_single_token_nouns(tokenizer)

training_pairs, pair_types, fan_types, test_probes, test_direct_targets, test_indirect_targets = generate_stimuli(
    all_tokens, n_sets, stimulus_type=stimulus_type)

probe_idx = 0
probe = test_probes[probe_idx]
direct_targets = test_direct_targets[probe_idx]
indirect_target = test_indirect_targets[probe_idx]
fan_type = fan_types[probe_idx]

In [None]:
# STANDARD PROBE
prompt_type = 'standard'
prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type)

print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt,
                      list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

scene->dog box->being oka->progress being->lug progress->comfort dog->Nation scene-> 

Probe: scene, direct target: dog, indirect target: Nation

Results (ranked):
(np.str_('dog'), {'logit': 13.5625, 'prob': 0.05419921875, 'rank': 2})
(np.str_('scene'), {'logit': 12.6875, 'prob': 0.0225830078125, 'rank': 3})
(np.str_('being'), {'logit': 11.5, 'prob': 0.00689697265625, 'rank': 7})
(np.str_('Nation'), {'logit': 10.875, 'prob': 0.003692626953125, 'rank': 16})
(np.str_('progress'), {'logit': 10.125, 'prob': 0.00174713134765625, 'rank': 47})
(np.str_('box'), {'logit': 9.9375, 'prob': 0.00144195556640625, 'rank': 56})
(np.str_('comfort'), {'logit': 9.375, 'prob': 0.000823974609375, 'rank': 120})
(np.str_('lug'), {'logit': 6.59375, 'prob': 5.1021575927734375e-05, 'rank': 2177})
(np.str_('oka'), {'logit': 4.21875, 'prob': 4.738569259643555e-06, 'rank': 12634})


In [None]:
# AFC PROBE
prompt_type = 'afc'
foil = None
if prompt_type == 'afc':
    other_targets = [t for t in test_indirect_targets if t != indirect_target]
    foil = random.choice(other_targets)

prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type, target=indirect_target, foil=foil)

print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt,
                      list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

scene->dog box->being oka->progress being->lug progress->comfort dog->Nation scene->comfort or Nation?  

Probe: scene, direct target: dog, indirect target: Nation

Results (ranked):
(np.str_('scene'), {'logit': 5.875, 'prob': 4.842877388000488e-07, 'rank': 16137})
(np.str_('comfort'), {'logit': 5.09375, 'prob': 2.2165477275848389e-07, 'rank': 24362})
(np.str_('progress'), {'logit': 4.1875, 'prob': 8.987262845039368e-08, 'rank': 37816})
(np.str_('lug'), {'logit': 2.59375, 'prob': 1.8277205526828766e-08, 'rank': 74465})
(np.str_('dog'), {'logit': 1.96875, 'prob': 9.778887033462524e-09, 'rank': 93759})
(np.str_('oka'), {'logit': 1.515625, 'prob': 6.199115887284279e-09, 'rank': 105245})
(np.str_('being'), {'logit': 1.484375, 'prob': 5.995389074087143e-09, 'rank': 106016})
(np.str_('Nation'), {'logit': 0.97265625, 'prob': 3.5943230614066124e-09, 'rank': 117660})
(np.str_('box'), {'logit': -0.3125, 'prob': 9.968061931431293e-10, 'rank': 138842})


# Run (names/cities/countries)

In [None]:
n_sets=3
stimulus_type = 'names'

In [None]:
names_list = get_single_token_names(tokenizer)
geo_pairs = get_single_token_geography(tokenizer)
stimulus_dict = {'A': names_list,'BC_pairs': geo_pairs}

training_pairs, pair_types, fan_types, test_probes, test_direct_targets, test_indirect_targets = generate_stimuli(
    all_tokens=[], n_sets=n_sets, stimulus_type='names', stimulus_dict=stimulus_dict)

probe_idx = 0
probe = test_probes[probe_idx]
direct_targets = test_direct_targets[probe_idx]
indirect_target = test_indirect_targets[probe_idx]
fan_type = fan_types[probe_idx]

In [None]:
# STANDARD PROBE
prompt_type = 'standard'
prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type)

print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt,
                      list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

Tam is from Delhi. Mitch is from Rome. Chelsea is from Riyadh. Riyadh is in Saudi. Delhi is in India. Rome is in Italy. Tam is from  

Probe: Tam, direct target: Delhi, indirect target: India

Results (ranked):
('India', {'logit': 12.9375, 'prob': 0.0002918243408203125, 'rank': 77})
('Saudi', {'logit': 9.6875, 'prob': 1.1324882507324219e-05, 'rank': 506})
(np.str_('Chelsea'), {'logit': 8.75, 'prob': 4.4405460357666016e-06, 'rank': 924})
('Italy', {'logit': 8.5, 'prob': 3.4570693969726562e-06, 'rank': 1076})
('Delhi', {'logit': 7.90625, 'prob': 1.9073486328125e-06, 'rank': 1633})
(np.str_('Tam'), {'logit': 5.5625, 'prob': 1.8347054719924927e-07, 'rank': 7960})
(np.str_('Mitch'), {'logit': 1.6953125, 'prob': 3.841705620288849e-09, 'rank': 63256})
('Riyadh', {'logit': 1.0546875, 'prob': 2.0227162167429924e-09, 'rank': 81347})
('Rome', {'logit': 1.0546875, 'prob': 2.0227162167429924e-09, 'rank': 81347})


In [None]:
# AFC PROBE
prompt_type = 'afc'
foil = None
if prompt_type == 'afc':
    other_targets = [t for t in test_indirect_targets if t != indirect_target]
    foil = random.choice(other_targets)

prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type, target=indirect_target, foil=foil)

print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt,
                      list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

Tam is from Delhi. Mitch is from Rome. Chelsea is from Riyadh. Riyadh is in Saudi. Delhi is in India. Rome is in Italy. Is Tam from India or Saudi?  

Probe: Tam, direct target: Delhi, indirect target: India

Results (ranked):
(np.str_('Tam'), {'logit': 11.4375, 'prob': 1.4722347259521484e-05, 'rank': 315})
(np.str_('Chelsea'), {'logit': 9.8125, 'prob': 2.8908252716064453e-06, 'rank': 779})
('India', {'logit': 6.15625, 'prob': 7.497146725654602e-08, 'rank': 5940})
('Saudi', {'logit': 4.875, 'prob': 2.0721927285194397e-08, 'rank': 11285})
('Italy', {'logit': 3.015625, 'prob': 3.230525180697441e-09, 'rank': 26658})
(np.str_('Mitch'), {'logit': 2.203125, 'prob': 1.433363649994135e-09, 'rank': 37331})
('Delhi', {'logit': 2.15625, 'prob': 1.367880031466484e-09, 'rank': 37995})
('Riyadh', {'logit': -0.98828125, 'prob': 5.911715561524034e-11, 'rank': 103650})
('Rome', {'logit': -0.98828125, 'prob': 5.911715561524034e-11, 'rank': 103650})


# Run (fake geography)

In [None]:
n_sets=3
stimulus_type = 'fakenames'

In [None]:
names_list = get_single_token_names(tokenizer)
fake_geo = get_fictional_single_tokens(tokenizer, 100)
geo_pairs = list(zip(fake_geo[0:50], fake_geo[50:100]))
stimulus_dict = {'A': names_list,'BC_pairs': geo_pairs}

training_pairs, pair_types, fan_types, test_probes, test_direct_targets, test_indirect_targets = generate_stimuli(
    all_tokens=[], n_sets=n_sets, stimulus_type='names', stimulus_dict=stimulus_dict)

probe_idx = 0
probe = test_probes[probe_idx]
direct_targets = test_direct_targets[probe_idx]
indirect_target = test_indirect_targets[probe_idx]
fan_type = fan_types[probe_idx]

In [None]:
# STANDARD PROBE
prompt_type = 'standard'
prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type)

print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt,
                      list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

Diamond is from Abble. Tab is from Przedsi. Jose is from Esting. Abble is in Uracy. Esting is in Indrical. Przedsi is in Duction. Diamond is from  

Probe: Diamond, direct target: Abble, indirect target: Uracy

Results (ranked):
('Abble', {'logit': 9.6875, 'prob': 1.9550323486328125e-05, 'rank': 586})
('Uracy', {'logit': 9.625, 'prob': 1.8358230590820312e-05, 'rank': 623})
('Indrical', {'logit': 5.8125, 'prob': 4.0605664253234863e-07, 'rank': 11414})
(np.str_('Jose'), {'logit': 4.375, 'prob': 9.639188647270203e-08, 'rank': 28131})
(np.str_('Diamond'), {'logit': 4.28125, 'prob': 8.800998330116272e-08, 'rank': 29630})
('Esting', {'logit': 4.1875, 'prob': 8.009374141693115e-08, 'rank': 31307})
('Przedsi', {'logit': 3.78125, 'prob': 5.3318217396736145e-08, 'rank': 39199})
(np.str_('Tab'), {'logit': 3.59375, 'prob': 4.423782229423523e-08, 'rank': 43110})
('Duction', {'logit': 2.65625, 'prob': 1.7345882952213287e-08, 'rank': 66046})


In [None]:
# AFC PROBE
prompt_type = 'afc'
foil = None
if prompt_type == 'afc':
    other_targets = [t for t in test_indirect_targets if t != indirect_target]
    foil = random.choice(other_targets)

prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type, target=indirect_target, foil=foil)

print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt,
                      list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

Diamond is from Abble. Tab is from Przedsi. Jose is from Esting. Abble is in Uracy. Esting is in Indrical. Przedsi is in Duction. Is Diamond from Duction or Uracy?  

Probe: Diamond, direct target: Abble, indirect target: Uracy

Results (ranked):
(np.str_('Diamond'), {'logit': 8.9375, 'prob': 5.62518835067749e-07, 'rank': 2699})
('Abble', {'logit': 3.4375, 'prob': 2.2992026060819626e-09, 'rank': 41382})
('Indrical', {'logit': 2.65625, 'prob': 1.0550138540565968e-09, 'rank': 54691})
('Uracy', {'logit': 2.015625, 'prob': 5.529727786779404e-10, 'rank': 67484})
(np.str_('Jose'), {'logit': 1.8984375, 'prob': 4.94765117764473e-10, 'rank': 70037})
('Duction', {'logit': 1.453125, 'prob': 3.1650415621697903e-10, 'rank': 79832})
('Przedsi', {'logit': -0.5859375, 'prob': 4.1154635255225e-11, 'rank': 122009})
(np.str_('Tab'), {'logit': -1.375, 'prob': 1.864464138634503e-11, 'rank': 133103})
('Esting', {'logit': -2.359375, 'prob': 6.963318810448982e-12, 'rank': 142367})
