# Heuristics analysis across training notebook

This notebook contains random code for analyzing the models across training checkpoints.

The code here is more messy, and shouldn't be referenced to. 

The code focuses on Pythia-6.9B, and contains initial experimentations (analyzed further in ``script_analyze_model_heursitics.py``) as well as the visualizations shown in section 5 in the paper.

### Imports and Setup

In [2]:
# Imports and setup
%load_ext autoreload
%autoreload 2

from general_utils import set_deterministic, set_cuda_device
set_cuda_device(0)

import random
import os
import torch
import pickle
import transformer_lens as lens
import re
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm
from itertools import chain
from fancy_einsum import einsum
from pprint import pprint
from scipy.stats import pearsonr
from transformers import GPTNeoXForCausalLM
from prompt_generation import generate_all_prompts_for_operator, OPERATORS, OPERATOR_NAMES, _is_number, _maximize_unique_answers
from visualization_utils import line, imshow, multiple_lines
from general_utils import generate_activations, set_deterministic
from evaluation_utils import model_accuracy, circuit_faithfulness_with_mean_ablation
from model_analysis_consts import PYTHIA_6_9B_CONSTS
from heuristics_classification import load_heuristic_classes
from script_analyze_model_heuristics import HEURISTIC_MATCH_THRESHOLD
from component import Component
from script_eval_pythia_faithfulness_only_mutual_neurons import build_circuit, get_heuristic_neurons, get_intersection_neurons


torch.set_grad_enabled(False)
device = 'cuda'
seed = 42
PYTHIA_PREFIX = "pythia-6.9b"

COLORBLIND_COLORS = ['#0173b2', '#de8f05', '#029e73','#d55e00', '#cc78bc', '#ca9161', '#fbafe4', '#949494', '#ece133', '#56b4e9']

In [3]:
# Utility functions, similar to those defined in the main notebook

def reverse_heuristic_dictionary(d):
    result = {}
    for heuristic, layer_neuron_scores in d.items():
        for lns in layer_neuron_scores:
            result.setdefault(lns[:2], []).append((heuristic, lns[2]))
    return result

def _get_top_op1_op2_indices(layer, neuron_idx, prompts_activations, top_k=None, is_top=True):
    activation_map = prompts_activations[(layer, neuron_idx)]
    if top_k is None:
        top_k = len(activation_map)
    top_op1_op2_pairs = op1_op2_pairs[activation_map.topk(top_k, largest=is_top).indices.cpu().numpy()].tolist()
    return top_op1_op2_pairs

def present_neuron(layer, neuron, use_kv_maps=True):
    if use_kv_maps:
        prompts_activations = kv_prompts_activations
    else:
        prompts_activations = k_prompts_activations
    v_tokens = 10
    vector_logits = model.blocks[layer].mlp.W_out[neuron] @ model.W_U
    topk_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[vector_logits[arithmetic_tokens].view(-1).topk(v_tokens).indices], prepend_bos=False)
    bottomk_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[vector_logits[arithmetic_tokens].view(-1).topk(v_tokens, largest=False).indices], prepend_bos=False)
    print(f'Neuron {neuron} logit lens:')    
    activation_img = torch.zeros((max_op - min_op, max_op - min_op))
    for i, (op1, op2) in enumerate(op1_op2_pairs):
        activation_img[op1 - min_op, op2 - min_op] = prompts_activations[(layer, neuron)][i]

    imshow(activation_img, x=list(range(min_op, max_op)), y=list(range(min_op, max_op)), labels={'x': 'Operand2', 'y': 'Operand1'}, width=600,
        title=f'Neuron {neuron} activations in MLP {layer} as function of operands')
    topk_op1_op2_pairs, bottomk_op1_op2_pairs =  _get_top_op1_op2_indices(layer, neuron, prompts_activations, top_k=50), \
                                                 _get_top_op1_op2_indices(layer, neuron, prompts_activations, top_k=50, is_top=False)
    print("Top 50 op1,op2 values in activation map: ", topk_op1_op2_pairs)
    print("Top results: ", [simple_eval(f"{op1}{OPERATORS[operator_idx]}{op2}") for (op1, op2) in topk_op1_op2_pairs])
    if not use_kv_maps:
        print(f'Top arithmetic {v_tokens} tokens: {topk_arithmetic_tokens}')
    print(sorted(rev_heuristic_classes[(layer, neuron)], key=lambda x:x[1], reverse=True))


def _get_neuron_importance_scores_across_operators(model_name):
    """
    Get a unified list of top neurons across operators.
    """
    all_neuron_scores = [_get_neuron_importance_scores(model_name, op_idx, use_2shot_prompting=False) for op_idx in range(len(OPERATORS))]
    mean_neuron_scores = {layer: sum([neuron_scores[layer] for neuron_scores in all_neuron_scores]) / len(OPERATORS) for layer in range(0, 32)}
    return mean_neuron_scores


def _get_neuron_importance_scores(model_name, operator_idx, use_2shot_prompting):
    """
    Override of function in general_utils to support 2shot prompting experimentation
    """
    def ranking_func(attribution_scores, pos):
        return attribution_scores[:, pos].mean(dim=0) + attribution_scores[:, pos].std(dim=0)
        
    neuron_attribution_scores = torch.load(f"./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_node_attribution_scores{'_with_2shot_prompting' if use_2shot_prompting else ''}.pt")
    neurons_scores = {layer: ranking_func(neuron_attribution_scores[f'blocks.{layer}.mlp.hook_post'], pos=-1) for layer in range(0, 32)}
    return neurons_scores
    

def find_all_pythia_steps():
    """
    Find all analyzed training checkpoint names in the Pythia 6.9b data directory.
    """
    steps = []
    for file in os.listdir('./data/'):
        if PYTHIA_PREFIX in file and PYTHIA_PREFIX != file:
            step = int(file.split(PYTHIA_PREFIX + "-step")[1])
            steps.append(step)
    return sorted(steps)

### Random experimentations

In [None]:
# Measure model accuracy

model_name = "pythia-6.9b-step143000"
results_dir = f"./data/{model_name}/"

with open(os.path.join(results_dir, f'large_prompts_and_answers_max_op=300.pkl'), "rb") as f:
    large_prompts_and_answers = pickle.load(f)
