# Functions



In [1]:
"""
***Memory Inference Experiments***

Train on associative inference task, assess integrative encoding (proactive) vs. retrieval-time inference (reactive)

Paired associate inference setup: goal is to measure A-C association
- Simplest version: AB BC -> A? -> measure output probability of B and C
- Fan in & fan out versions (multiple As to one B, multiple Bs to one A)

Other task ideas (not implemented here):
- Acquired equivalence: AB CB AD -> C?
- Other task: AB CB XY -> A? AB AC XY -> B?

"""

import sys, random, uuid, os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from tqdm import tqdm
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import nltk
nltk.download('wordnet')
nltk.download('names')
nltk.download('gazetteers')
from nltk.corpus import wordnet, names, gazetteers

# for running on colab
if 'google.colab' in sys.modules:
    os.system("git clone https://github.com/sflippl/models-of-memory.git")
    sys.path.append('models-of-memory')
    dir = 'models-of-memory'
else:
    print("Running locally")
    dir = '.'

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package names to /root/nltk_data...
[nltk_data]   Unzipping corpora/names.zip.
[nltk_data] Downloading package gazetteers to /root/nltk_data...
[nltk_data]   Unzipping corpora/gazetteers.zip.


In [8]:
def load_model(model_id):
    model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    return model, tokenizer

# filter stimuli for only single tokens
def get_single_tokens(tokenizer, arr):
  return [a for a in arr if len( tokenizer(a, add_special_tokens=False)["input_ids"] ) == 1 ]

def get_single_token_nouns(tokenizer):
    """get list of nouns from wordnet that are single tokens in the given model"""
    nouns = [lemma.name() for syn in wordnet.all_synsets("n") for lemma in syn.lemmas()]
    nouns = [n for n in nouns if n.isalpha()]
    return get_single_tokens(tokenizer, nouns)

def get_single_token_names(tokenizer):
    """get list of names from nltk names corpus that are single tokens"""
    all_names = names.words('male.txt') + names.words('female.txt')
    return list(set(get_single_tokens(tokenizer, all_names)))

def get_single_token_geography(tokenizer):
    return [('Paris', 'France'), ('Berlin', 'Germany'), ('Rome', 'Italy'), ('Madrid', 'Spain'),
            ('Tokyo', 'Japan'), ('Beijing', 'China'), ('Delhi', 'India'), ('Cairo', 'Egypt'),
             ('Moscow', 'Russia'), ('Stockholm', 'Sweden'), ('London', 'UK'), ('Ottawa', 'Canada'),
              ('Mexico', 'Mexico'), ('Brasilia', 'Brazil'), ('Sydney', 'Australia'), ('Amman', 'Jordan'),
               ('Baghdad', 'Iraq'), ('Tehran', 'Iran'), ('Riyadh', 'Saudi'), ('Islamabad', 'Pakistan')]

def get_fictional_single_tokens(tokenizer, count):
    """
    Scavenges the model's own vocabulary for strings it treats as atomic
    but that are not recognized as real English nouns.
    """
    # Get all potential strings from the model's vocabulary
    # (Handling various tokenizer formats like GPT, Llama, and BERT)
    vocab = tokenizer.get_vocab().keys()
    fictional_pool = []

    for t in vocab:
        # 1. Clean up subword/whitespace markers (like 'Ġ', ' ', '##')
        clean = t.replace('Ġ', '').replace(' ', '').replace('##', '')

        # 2. Heuristics for a "Country" name (alphabetic, 5-8 chars)
        if clean.isalpha() and 5 <= len(clean) <= 8:
            # 3. Validation: Verify it stays a single token and isn't a real word
            if len(tokenizer.encode(clean, add_special_tokens=False)) == 1:
                if not wordnet.synsets(clean.lower()):
                    fictional_pool.append(clean.capitalize())

        if len(fictional_pool) >= count + 50: # Get a buffer then stop
            break

    return random.sample(fictional_pool, min(count, len(fictional_pool)))

In [10]:
### honestly this function is way more complicated than we need at this point just for testing
# this implements fan in/out and a bunch of stuff that we're not ready to do

