In [None]:
from circular_probe import train_circular_probe
import numpy
from tqdm import tqdm
from general_ps_utils import ModelAndTokenizer
from matplotlib import pyplot as plt
import torch

torch.cuda.empty_cache()

params = {
    'model_name': "meta-llama/Meta-Llama-3-8B", #"mistralai/Mistral-7B-v0.1",
    'use_4bit': False,
    'epochs': 10_000,
    'lr': 0.0005,
    'numbers': 2000,
    'batch_size': 2000,
    'exclude': 'random', # numbers to exclude from training set
    'exclude_count': 200,
    'positions': 1,
    'shuffle': True,
    'bases': [10,11],
    'start_layer': 0,
    'bias': False
}


print(f"Params:\n\n{params}")

if params['exclude'] == 'random':
    params['exclude'] = numpy.random.choice(params['numbers'], params['exclude_count'], replace=False)

mt = ModelAndTokenizer(
    model_name=params['model_name'],
    use_4bit=params['use_4bit'],
    device='cuda'
)

tokenizer = mt.tokenizer
num_to_hidden = dict()

# move device to cuda because for some reason it is not
mt.model.to('cuda')

print(f"device of model is {mt.model.device}")


In [None]:

for i in tqdm(range(params['numbers']), delay=120):
    text_for_embeddings = ""

    for _ in range(params['positions']):
        text_for_embeddings += str(i) + " "
    text_for_embeddings = text_for_embeddings[:-1]

    x = tokenizer.encode(text_for_embeddings, return_tensors='pt')
    x = x.to(mt.device)
    hidden_states = mt.model(x, output_hidden_states=True).hidden_states

    num_to_hidden[i] = hidden_states

In [None]:
# need to average over all layers per basis
import pandas as pd

# set up df
df = pd.DataFrame(columns=['layer', 'basis', 'accuracy'])


basis_acc = dict()
basis_acc_max = dict()
basis_acc_argmax = dict()
for basis in tqdm(params['bases']):
    print(f"Training cyclic probe on basis: {basis}")
    params['basis'] = basis
    layer_acc = []
    for layer in range(params['start_layer'], mt.num_layers):
        params['layers'] = [layer]
        acc, circular_probe = train_cyclic_probe(params, mt, num_to_hidden)
        print(f"Layer: {layer}, Accuracy: {acc}")
        layer_acc.append(acc)

        # add to df
        df_delta = pd.DataFrame([[layer, basis, acc]], columns=['layer', 'basis', 'accuracy'])
        df = pd.concat([df, df_delta], ignore_index=True)

    basis_acc[basis] = numpy.mean(layer_acc)
    basis_acc_max[basis] = numpy.max(layer_acc)
    basis_acc_argmax[basis] = numpy.argmax(layer_acc)

# round all acc numbers to 3 decimal places
basis_acc = {k: round(v, 3) for k, v in basis_acc.items()}

In [None]:
print(f"Model: {params['model_name']}")
print(f"Number of numbers: {params['numbers']}")
print(f"Epochs: {params['epochs']}")
print(f"Start layer: {params['start_layer']}")

# print results as a table
print("Basis\tAccuracy")

for basis, acc in basis_acc.items():
    print(f"{basis}\t{acc}")

for basis, acc in basis_acc_max.items():
    print(f"{basis}\t{acc} (layer: {basis_acc_argmax[basis]})")