pprint(large_prompts_and_answers[0])
pprint(large_prompts_and_answers[1])

for operator_idx in range(len(OPERATORS)):
    op = OPERATORS[operator_idx]
    print(op)
    print("Model accuracy: ", torch.load(os.path.join(results_dir, 'accuracy.pt'))[op])

In [None]:
# How many prompts completed correctly by the model are "junk" prompts (for example, completed using simple copy mechanisms when the result is equal to one of the operands)?

steps = find_all_pythia_steps()
steps = sorted(set(steps) - set([3000, 13000]))

for step in steps:
    model_name = f"{PYTHIA_PREFIX}-step{step}"
    large_prompts_and_answers = pickle.load(open( fr'./data/{model_name}/large_prompts_and_answers_max_op=300.pkl', 'rb'))
    for operator_idx in range(len(OPERATORS)):
        junk_count = 0
        for (prompt, answer) in large_prompts_and_answers[operator_idx]:
            answer = int(answer)
            op1, op2 = list(map(int, prompt[:-1].split(OPERATORS[operator_idx])))
            if answer == 0 or answer == 1:
                junk_count += 1
            elif op1 == 0 or op2 == 0 or op1 == 1 or op2 == 1:
                junk_count += 1
            else:
                pass
        print(f"{step=}, {OPERATORS[operator_idx]}: {junk_count/len(large_prompts_and_answers[operator_idx])}")

In [None]:
# Look at correctly completed prompts for a specific training checkpoint

def load_prompts(model_name):
    analysis_prompts_file_path = fr'./data/{model_name}/large_prompts_and_answers_max_op=300.pkl'
    with open(analysis_prompts_file_path, 'rb') as f:
        large_prompts_and_answers = pickle.load(f)

    set_deterministic(42)
    wanted_size = 50
    filtered_prompts_and_answers = []
    for i, pa in enumerate(large_prompts_and_answers):
        new_pa = []
        for p, a in pa:
            # Filter out simple prompts (x/0, x*1, etc)
            op1, op2 = tuple(map(int, re.findall(r'\d+', p)))[-2:]
            if op1 > 5 and op2 > 5 and int(a) > 2:
                new_pa.append((p, a))
        if len(new_pa) < wanted_size:
            print(len(new_pa), f' length of new_pa for operator {i}')
            new_pa = new_pa + random.sample(pa, k=wanted_size - len(new_pa))
        filtered_prompts_and_answers.append(new_pa)
    correct_prompts_and_answers = [_maximize_unique_answers(pa, k=wanted_size) for pa in filtered_prompts_and_answers]
    return large_prompts_and_answers, correct_prompts_and_answers

large, correct = load_prompts(f'{PYTHIA_PREFIX}-step143000')
correct

In [None]:
# Present model accuracy across timesteps

steps = find_all_pythia_steps()
steps = sorted(set(steps) - {3000, 13000, 138000, 142000}) # Remove too early checkpoints (where the model doesn't really solve arithmetics with a non-random mechanism) and the extra test checkpoints at the end
accuracies = [[torch.load(os.path.join('./data', f'{PYTHIA_PREFIX}-step{step}', 'accuracy.pt'))[op] for step in steps] for op in OPERATORS]
multiple_lines(x=steps, y=torch.tensor(accuracies), line_titles=OPERATORS, width=400)
line(torch.tensor(accuracies).mean(dim=0), x=steps, width=400)

In [None]:
# How many heuristics neurons are shared between each timestep?

threshold = 0.001
topk_neuron_per_layer = 50
use_threshold = False

steps = find_all_pythia_steps()
steps = sorted(set(steps) - {3000, 13000, 138000, 142000})
print(steps)
shared_neurons_iou = torch.zeros((len(steps), len(steps)))

neuron_scores = {step: _get_neuron_importance_scores_across_operators(f"{PYTHIA_PREFIX}-step{step}") for step in steps} # Caching for faster calculations
# neuron_scores = {step: _get_neuron_importance_scores(f"{PYTHIA_PREFIX}-step{step}", operator_idx=0, use_2shot_prompting=False) for step in steps}

for layers_to_analyze in [list(range(PYTHIA_6_9B_CONSTS.first_heuristics_layer, 32))]:
    print(layers_to_analyze)
    for i in range(len(steps)):
        for j in tqdm(range(i, len(steps))):
            try:
                neuron_scores_i = neuron_scores[steps[i]]
                neuron_scores_j = neuron_scores[steps[j]]
                if use_threshold:
                    heuristic_neurons_i = set(sum([[(layer, neuron) for neuron in (neuron_scores_i[layer] > threshold).nonzero().view(-1).tolist()] for layer in layers_to_analyze], []))
                    heuristic_neurons_j = set(sum([[(layer, neuron) for neuron in (neuron_scores_j[layer] > threshold).nonzero().view(-1).tolist()] for layer in layers_to_analyze], []))
                    shared_neurons_iou[i, j] = len(heuristic_neurons_i.intersection(heuristic_neurons_j)) / len(heuristic_neurons_i.union(heuristic_neurons_j))
                else:
                    heuristic_neurons_i = set(sum([[(layer, neuron) for neuron in neuron_scores_i[layer].topk(topk_neuron_per_layer).indices.tolist()] for layer in layers_to_analyze], []))
                    heuristic_neurons_j = set(sum([[(layer, neuron) for neuron in neuron_scores_j[layer].topk(topk_neuron_per_layer).indices.tolist()] for layer in layers_to_analyze], []))
                    shared_neurons_iou[i, j] = len(heuristic_neurons_i.intersection(heuristic_neurons_j)) / len(heuristic_neurons_i)
                shared_neurons_iou[j, i] = shared_neurons_iou[i, j]
            except Exception as e:
                print(e)
                shared_neurons_iou[i, j] = 0.0
                
    px.imshow(shared_neurons_iou, width=500, x=steps, y=steps, zmin=0.0, zmax=1.0, labels={'x': 'Step', 'y': 'Step'}, title=f'Shared Neurons between Steps', color_continuous_scale="blues").show()

In [None]:
# Draw a graph of the mean clean and mean ablated accs across operators and averaged across operators