def generate_stimuli(all_tokens, n_sets, n_distractor_pairs=0,
                     fan_in_pct=0, fan_out_pct=0, fan_in_degree=0, fan_out_degree=0,
                     stimulus_type='words', stimulus_dict=None):
    """
    Each "set" is A->B->C, but strictly its the number of C items (because of fan in/fan out)
    Distractor XY pairs are randomly intermixed if requested
    Fan structure:
        - fan_in_pct: proportion of sets with fan-in structure (multiple As -> one B -> one C)
        - fan_out_pct: proportion of sets with fan-out structure (one A -> multiple Bs -> one C)
        - fan_in_degree: how many A tokens lead to the same B (e.g., 3 means A1->B, A2->B, A3->B->C)
        - fan_out_degree: how many B tokens each A leads to (e.g., 5 means A->B1, A->B2, ..., A->B5, then A->C)
    stimulus_type: 'words', 'names', or 'fakenames'
    stimulus_dict: optional dict with 'A' list and 'BC_pairs' list of tuples for cities/countries mode
    Returns:
        - list of tuples of training pairs
        - list of possible probes (A)
        - list of lists of direct targets B (may be more than one for fan out)
        - list of indirect targets C (only one per A)
        - list of fan types (simple, fan_in, fan_out)
    """
    if fan_in_pct + fan_out_pct > 1:
        raise ValueError("fan_in_pct + fan_out_pct cannot exceed 1.0")

    # determine how many sets of each type
    n_sets_fan_in = int(n_sets * fan_in_pct)
    n_sets_fan_out = int(n_sets * fan_out_pct)
    n_sets_simple = n_sets - n_sets_fan_in - n_sets_fan_out

    # calculate and generate tokens
    n_a_tokens = n_sets_simple + (n_sets_fan_in * fan_in_degree) + n_sets_fan_out
    n_b_tokens = n_sets_simple + n_sets_fan_in + (n_sets_fan_out * fan_out_degree)
    n_c_tokens = n_sets  # one C per set
    n_x_tokens = n_y_tokens = n_distractor_pairs

    if stimulus_dict:
        a_tokens = np.random.choice(stimulus_dict['A'], n_a_tokens, replace=False)
        # Select matching BC pairs
        bc_indices = np.random.choice(len(stimulus_dict['BC_pairs']), n_sets, replace=False)
        selected_bc = [stimulus_dict['BC_pairs'][i] for i in bc_indices]
        b_tokens = [bc[0] for bc in selected_bc]
        c_tokens = [bc[1] for bc in selected_bc]
        x_tokens = np.random.choice(all_tokens, n_x_tokens, replace=False)
        y_tokens = np.random.choice(all_tokens, n_y_tokens, replace=False)
    else:
        n_total_tokens = int(n_a_tokens + n_b_tokens + n_c_tokens + n_x_tokens + n_y_tokens)
        tokens = np.random.choice(all_tokens, n_total_tokens, replace=False)
        split_indices = np.cumsum([n_a_tokens, n_b_tokens, n_c_tokens, n_x_tokens, n_y_tokens], dtype=int)[:-1]
        a_tokens, b_tokens, c_tokens, x_tokens, y_tokens = np.split(tokens, split_indices)

    # build pairs and track mappings
    ab_pairs, bc_pairs, xy_pairs = [],[],[]
    direct_targets = []  # list of lists: each entry corresponds to each A, containing a list of its Bs
    indirect_targets = []  # list: each entry corresponds to each A, containing its C
    pair_types = [] # track whether each pair is AB, BC, or XY
    fan_types = []  # track whether each A is simple, fan_in, or fan_out
    a_idx, b_idx, c_idx = 0, 0, 0 # keep track of where we are in the tokens

    # simple sets: 1 A -> 1 B -> 1 C
    for _ in range(n_sets_simple):
        ab_pairs.append((a_tokens[a_idx], b_tokens[b_idx]))
        bc_pairs.append((b_tokens[b_idx], c_tokens[c_idx]))
        direct_targets.append([b_tokens[b_idx]])
        indirect_targets.append(c_tokens[c_idx])
        fan_types.append('simple')
        a_idx += 1
        b_idx += 1
        c_idx += 1

    # Fan-in sets: multiple As -> 1 B -> 1 C
    for _ in range(n_sets_fan_in):
        for _ in range(fan_in_degree): # A1->B, A2->B, ...
            ab_pairs.append((a_tokens[a_idx], b_tokens[b_idx]))
            direct_targets.append([b_tokens[b_idx]])
            indirect_targets.append(c_tokens[c_idx])
            fan_types.append('fan_in')
            a_idx += 1
        bc_pairs.append((b_tokens[b_idx], c_tokens[c_idx])) # only one BC pair per set
        b_idx += 1
        c_idx += 1

    # Fan-out sets: 1 A -> multiple Bs; only first B -> C
    for _ in range(n_sets_fan_out):
        b_list = []
        for i in range(fan_out_degree): # A->B1, A->B2, ...
            ab_pairs.append((a_tokens[a_idx], b_tokens[b_idx]))
            b_list.append(b_tokens[b_idx])
            if i == 0:  # Only pair the first B with C
                bc_pairs.append((b_tokens[b_idx], c_tokens[c_idx]))
            b_idx += 1
        direct_targets.append(b_list)
        indirect_targets.append(c_tokens[c_idx])
        fan_types.append('fan_out')
        a_idx += 1
        c_idx += 1

    xy_pairs = list(zip(x_tokens, y_tokens))

    # Shuffle BC pairs
    bc_perm = np.random.permutation(len(bc_pairs))
    bc_shuffled = [bc_pairs[i] for i in bc_perm]

    # Combine training pairs and track types
    train_pairs = ab_pairs + bc_shuffled
    pair_types = ['AB'] * len(ab_pairs) + ['BC'] * len(bc_pairs)

    # Insert XY pairs at random locations
    for xy_pair in xy_pairs:
        idx = np.random.randint(0, len(train_pairs) + 1)
        train_pairs.insert(idx, xy_pair)
        pair_types.insert(idx, 'XY')

    return train_pairs, pair_types, fan_types, a_tokens, direct_targets, indirect_targets