def get_mean_baseline_and_ablated_accs(step, operator_idx, use_2shot_prompting=False):
    try:
        model_name = f"{PYTHIA_PREFIX}-step{step}"
        heuristics_knockout_results = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_heuristic_ablation_results{"_with_2shot_prompting" if use_2shot_prompting else ""}_thres={match_threshold}_{knockout_map_type}_maps.pt')
        heuristics_knockout_results = [h for h in heuristics_knockout_results if \
                                        len(h.ablated_neurons)>= MIN_NEURONS_PER_HEURISTIC and 
                                        h.ablated_neuron_matching_score >= MIN_SCORE_SUM_PER_HEURISTIC]
        if knockout_map_type == "KV":
            # No point in looking at operand heuristics
            heuristics_knockout_results = [h for h in heuristics_knockout_results if h.heuristic_name.startswith('result')]
    except Exception as e:
        print(f"Exception {e}")
        return 0.0, 0.0
    
    baseline_related = torch.tensor([h.baseline_related for h in heuristics_knockout_results])
    ablated_related = torch.tensor([h.ablated_related for h in heuristics_knockout_results])
    return baseline_related.mean().item(), ablated_related.mean().item()

match_threshold = 0.6
MIN_NEURONS_PER_HEURISTIC = 20.0
MIN_SCORE_SUM_PER_HEURISTIC = 20.0
knockout_map_type = ["KV", "K", "HYBRID"][2]
use_2shot_prompting = False
operators_to_process = range(len(OPERATORS))

accs_per_step = []
ablated_accs_per_step = []
steps = find_all_pythia_steps()
steps = sorted(set(steps) - {3000, 13000, 138000, 142000})
for operator_idx in operators_to_process:
    accs_per_step.append([get_mean_baseline_and_ablated_accs(step, operator_idx, use_2shot_prompting)[0] for step in steps])
    ablated_accs_per_step.append([get_mean_baseline_and_ablated_accs(step, operator_idx, use_2shot_prompting)[1] for step in steps])

lines = []
line_titles = []
colors = []
for i in operators_to_process:
    lines.extend([accs_per_step[i], ablated_accs_per_step[i]])
    line_titles.extend([f"Pre-ablation {OPERATORS[i]}", f"Post-ablation {OPERATORS[i]}"])
lines.extend([torch.tensor(accs_per_step[:2]).mean(dim=0).tolist(), 
              torch.tensor(ablated_accs_per_step[:2]).mean(dim=0).tolist()])
line_titles.extend([f"Pre-ablation mean", f"Post-ablation mean"])
colors = px.colors.qualitative.Plotly

multiple_lines(y=lines, x=steps, line_titles=line_titles,
        title=f'Heuristic knockout accuracies <br>(Clean & Mean over heuristic ablations) <br>over training steps',
        xaxis_title="Training step",
        yaxis_title="Accuracy",
        yaxis=dict(range=(0.0, 1.0)),
        height=400,
        width=500)

avg_corr = 0
for op_idx in range(len(OPERATORS)):
    knockout_diff_per_step = [accs_per_step[op_idx][i] - ablated_accs_per_step[op_idx][i] for i in range(len(steps))]
    corr = pearsonr(accs_per_step[op_idx], knockout_diff_per_step)[0]
    print(f"{op_idx=}, Correlation of acc to knockout diff: {corr}")
    avg_corr += corr
avg_corr /= len(OPERATORS)
print(f"{avg_corr=}")

In [None]:
# Generate the heuristic statistics as function of training steps

THRESHOLD = 0.6
OPERATORS_TO_ANALYZE = range(len(OPERATORS))
knockout_type = "HYBRID"
use_2shot_prompting = False
k_neuron_filter = 50 # How many topk neurons (per layer) to consider (Only allow heuristic matches of these neurons)

steps = find_all_pythia_steps()
steps = sorted(set(steps) - {3000, 13000, 138000, 142000})
heuristic_category_patterns = [r"op\d_\d+mod\d+", r"result_\d+mod\d+", 
                            r"op\d_region_\d+_\d+", r"result_region_\d+_\d+", 
                            r"op\d_value_\d+", r"result_value_\d+", 
                            r"op\d_pattern_.*", r"result_pattern_.*"]
heuristic_counts_across_steps = torch.zeros((len(heuristic_category_patterns), len(steps), len(OPERATORS_TO_ANALYZE)))


for operator_idx in OPERATORS_TO_ANALYZE:
    for step_idx in range(len(steps)):
        model_name = f"{PYTHIA_PREFIX}-step{steps[step_idx]}"
        try:
            heuristic_classes = load_heuristic_classes(f"./data/{model_name}", 
                                                    operator_idx, knockout_type, 
                                                    override_fileprefix=f'{OPERATOR_NAMES[operator_idx]}_heuristic_matches_dict{"_with_2shot_prompting" if use_2shot_prompting else ""}')
            heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()} # Filter by threshold

            # Filter only top-k neurons
            neuron_scores = _get_neuron_importance_scores_across_operators(model_name)
            top_neurons_of_model = {layer: neuron_scores[layer].topk(k_neuron_filter).indices.tolist() for layer in neuron_scores.keys()}
            heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if n in top_neurons_of_model[l]] for name, layer_neuron_scores in heuristic_classes.items()}

            # Unify the heuristic classes according to pattern
            unified_heuristic_classes = {}
            for pattern in heuristic_category_patterns:
                pattern_keys = [key for key in heuristic_classes.keys() if re.search(pattern, key)]
                unified_heuristic_classes[pattern] = sum([heuristic_classes[key] for key in pattern_keys], [])
                
                # Remove duplicates (layer, neuron) pairs
                unified_heuristic_classes[pattern] = list(set([(l, n) for (l, n, s) in unified_heuristic_classes[pattern]]))

            # Count the neuron in each heuristic
            for i, pattern in enumerate(heuristic_category_patterns):
                heuristic_counts_across_steps[i, step_idx, operator_idx] = len(unified_heuristic_classes[pattern])
        except:
            for i, pattern in enumerate(heuristic_category_patterns):
                heuristic_counts_across_steps[i, step_idx, operator_idx] = 0.0
        

for op_idx in range(len(OPERATORS_TO_ANALYZE)):
    fig = multiple_lines(y=heuristic_counts_across_steps[:, :, op_idx], x=steps, line_titles=heuristic_category_patterns,
        title=f'Neurons per heuristic category over training steps<br>({OPERATOR_NAMES[op_idx]})',
        xaxis_title="Training step",
        yaxis_title="Neurons count",
        width=500,
        show_fig=False)
    fig.show()

# Average across operators 
multiple_lines(y=heuristic_counts_across_steps.mean(dim=2), x=steps, line_titles=heuristic_category_patterns,
        title=f'Neurons per heuristic category over training steps<br>(Averaged across operators)',
        xaxis_title="Training step",
        yaxis_title="Neurons count",
        width=500)

In [None]:
# Generate the heuristics intersection with a chosen training step (without categorization, weighted mean across all heuristics)

def get_heuristics_that_appear_across_all_steps(operator_idx, steps):
    a = set(load_heuristic_classes(f"./data/pythia-6.9b-step143000", operator_idx, knockout_type, override_fileprefix=f'{OPERATOR_NAMES[operator_idx]}_heuristic_matches_dict{"_with_2shot_prompting" if use_2shot_prompting else ""}').keys())
    for step in steps:
        b = set(load_heuristic_classes(f"./data/pythia-6.9b-step{step}", operator_idx, knockout_type, override_fileprefix=f'{OPERATOR_NAMES[operator_idx]}_heuristic_matches_dict{"_with_2shot_prompting" if use_2shot_prompting else ""}').keys())
        a = a.intersection(b)
    return a


THRESHOLD = 0.6
knockout_type = "HYBRID"
use_2shot_prompting = False

steps = find_all_pythia_steps()
steps = sorted(list(set(steps) - {3000, 13000}))
print(steps)

mean_heuristic_intersections = torch.zeros((len(OPERATORS), len(steps)))
step_idx_to_compare_to = steps.index(143000)
gt_model_name = f"{PYTHIA_PREFIX}-step{steps[step_idx_to_compare_to]}"

intersection_neurons = {}

print(f"Comparing to step {steps[step_idx_to_compare_to]}")
for operator_idx in range(len(OPERATORS)):
    gt_heuristic_classes = load_heuristic_classes(f"./data/{gt_model_name}", operator_idx, knockout_type, override_fileprefix=f'{OPERATOR_NAMES[operator_idx]}_heuristic_matches_dict{"_with_2shot_prompting" if use_2shot_prompting else ""}')
    print('0', len(gt_heuristic_classes.keys()))
    
    # Filter by threshold
    gt_heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= THRESHOLD] for name, layer_neuron_scores in gt_heuristic_classes.items()}
    gt_heuristic_classes = {name: lns for name, lns in gt_heuristic_classes.items() if len(lns) > 0}
    print('1', len(gt_heuristic_classes.keys()))

    # Filter only top-k neurons
    k = 200
    neuron_scores = _get_neuron_importance_scores(gt_model_name, operator_idx, use_2shot_prompting)
    top_neuron_of_model = {layer: neuron_scores[layer].topk(k).indices.tolist() for layer in neuron_scores.keys()}
    gt_heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if n in top_neuron_of_model[l]] for name, layer_neuron_scores in gt_heuristic_classes.items()}
    gt_heuristic_classes = {name: lns for name, lns in gt_heuristic_classes.items() if len(lns) > 0}
    print('2', len(gt_heuristic_classes.keys()))

    # Filter only heuristics that appear in all training steps
    gt_heuristic_names = get_heuristics_that_appear_across_all_steps(operator_idx, steps)
    gt_heuristic_classes = {name: lns for name, lns in gt_heuristic_classes.items() if name in gt_heuristic_names}
    print('3', len(gt_heuristic_classes.keys()))

    gt_heuristic_names = sorted(gt_heuristic_classes.keys())
    heuristic_intersections = torch.zeros(len(gt_heuristic_names), len(steps))
    for step_idx, step in enumerate(steps):
        model_name = f"{PYTHIA_PREFIX}-step{steps[step_idx]}"
        heuristic_classes = load_heuristic_classes(f"./data/{model_name}", 
                                                    operator_idx, knockout_type, 
                                                    override_fileprefix=f'{OPERATOR_NAMES[operator_idx]}_heuristic_matches_dict{"_with_2shot_prompting" if use_2shot_prompting else ""}')
        heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()} # Filter by threshold
 
        for heuristic_idx, heuristic_name in enumerate(gt_heuristic_names):
            gt_neurons = set([(l, n) for (l, n, s) in gt_heuristic_classes[heuristic_name]])
            neurons = set([(l, n) for (l, n, s) in heuristic_classes[heuristic_name]])
            intersection_neurons[(operator_idx, heuristic_name, step)] = neurons.intersection(gt_neurons)
            heuristic_intersections[heuristic_idx, step_idx] = len(neurons.intersection(gt_neurons)) / len(gt_neurons)

    mean_heuristic_intersections[operator_idx, :] = heuristic_intersections.mean(dim=0)

fig = multiple_lines(y=mean_heuristic_intersections, x=steps, line_titles=OPERATORS,
        title=f'Heuristic neuron intersection between steps',
        xaxis_title="Training step",
        yaxis_title="Intersection ratio",
        width=1000,
        height=300,
        show_fig=False)
fig.update_layout(
    yaxis_range=[0.0, 1.0]
)
fig.show()

In [None]:
# Generate the heuristic list in the last (GT) checkpoint
HEURISTIC_THRESHOLD = 0.6
step_to_compare_to = "143000"
gt_model_name = f"{PYTHIA_PREFIX}-step{step_to_compare_to}"
model_name = f"{PYTHIA_PREFIX}-step142000"

gt_heuristic_classes = load_heuristic_classes(f"./data/{gt_model_name}", operator_idx, "HYBRID")
# Filter by threshold
gt_heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= HEURISTIC_THRESHOLD] for name, layer_neuron_scores in gt_heuristic_classes.items()}
gt_heuristic_classes = {name: lns for name, lns in gt_heuristic_classes.items() if len(lns) > 0}    
gt_heuristic_neuron_pairs = [(h_name, l, n) for h_name, lns in gt_heuristic_classes.items() for (l, n, s) in lns]