In [11]:
def generate_prompt(train_pairs, pair_types, test_probe, stimulus_type='words',
                    prompt_type='standard', target=None, foil=None, ):
    prompt = ""
    # training pairs
    if stimulus_type == 'words':
        for p in train_pairs:
            prompt += f"{p[0]}->{p[1]} "
    else:
        for p, p_type in zip(train_pairs, pair_types):
            if p_type == 'AB' or ():
                prompt += f"{p[0]} is from {p[1]}. "
            elif p_type == 'BC':
                prompt += f"{p[0]} is in {p[1]}. "
            elif random.random() < 0.5: # XY
                prompt += f"{p[0]} is from {p[1]}. "
            else:
                prompt += f"{p[0]} is in {p[1]}. "

    # Query structure
    if prompt_type == 'standard':
        if stimulus_type == 'words':
            prompt += f"{test_probe}->"
        else:
            prompt += f"{test_probe} is from "
    elif prompt_type == 'afc':
        choices = [target, foil]
        random.shuffle(choices)
        if stimulus_type == 'words':
            prompt += f"{test_probe}->{choices[0]} or {choices[1]}? "
        else:
            prompt += f"Is {test_probe} from {choices[0]} or {choices[1]}? "
    return prompt


def query_model(model, tokenizer, prompt, target_tokens=None):
    """
    Get logits, probabilities, and ranks for next predicted tokens.
    If target_tokens is None, returns top 10 ranked tokens.
    Returns dict mapping each target token to {logit, prob, rank}
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model(**inputs)
        next_token_logits = outputs.logits[0, -1, :]  # logits for next token
    probs = torch.softmax(next_token_logits, dim=-1)  # probabilites for next token

    if target_tokens is None:
        top_probs, top_indices = torch.topk(probs, 10)
        results = {}
        for i in range(10):
            token_id = top_indices[i].item()
            token = tokenizer.decode([token_id])
            results[token] = {
                'logit': next_token_logits[token_id].item(),
                'prob': top_probs[i].item(),
                'rank': i + 1
            }
        return results

    target_token_ids = [tokenizer.encode(t, add_special_tokens=False)[0] for t in target_tokens]
    results = {}
    for token, token_id in zip(target_tokens, target_token_ids):
        results[token] = {
            'logit': next_token_logits[token_id].item(),
            'prob': probs[token_id].item(),
            'rank': (probs > probs[token_id]).sum().item() + 1
        }
    return results

# Load model

In [4]:
model_id = "Qwen/Qwen3-4B"
model, tokenizer = load_model(model_id)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]



model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

# Run (words)

In [12]:
n_sets = 3 # how many C items (BC pairs)
n_distractor_pairs = 0 # how many XY pairs
fan_in_pct = 0.0 # percentage of sets that fan in
fan_out_pct = 0.0 # percentage of sets that fan out
fan_in_degree = 2 # how many As per B
fan_out_degree = 2 # how many Bs per A

stimulus_type = 'words'

In [13]:
all_tokens = get_single_token_nouns(tokenizer)

training_pairs, pair_types, fan_types, test_probes, test_direct_targets, test_indirect_targets = generate_stimuli(
    all_tokens, n_sets, stimulus_type=stimulus_type)

probe_idx = 0
probe = test_probes[probe_idx]
direct_targets = test_direct_targets[probe_idx]
indirect_target = test_indirect_targets[probe_idx]
fan_type = fan_types[probe_idx]

In [14]:
# STANDARD PROBE
prompt_type = 'standard'
prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type)

print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt, target_tokens=None) # just get top 10
                      # target_tokens = list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

illusion->e page->partition Virgin->cross partition->block cross->git e->letter illusion-> 

Probe: illusion, direct target: e, indirect target: letter

Results (ranked):
('e', {'logit': 16.625, 'prob': 0.388671875, 'rank': 1})
('page', {'logit': 14.3125, 'prob': 0.038330078125, 'rank': 2})
(' e', {'logit': 13.4375, 'prob': 0.0159912109375, 'rank': 3})
('cross', {'logit': 12.9375, 'prob': 0.00970458984375, 'rank': 4})
('\n\n', {'logit': 12.625, 'prob': 0.007110595703125, 'rank': 5})
(' page', {'logit': 12.4375, 'prob': 0.005889892578125, 'rank': 6})
('c', {'logit': 12.1875, 'prob': 0.00457763671875, 'rank': 7})
('s', {'logit': 12.125, 'prob': 0.004302978515625, 'rank': 8})
('p', {'logit': 11.875, 'prob': 0.00335693359375, 'rank': 9})
('v', {'logit': 11.8125, 'prob': 0.0031585693359375, 'rank': 10})


In [23]:
# AFC PROBE
prompt_type = 'afc'
foil = None
if prompt_type == 'afc':
    other_targets = [t for t in test_indirect_targets if t != indirect_target]
    foil = random.choice(other_targets)

prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type, target=indirect_target, foil=foil)
prompt += 'The answer is '

print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt) #, list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

illusion->e page->partition Virgin->cross partition->block cross->git e->letter illusion->git or letter? The answer is  

Probe: illusion, direct target: e, indirect target: letter

Results (ranked):
('2', {'logit': 20.75, 'prob': 0.3046875, 'rank': 1})
('1', {'logit': 20.75, 'prob': 0.3046875, 'rank': 2})
('3', {'logit': 19.25, 'prob': 0.06787109375, 'rank': 3})
(' "', {'logit': 18.875, 'prob': 0.046630859375, 'rank': 4})
('4', {'logit': 18.875, 'prob': 0.046630859375, 'rank': 5})
('7', {'logit': 18.625, 'prob': 0.036376953125, 'rank': 6})
('6', {'logit': 18.5, 'prob': 0.031982421875, 'rank': 7})
('5', {'logit': 18.125, 'prob': 0.02197265625, 'rank': 8})
('8', {'logit': 18.125, 'prob': 0.02197265625, 'rank': 9})
('0', {'logit': 18.0, 'prob': 0.0194091796875, 'rank': 10})


# Run (names/cities/countries)

In [24]:
n_sets=3
stimulus_type = 'names'

In [25]:
names_list = get_single_token_names(tokenizer)
geo_pairs = get_single_token_geography(tokenizer)
stimulus_dict = {'A': names_list,'BC_pairs': geo_pairs}

training_pairs, pair_types, fan_types, test_probes, test_direct_targets, test_indirect_targets = generate_stimuli(
    all_tokens=[], n_sets=n_sets, stimulus_type='names', stimulus_dict=stimulus_dict)

probe_idx = 0
probe = test_probes[probe_idx]
direct_targets = test_direct_targets[probe_idx]
indirect_target = test_indirect_targets[probe_idx]
fan_type = fan_types[probe_idx]

In [29]:
# STANDARD PROBE
prompt_type = 'standard'
prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type)
print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt) #, list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

Berry is from Islamabad. Kelly is from Tokyo. Richard is from Paris. Tokyo is in Japan. Paris is in France. Islamabad is in Pakistan. Berry is from  

Probe: Berry, direct target: Islamabad, indirect target: Pakistan

Results (ranked):
(' Islamabad', {'logit': 22.875, 'prob': 0.609375, 'rank': 1})
(' Pakistan', {'logit': 21.25, 'prob': 0.11962890625, 'rank': 2})
('1', {'logit': 20.375, 'prob': 0.050048828125, 'rank': 3})
(' a', {'logit': 20.375, 'prob': 0.050048828125, 'rank': 4})
('2', {'logit': 19.75, 'prob': 0.0267333984375, 'rank': 5})
(' ______', {'logit': 19.5, 'prob': 0.0208740234375, 'rank': 6})
(' the', {'logit': 19.125, 'prob': 0.0142822265625, 'rank': 7})
(' where', {'logit': 19.0, 'prob': 0.01263427734375, 'rank': 8})
(' __', {'logit': 18.625, 'prob': 0.0086669921875, 'rank': 9})
(' ?', {'logit': 18.5, 'prob': 0.007659912109375, 'rank': 10})


In [32]:
# AFC PROBE
prompt_type = 'afc'
foil = None
if prompt_type == 'afc':
    other_targets = [t for t in test_indirect_targets if t != indirect_target]
    foil = random.choice(other_targets)

prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type, target=indirect_target, foil=foil)
prompt += 'Berry is from '
print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt) # list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

Berry is from Islamabad. Kelly is from Tokyo. Richard is from Paris. Tokyo is in Japan. Paris is in France. Islamabad is in Pakistan. Is Berry from Pakistan or Japan? Berry is from  

Probe: Berry, direct target: Islamabad, indirect target: Pakistan

Results (ranked):
(' Pakistan', {'logit': 23.0, 'prob': 0.458984375, 'rank': 1})
(' Islamabad', {'logit': 23.0, 'prob': 0.458984375, 'rank': 2})
(' (', {'logit': 19.375, 'prob': 0.01220703125, 'rank': 3})
(' from', {'logit': 19.25, 'prob': 0.01080322265625, 'rank': 4})
(' -', {'logit': 18.5, 'prob': 0.005096435546875, 'rank': 5})
('1', {'logit': 18.375, 'prob': 0.004486083984375, 'rank': 6})
(' .', {'logit': 18.375, 'prob': 0.004486083984375, 'rank': 7})
(' ?', {'logit': 18.25, 'prob': 0.00396728515625, 'rank': 8})
(' __', {'logit': 17.875, 'prob': 0.0027313232421875, 'rank': 9})
('2', {'logit': 17.625, 'prob': 0.0021209716796875, 'rank': 10})


# Run (fake geography)

In [33]:
n_sets=3
stimulus_type = 'fakenames'

In [34]:
names_list = get_single_token_names(tokenizer)
fake_geo = get_fictional_single_tokens(tokenizer, 100)
geo_pairs = list(zip(fake_geo[0:50], fake_geo[50:100]))
stimulus_dict = {'A': names_list,'BC_pairs': geo_pairs}

training_pairs, pair_types, fan_types, test_probes, test_direct_targets, test_indirect_targets = generate_stimuli(
    all_tokens=[], n_sets=n_sets, stimulus_type='names', stimulus_dict=stimulus_dict)

probe_idx = 0
probe = test_probes[probe_idx]
direct_targets = test_direct_targets[probe_idx]
indirect_target = test_indirect_targets[probe_idx]
fan_type = fan_types[probe_idx]

In [38]:
# STANDARD PROBE
prompt_type = 'standard'
prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type)

print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt) # list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

Joy is from Istani. Mario is from Maint. Price is from Readonly. Istani is in Metry. Maint is in Getwidth. Readonly is in Celebr. Joy is from  

Probe: Joy, direct target: Istani, indirect target: Metry

Results (ranked):
('1', {'logit': 18.5, 'prob': 0.177734375, 'rank': 1})
('2', {'logit': 18.125, 'prob': 0.1220703125, 'rank': 2})
('3', {'logit': 17.25, 'prob': 0.051025390625, 'rank': 3})
(' where', {'logit': 17.125, 'prob': 0.044921875, 'rank': 4})
(' ______', {'logit': 17.0, 'prob': 0.039794921875, 'rank': 5})
('4', {'logit': 16.875, 'prob': 0.034912109375, 'rank': 6})
(' a', {'logit': 16.75, 'prob': 0.0308837890625, 'rank': 7})
(' Met', {'logit': 16.75, 'prob': 0.0308837890625, 'rank': 8})
(' ?', {'logit': 16.75, 'prob': 0.0308837890625, 'rank': 9})
('7', {'logit': 16.625, 'prob': 0.0272216796875, 'rank': 10})


In [40]:
# AFC PROBE
prompt_type = 'afc'
foil = None
if prompt_type == 'afc':
    other_targets = [t for t in test_indirect_targets if t != indirect_target]
    foil = random.choice(other_targets)

prompt = generate_prompt(training_pairs, pair_types, probe, stimulus_type=stimulus_type,
                         prompt_type = prompt_type, target=indirect_target, foil=foil)
prompt += "Diamond is from "
print(prompt, '\n')
print(f'Probe: {probe}, direct target: {direct_targets[0]}, indirect target: {indirect_target}\n')

results = query_model(model, tokenizer, prompt) # list(set([a for a,b in training_pairs] + [b for a,b in training_pairs])))


# print results, sorted by the "rank" value of the value of each key
sorted_results = sorted(results.items(), key=lambda item: item[1]['rank'])
print('Results (ranked):')
_ = [print(i) for i in sorted_results]

Joy is from Istani. Mario is from Maint. Price is from Readonly. Istani is in Metry. Maint is in Getwidth. Readonly is in Celebr. Is Joy from Metry or Celebr? Diamond is from  

Probe: Joy, direct target: Istani, indirect target: Metry

Results (ranked):
('1', {'logit': 17.375, 'prob': 0.1904296875, 'rank': 1})
('2', {'logit': 17.125, 'prob': 0.1484375, 'rank': 2})
('3', {'logit': 16.75, 'prob': 0.1015625, 'rank': 3})
('4', {'logit': 16.5, 'prob': 0.0791015625, 'rank': 4})
('5', {'logit': 16.0, 'prob': 0.048095703125, 'rank': 5})
('7', {'logit': 16.0, 'prob': 0.048095703125, 'rank': 6})
('0', {'logit': 15.75, 'prob': 0.037353515625, 'rank': 7})
('6', {'logit': 15.4375, 'prob': 0.02734375, 'rank': 8})
('8', {'logit': 15.375, 'prob': 0.0257568359375, 'rank': 9})
('9', {'logit': 15.3125, 'prob': 0.024169921875, 'rank': 10})