# Generate the heuristic list in the current model
heuristic_classes = load_heuristic_classes(f"./data/{model_name}", operator_idx, "HYBRID")
# Filter by threshold
heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= HEURISTIC_THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()}
heuristic_classes = {name: lns for name, lns in heuristic_classes.items() if len(lns) > 0}
heuristic_neuron_pairs = [(h_name, l, n) for h_name, lns in heuristic_classes.items() for (l, n, s) in lns]

# Get the intersection of the neurons
mutual_neurons = list(set([(l, n) for (h_name, l, n) in set(gt_heuristic_neuron_pairs).intersection(set(heuristic_neuron_pairs))]))
mutual_neurons = {layer: [n for (l, n) in mutual_neurons if l == layer] for layer in range(32)}
print(len(list(chain.from_iterable(mutual_neurons.values()))))
print(mutual_neurons)

In [None]:
# Generate faithfulness with all top-200 neurons and only with those that intersect with the last checkpoint 

# How many neurons are shared between each timestep?
operator_idx = 0
steps = sorted(set(steps) - {3000, 13000, 138000, 142000})

for step in steps:
    model_name = f"{PYTHIA_PREFIX}-step{step}"
    heuristic_classes = load_heuristic_classes(f"./data/{model_name}", operator_idx, knockout_type)
    heuristic_neurons = set([(v[0], v[1]) for v in chain.from_iterable(heuristic_classes.values())])
    print(step, 'Pre-filtering', len(heuristic_neurons))
    
    # Filter by threshold
    heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()}
    heuristic_classes = {name: lns for name, lns in heuristic_classes.items() if len(lns) > 0}
    heuristic_neurons = set([(v[0], v[1]) for v in chain.from_iterable(heuristic_classes.values())])
    print(step, 'Post-threshold-filtering', len(heuristic_neurons))

    print(step, len(set(chain.from_iterable([v for k, v in intersection_neurons.items() if k[0] == operator_idx and k[2] == step]))))

In [None]:
# Generate the prompt knockout score as a func of training steps

steps = find_all_pythia_steps()
steps = sorted(list(set(steps) - {3000, 13000}))
colors = ['lightblue', 'cadetblue', 'darkblue', 'midnightblue']
fig = go.Figure()
fig.add_trace(go.Scatter(x=[steps[0], steps[-1]], y=[1, 1], mode='lines', name='Baseline', line=dict(color='black')))

for i, chosen_neuron_hard_limit in enumerate([5, 25, 50]):
    ablated_results_per_step = torch.zeros((len(OPERATORS), len(steps)))
    control_results_per_step = torch.zeros((len(OPERATORS), len(steps)))

    # Collect the data to present from the files
    for operator_idx in range(len(OPERATORS)):
        for step_idx, step in enumerate(steps):
            model_name = f"{PYTHIA_PREFIX}-step{step}"
            neuron_hard_limits, baseline_results, ablated_results, ablated_neuron_counts, control_results = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_prompt_ablation_results_thres=0.6_HYBRID_maps.pt')
            assert torch.all(baseline_results == 1)

            chosen_neuron_hard_limit_idx = list(neuron_hard_limits).index(chosen_neuron_hard_limit)
            ablated_results_per_step[operator_idx, step_idx] = ablated_results[chosen_neuron_hard_limit_idx]
            control_results_per_step[operator_idx, step_idx] = control_results[chosen_neuron_hard_limit_idx]

    fig.add_trace(go.Scatter(x=steps, y=ablated_results_per_step[:2].mean(dim=0), mode='lines', name=f'Mean (ablated, neurons={chosen_neuron_hard_limit})', line=dict(color=colors[i])))
    fig.add_trace(go.Scatter(x=steps, y=control_results_per_step.mean(dim=0), mode='lines', name=f'Mean (control, neurons={chosen_neuron_hard_limit})', line=dict(color=colors[i], dash='dash')))

fig.update_layout(
    title='Ablated and Control Accuracies',
    xaxis_title=f'Ablated Neurons (Per Layer)',
    yaxis_title='Accuracy',
    legend_title='',
    width=700,
    yaxis_range=[0.0, 1.02],
    font=dict(size=14)
)

fig.show()

### Figures

In [None]:
# Draw a graph of the mean clean and mean ablated accs across operators and averaged across operators

def get_mean_baseline_and_ablated_accs(step, operator_idx, use_2shot_prompting=False):
    try:
        model_name = f"{PYTHIA_PREFIX}-step{step}"
        heuristics_knockout_results = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_heuristic_ablation_results{"_with_2shot_prompting" if use_2shot_prompting else ""}_thres={match_threshold}_{knockout_map_type}_maps.pt')
        heuristics_knockout_results = [h for h in heuristics_knockout_results if \
                                        len(h.ablated_neurons)>= MIN_NEURONS_PER_HEURISTIC and 
                                        h.ablated_neuron_matching_score >= MIN_SCORE_SUM_PER_HEURISTIC]
        if knockout_map_type == "KV":
            # No point in looking at operand heuristics
            heuristics_knockout_results = [h for h in heuristics_knockout_results if h.heuristic_name.startswith('result')]
    except Exception as e:
        print(f"Exception {e}")
        return 0.0, 0.0
    
    baseline_related = torch.tensor([h.baseline_related for h in heuristics_knockout_results])
    ablated_related = torch.tensor([h.ablated_related for h in heuristics_knockout_results])
    return baseline_related.mean().item(), ablated_related.mean().item()

match_threshold = 0.6
MIN_NEURONS_PER_HEURISTIC = 20.0
MIN_SCORE_SUM_PER_HEURISTIC = 20.0
knockout_map_type = "HYBRID"
use_2shot_prompting = False
operators_to_process = range(len(OPERATORS))

accs_per_step = []
ablated_accs_per_step = []
steps = find_all_pythia_steps()
steps = sorted(set(steps) - {3000, 13000, 138000, 142000})
for operator_idx in operators_to_process:
    accs_per_step.append([get_mean_baseline_and_ablated_accs(step, operator_idx, use_2shot_prompting)[0] for step in steps])
    ablated_accs_per_step.append([get_mean_baseline_and_ablated_accs(step, operator_idx, use_2shot_prompting)[1] for step in steps])

colors = px.colors.qualitative.Plotly
fig = go.Figure()
for i, op in enumerate(OPERATORS):
    fig.add_trace(go.Scatter(x=steps, y=accs_per_step[i], mode='lines', name=f'{op}', line=dict(color=colors[i], dash='dash')))
    fig.add_trace(go.Scatter(x=steps, y=ablated_accs_per_step[i], mode='lines', name=f'{op} (Ablated)', line=dict(color=colors[i])))
# fig.add_trace(go.Scatter(x=steps, y=torch.tensor(accs_per_step).mean(dim=0).tolist(), mode='lines', name=f'Mean', line=dict(color=colors[i+1], dash='dash')))
# fig.add_trace(go.Scatter(x=steps, y=torch.tensor(ablated_accs_per_step).mean(dim=0).tolist(), mode='lines', name=f'Mean (Ablated)', line=dict(color=colors[i+1])))
fig.update_layout(
    title='Heuristic-based accuracies over training steps',
    xaxis_title='Training step',
    yaxis_title='Accuracy',
    legend_title='',
    width=500,
    height=400,
    font=dict(size=14)
)
fig.show()

pio.write_image(fig, f'./figs/heuristics_knockout_across_training.pdf')
pio.write_image(fig, f'./figs/heuristics_knockout_across_training.svg')

avg_corr = 0
for op_idx in range(len(OPERATORS)):
    knockout_diff_per_step = [accs_per_step[op_idx][i] - ablated_accs_per_step[op_idx][i] for i in range(len(steps))]
    corr = pearsonr(accs_per_step[op_idx], knockout_diff_per_step)[0]
    print(f"{op_idx=}, Correlation of acc to knockout diff: {corr}")
    avg_corr += corr
avg_corr /= len(OPERATORS)
print(f"{avg_corr=}")

In [None]:
# Generate the heuristics intersection with a chosen training step 
# (without categorization, without considering heuristics, just full set intersection of (h, l, n) sets)

def get_heuristics_that_appear_across_all_steps(operator_idx, steps):
    a = set(load_heuristic_classes(f"./data/pythia-6.9b-step143000", operator_idx, knockout_type).keys())
    for step in steps:
        b = set(load_heuristic_classes(f"./data/pythia-6.9b-step{step}", operator_idx, knockout_type).keys())
        a = a.intersection(b)
    return a

THRESHOLD = 0.6
knockout_type = "HYBRID"
use_2shot_prompting = False

steps = find_all_pythia_steps()
steps = sorted(list(set(steps) - {3000, 13000, 138000, 142000}))
print(steps)

new_mean_heuristic_intersections = torch.zeros((len(OPERATORS), len(steps)))
step_idx_to_compare_to = steps.index(143000)
gt_model_name = f"{PYTHIA_PREFIX}-step{steps[step_idx_to_compare_to]}"

print(f"Comparing to step {steps[step_idx_to_compare_to]}")
for operator_idx in range(len(OPERATORS)):
    gt_heuristic_classes = load_heuristic_classes(f"./data/{gt_model_name}", operator_idx, knockout_type, override_fileprefix=f'{OPERATOR_NAMES[operator_idx]}_heuristic_matches_dict')
    
    # Filter by threshold
    gt_heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= THRESHOLD] for name, layer_neuron_scores in gt_heuristic_classes.items()}
    gt_heuristic_classes = {name: lns for name, lns in gt_heuristic_classes.items() if len(lns) > 0}

    # Filter only top-k neurons
    neuron_scores = _get_neuron_importance_scores(gt_model_name, operator_idx, use_2shot_prompting)
    top_neuron_of_model = {layer: neuron_scores[layer].topk(50).indices.tolist() for layer in neuron_scores.keys()}
    gt_heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if n in top_neuron_of_model[l]] for name, layer_neuron_scores in gt_heuristic_classes.items()}
    gt_heuristic_classes = {name: lns for name, lns in gt_heuristic_classes.items() if len(lns) > 0}

    # Filter only heuristics that appear in all training steps
    gt_heuristic_names = get_heuristics_that_appear_across_all_steps(operator_idx, steps)
    gt_heuristic_classes = {name: lns for name, lns in gt_heuristic_classes.items() if name in gt_heuristic_names}

    # Create a list of (layer, neuron, heuristic_name) tuples
    gt_heuristic_neuron_pairs = [(h_name, l, n) for h_name, lns in gt_heuristic_classes.items() for (l, n, s) in lns]
    heuristic_intersections = torch.zeros(len(steps))
    for step_idx, step in enumerate(steps):
        model_name = f"{PYTHIA_PREFIX}-step{steps[step_idx]}"
        print("comparing to", model_name)
        heuristic_classes = load_heuristic_classes(f"./data/{model_name}", 
                                                    operator_idx, knockout_type, 
                                                    override_fileprefix=f'{OPERATOR_NAMES[operator_idx]}_heuristic_matches_dict')
        heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()} # Filter by threshold

        heuristic_neuron_pairs = [(h_name, l, n) for h_name, lns in heuristic_classes.items() for (l, n, s) in lns]
        heuristic_intersections[step_idx] = len(set(gt_heuristic_neuron_pairs).intersection(set(heuristic_neuron_pairs))) / len(gt_heuristic_neuron_pairs)
        print(f"{step=}, {heuristic_intersections[step_idx]}")
    new_mean_heuristic_intersections[operator_idx, :] = heuristic_intersections


In [None]:
# Draw the heuristic intersection graph itself

# legend = ['+', '-', '×', '÷']
legend = ["Average"]
steps = sorted(set(steps) - {3000, 13000, 138000, 142000})
fig = px.line(y=new_mean_heuristic_intersections.mean(dim=0), x=steps, width=300, height=200)
fig.update_layout(
    yaxis=dict(range=[0.0, 1.0], title=dict(text="Intersection ratio", font=dict(size=16)), tickfont=dict(size=15), tickvals=torch.linspace(0.2, 1, 5), ),
    xaxis=dict(title=dict(standoff=5, text="Training step", font=dict(size=16)), tickfont=dict(size=15), tickvals=steps[::4]),
    title=dict(
        text="Heuristic neurons intersection<br>with final checkpoint",
        x=0.55,  # Center the title horizontally
        y=0.93,  # Move the title up or down
        xanchor='center',
        yanchor='top',
        font=dict(
            family="Arial",  # Change font family
            size=17,  # Change font size
    )),
    legend=dict(orientation="h", itemsizing='constant', yanchor="bottom", y=-0.35, xanchor="center", x=0.5, font=dict(size=18)),
    margin=dict(t=40, l=0, r=0, b=0),
)

fig.show()

pio.write_image(fig, f'./figs/heuristics_intersection_across_training.pdf')

In [None]:
# Draw the faithfulness graph; Once when using all heuristic neurons and once when using mutual neurons to last checkpoint

operator_idx = 0
faithfulnneses = sorted({int(k[0].split('step')[1]): v for k, v in torch.load(r'./data/pythia-6.9b-step143000/mutual_faithfulness_results_old.pt').items() if k[1] == operator_idx}.items())

colors = COLORBLIND_COLORS
baseline = [f[1][1].item() for f in faithfulnneses]
mutual = [f[1][2].item() for f in faithfulnneses]
print((torch.tensor(mutual) / torch.tensor(baseline)))
steps = [step for step, _ in faithfulnneses]

fig = go.Figure()
fig.add_trace(go.Scatter(x=steps, y=baseline, mode='lines', name='All heuristic neurons', line=dict(color=colors[0], dash="dot")))
fig.add_trace(go.Scatter(x=steps, y=mutual, mode='lines', name='Overlapping with last checkpoint', line=dict(color=colors[0])))
fig.update_layout(
    yaxis=dict(range=[0.0, 1.0], title=dict(text="Faithfulness", font=dict(size=16)), tickfont=dict(size=15), tickvals=torch.linspace(0.2, 1, 5)),
    xaxis=dict(title=dict(standoff=5, text="Training step", font=dict(size=16)), tickfont=dict(size=15), tickvals=[23000, 63000, 103000, 143000]),
    title=dict(
        text="Faithfulness of circuit with <br>specific heuristic neurons",	
        x=0.55,  # Center the title horizontally
        y=0.93,  # Move the title up or down
        xanchor='center',
        yanchor='top',
        font=dict(
            family="Arial",  # Change font family
            size=17,  # Change font size
    )),
    width=300,
    height=200,
    legend=dict(orientation="h", itemsizing='constant', yanchor="bottom", y=0.2, xanchor="center", x=0.5, font=dict(size=15)),
    margin=dict(t=40, l=0, r=0, b=0),
)

fig.show()

pio.write_image(fig, f'./figs/faithfulness_mutual_heuristics_across_training.pdf')

In [None]:
# Generate the prompt knockout score as a function of training steps

steps = find_all_pythia_steps()
steps = sorted(list(set(steps) - {3000, 13000, 138000, 142000}))
colors = COLORBLIND_COLORS[1:2] + COLORBLIND_COLORS[::-1]
fig = go.Figure()

fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name='Heuristic Ablation', line=dict(color='grey')))
fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name='Random Ablation', line=dict(dash='dash', color='grey')))

for i, chosen_neuron_hard_limit in enumerate([5, 10, 25]):
    ablated_results_per_step = torch.zeros((len(OPERATORS), len(steps)))
    control_results_per_step = torch.zeros((len(OPERATORS), len(steps)))

    # Collect the data to present from the files
    for operator_idx in range(len(OPERATORS)):
        for step_idx, step in enumerate(steps):
            model_name = f"{PYTHIA_PREFIX}-step{step}"
            neuron_hard_limits, baseline_results, ablated_results, ablated_neuron_counts, control_results = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_prompt_ablation_results_thres=0.6_HYBRID_maps.pt')
            assert torch.all(baseline_results == 1)

            chosen_neuron_hard_limit_idx = list(neuron_hard_limits).index(chosen_neuron_hard_limit)
            ablated_results_per_step[operator_idx, step_idx] = ablated_results[chosen_neuron_hard_limit_idx]
            control_results_per_step[operator_idx, step_idx] = control_results[chosen_neuron_hard_limit_idx]

    # Targeted Ablation
    fig.add_trace(go.Scatter(
        x=steps, y=ablated_results_per_step[:2].mean(dim=0),
        mode='lines',
        name=f'{chosen_neuron_hard_limit} neurons',
        line=dict(color=colors[i])
    ))
    
    # Control Ablation
    fig.add_trace(go.Scatter(
        x=steps, y=control_results_per_step[:2].mean(dim=0),
        mode='lines',
        name=f'',
        showlegend=False,
        line=dict(color=colors[i], dash='dot')
    ))

fig.update_layout(
    width=300,
    height=200,
    yaxis=dict(range=[0.0, 1.0], title=dict(text="Accuracy", font=dict(size=16)), tickfont=dict(size=15), tickvals=torch.linspace(0.2, 1, 5)),
    xaxis=dict(title=dict(standoff=5, text="Training step", font=dict(size=16)), tickfont=dict(size=15), tickvals=steps[:-1:4]),
    title=dict(text="Effect of heuristic neurons<br>knockout", x=0.52, y=0.93, xanchor='center', yanchor='top', font=dict(size=17, family="Arial")),
    showlegend=False,
    # legend=dict(orientation="h", yanchor="top", y=-0.4, xanchor="center", x=0.2, font=dict(size=14), title_text='', traceorder="normal",
    #     valign="top",  # Align legend at the top within the margin
    #     itemwidth=30,  # Adjust the width of each item
    #     tracegroupgap=0,  # Controls the spacing between columns
    # ),
    margin=dict(t=40, l=0, r=0, b=0),  # Adjust bottom margin for space
)

fig.show()

pio.write_image(fig, f'./figs/prompt_knockout_across_training.pdf')

# Other

### Shots Analysis

In [None]:
# Compare 2-shot-prompted model with 0-shot model

model_name = "pythia-6.9b-step143000"
print("Model accuracy: ", torch.load(os.path.join('data', model_name, 'accuracy.pt')))
print("Model 2-shot accuracy: ", torch.load(os.path.join('data', model_name, 'accuracywith_2shot_prompting.pt')))

def top_k_neurons(model_name, op_idx, use_2shot_prompting, k):
    def ranking_func(attribution_scores):
        return attribution_scores[:, -1].mean(dim=0) + attribution_scores[:, -1].std(dim=0)
    neuron_attribution_scores = torch.load(f"./data/{model_name}/{OPERATOR_NAMES[op_idx]}_node_attribution_scores{'_with_2shot_prompting' if use_2shot_prompting else ''}.pt")
    mlppost_neurons_scores = {layer: ranking_func(neuron_attribution_scores[f'blocks.{layer}.mlp.hook_post']).topk(k).indices.tolist() for layer in range(0, 32)}
    return mlppost_neurons_scores

for op_idx in [0, 1]:
    k = 25
    clean_top_neurons = top_k_neurons(model_name, op_idx, False, k)
    two_shot_neurons = top_k_neurons(model_name, op_idx, True, k)
    for layer in range(14, 32):
        assert len(clean_top_neurons[layer]) == len(two_shot_neurons[layer])
        mutual_percentage = len(set(clean_top_neurons[layer]).intersection(set(two_shot_neurons[layer]))) / len(clean_top_neurons[layer])
        print(f"Operator {op_idx}, Layer {layer}: {mutual_percentage} of neurons are mutual in 0-shot and 2-shot settings")

In [None]:
model_name = "pythia-6.9b-step143000"
name, step = model_name.split('-step')
model_path = f"/mnt/nlp/models/{name}/step{step}"
inner_model = GPTNeoXForCausalLM.from_pretrained(f"EleutherAI/{name}", revision=f"step{step}", cache_dir=model_path)
model = lens.HookedTransformer.from_pretrained(model_name=name, hf_model=inner_model, fold_ln=True, center_unembed=True, center_writing_weights=True, device=device)

BEST_SHOTS_PER_OP = {#_PYTHIA6 = {
    '+': '17+84=101, 2+4=6, ', #40clean, 85%
    '-': '34-20=14; 38-19=19; ', #3 clean, 
    '*': '23*14=322; 5*6=30; ',
    '/': '115/22=5, 98/7=14, '
}

BEST_SHOTS_PER_OP_PYTHIA12 = {
    '+': '17+84=101, 2+4=6, ', #45clean, 87% 2shot
    '-': '21-10=11, 105-23=82, ', #43clean, 80% 2shot
    '*': '23*14=322; 5*6=30; ', #16clean, 92% 2shot
    '/': '58/4=14, 85/4=21, ' #12clean, 73% 2shot
}


# Baseline shots
SHOTS_PER_OP = {
    '+': '15+23=38, 17+115=132, ',
    '-': '105-23=82, 21-10=11, ',
    '*': '5*6=30, 23*14=322, ',
    '/': '98/7=14, 115/22=5, '
}

# Different order
SHOTS_PER_OP2 = {
    '+': '17+115=132, 15+23=38, ',
    '-': '21-10=11, 105-23=82, ', 
    '*': '23*14=322, 5*6=30, ',
    '/': '115/22=5, 98/7=14, '
}

# ; instead of ,
SHOTS_PER_OP3 = {
    '+': '17+115=132; 15+23=38; ',
    '-': '21-10=11; 105-23=82; ', 
    '*': '23*14=322; 5*6=30; ',
    '/': '115/22=5; 98/7=14; '
}

# HARDER PROMPTS
SHOTS_PER_OP4 = {
    '+': '117+56=173, 23+299=322, ',
    '-': '221-31=190, 173-28=145, ',
    '*': '89*3=267, 18*12=216, ', 
    '/': '298/14=21, 145/113=1, '
}

# print("0Shot Model accuracy: ", torch.load(os.path.join('data', model_name + "-0shot", 'accuracy.pt')))
for op in OPERATORS:
    min_op = 0 if op != '/' else 1
    max_op = 300
    clean_prompts = generate_all_prompts_for_operator(op, min_op, max_op, (0, 520))
    answers = [str(int(eval(prompt[:-1]))) for prompt in clean_prompts]
    # print(op, 'CLEAN', model_accuracy(model, clean_prompts, answers))
    prompts = [f"{BEST_SHOTS_PER_OP[op]}{p}" for p in clean_prompts]
    print(op, 'BEST', model_accuracy(model, prompts, answers))

    # prompts = [f"{SHOTS_PER_OP2[op]}{p}" for p in clean_prompts]
    # print(op, '2', model_accuracy(model, prompts, answers))
    # prompts = [f"{SHOTS_PER_OP3[op]}{p}" for p in clean_prompts]
    # print(op, '3', model_accuracy(model, prompts, answers))
    # prompts = [f"{SHOTS_PER_OP4[op]}{p}" for p in clean_prompts]
    # print(op, '4', model_accuracy(model, prompts, answers))
    
    # def gen_random_shots(op):
    #     # Generate random prompt in the same template as the SHOTS_PER_OP[op]
    #     shot1 = random.choice(clean_prompts)
    #     ans1 = int(eval(shot1[:-1]))
    #     shot2 = random.choice(clean_prompts)
    #     ans2 = int(eval(shot2[:-1]))
    #     if op == '/'or op == '+':
    #         shots = f"{shot1}{ans1}, {shot2}{ans2}, "
    #     else:
    #         shots = f"{shot1}{ans1}; {shot2}{ans2}; "
    #     return shots
    
    # shots = gen_random_shots(op)
    # prompts = [f"{shots}{p}" for p in clean_prompts]
    # print(op, shots, model_accuracy(model, prompts, answers))
    # shots = gen_random_shots(op)
    # prompts = [f"{shots}{p}" for p in clean_prompts]
    # print(op, shots, model_accuracy(model, prompts, answers))
    # shots = gen_random_shots(op)
    # prompts = [f"{shots}{p}" for p in clean_prompts]
    # print(op, shots, model_accuracy(model, prompts, answers))
    # shots = gen_random_shots(op)
    # prompts = [f"{shots}{p}" for p in clean_prompts]
    # print(op, shots, model_accuracy(model, prompts, answers))