In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from circular_probe import train_circular_probe
import numpy
from tqdm import tqdm
from general_ps_utils import ModelAndTokenizer
from matplotlib import pyplot as plt
import torch

torch.cuda.empty_cache()

params = {
    'model_name': "Qwen/Qwen2-0.5B", #"google/gemma-2-2b", #"meta-llama/Meta-Llama-3-8B", #"mistralai/Mistral-7B-v0.1",
    'use_4bit': False,
    'epochs': 10_000,
    'lr': 0.0005,
    'numbers': 2000,
    'batch_size': 2000,
    'exclude': 'random', # numbers to exclude from training set
    'exclude_count': 200,
    'positions': 1,
    'shuffle': True,
    'bases': [10,11],
    'start_layer': 0,
    'bias': False
}


print(f"Params:\n\n{params}")

if params['exclude'] == 'random':
    params['exclude'] = numpy.random.choice(params['numbers'], params['exclude_count'], replace=False)

mt = ModelAndTokenizer(
    model_name=params['model_name'],
    use_4bit=params['use_4bit'],
    device='cuda'
)

tokenizer = mt.tokenizer
num_to_hidden = dict()

# move device to cuda because for some reason it is not
mt.model.to('cuda')

print(f"device of model is {mt.model.device}")


Params:

{'model_name': 'Qwen/Qwen2-0.5B', 'use_4bit': False, 'epochs': 10000, 'lr': 0.0005, 'numbers': 2000, 'batch_size': 2000, 'exclude': 'random', 'exclude_count': 200, 'positions': 1, 'shuffle': True, 'bases': [10, 11], 'start_layer': 0, 'bias': False}


tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

device of model is cuda:0


In [6]:
for i in tqdm(range(params['numbers']), delay=120):
    text_for_embeddings = ""

    for _ in range(params['positions']):
        text_for_embeddings += str(i) + " "
    text_for_embeddings = text_for_embeddings[:-1]

    x = tokenizer.encode(text_for_embeddings, return_tensors='pt')
    x = x.to(mt.device)
    hidden_states = mt.model(x, output_hidden_states=True).hidden_states

    num_to_hidden[i] = hidden_states

In [None]:
# need to average over all layers per basis
import pandas as pd

# set up df
df = pd.DataFrame(columns=['layer', 'basis', 'accuracy'])


basis_acc = dict()
basis_acc_max = dict()
basis_acc_argmax = dict()
for basis in tqdm(params['bases']):
    print(f"Training cyclic probe on basis: {basis}")
    params['basis'] = basis
    layer_acc = []
    for layer in range(params['start_layer'], mt.num_layers):
        params['layers'] = [layer]
        acc, circular_probe = train_circular_probe(params, mt, num_to_hidden)
        print(f"Layer: {layer}, Accuracy: {acc}")
        layer_acc.append(acc)

        # add to df
        df_delta = pd.DataFrame([[layer, basis, acc]], columns=['layer', 'basis', 'accuracy'])
        df = pd.concat([df, df_delta], ignore_index=True)

    basis_acc[basis] = numpy.mean(layer_acc)
    basis_acc_max[basis] = numpy.max(layer_acc)
    basis_acc_argmax[basis] = numpy.argmax(layer_acc)

# round all acc numbers to 3 decimal places
basis_acc = {k: round(v, 3) for k, v in basis_acc.items()}

  0%|                                                                                                                                          | 0/2 [00:00<?, ?it/s]

Training cyclic probe on basis: 10
896
X[0].shape=torch.Size([896])


  df = pd.concat([df, df_delta], ignore_index=True)


Layer: 0, Accuracy: 0.0
896
X[0].shape=torch.Size([896])
Layer: 1, Accuracy: 0.92
896
X[0].shape=torch.Size([896])
Layer: 2, Accuracy: 0.885
896
X[0].shape=torch.Size([896])
Layer: 3, Accuracy: 0.915
896
X[0].shape=torch.Size([896])
Layer: 4, Accuracy: 0.915
896
X[0].shape=torch.Size([896])
Layer: 5, Accuracy: 0.94
896
X[0].shape=torch.Size([896])
Layer: 6, Accuracy: 0.855
896
X[0].shape=torch.Size([896])
Layer: 7, Accuracy: 0.845
896
X[0].shape=torch.Size([896])
Layer: 8, Accuracy: 0.795
896
X[0].shape=torch.Size([896])
Layer: 9, Accuracy: 0.915
896
X[0].shape=torch.Size([896])
Layer: 10, Accuracy: 0.905
896
X[0].shape=torch.Size([896])
Layer: 11, Accuracy: 0.885
896
X[0].shape=torch.Size([896])
Layer: 12, Accuracy: 0.835
896
X[0].shape=torch.Size([896])
Layer: 13, Accuracy: 0.84
896
X[0].shape=torch.Size([896])
Layer: 14, Accuracy: 0.775
896
X[0].shape=torch.Size([896])
Layer: 15, Accuracy: 0.735
896
X[0].shape=torch.Size([896])
Layer: 16, Accuracy: 0.775
896
X[0].shape=torch.Size([8

 50%|████████████████████████████████████████████████████████████████▌                                                                | 1/2 [03:27<03:27, 207.47s/it]

Layer: 23, Accuracy: 0.685
Training cyclic probe on basis: 11
896
X[0].shape=torch.Size([896])
Layer: 0, Accuracy: 0.0
896
X[0].shape=torch.Size([896])
Layer: 1, Accuracy: 0.01
896
X[0].shape=torch.Size([896])
Layer: 2, Accuracy: 0.065
896
X[0].shape=torch.Size([896])
Layer: 3, Accuracy: 0.025
896
X[0].shape=torch.Size([896])
Layer: 4, Accuracy: 0.04
896
X[0].shape=torch.Size([896])
Layer: 5, Accuracy: 0.06
896
X[0].shape=torch.Size([896])
Layer: 6, Accuracy: 0.03
896
X[0].shape=torch.Size([896])
Layer: 7, Accuracy: 0.055
896
X[0].shape=torch.Size([896])
Layer: 8, Accuracy: 0.025
896
X[0].shape=torch.Size([896])
Layer: 9, Accuracy: 0.04
896
X[0].shape=torch.Size([896])
Layer: 10, Accuracy: 0.05
896
X[0].shape=torch.Size([896])
Layer: 11, Accuracy: 0.055
896
X[0].shape=torch.Size([896])
Layer: 12, Accuracy: 0.035
896
X[0].shape=torch.Size([896])
Layer: 13, Accuracy: 0.04
896
X[0].shape=torch.Size([896])
Layer: 14, Accuracy: 0.04
896
X[0].shape=torch.Size([896])
Layer: 15, Accuracy: 0.05

In [12]:
print(f"Model: {params['model_name']}")
print(f"Number of numbers: {params['numbers']}")
print(f"Epochs: {params['epochs']}")
print(f"Start layer: {params['start_layer']}")

# print results as a table
print("Basis\tAccuracy")

for basis, acc in basis_acc.items():
    print(f"{basis}\t{acc}")

for basis, acc in basis_acc_max.items():
    print(f"{basis}\t{acc} (layer: {basis_acc_argmax[basis]})")

Model: google/gemma-2-2b
Number of numbers: 2000
Epochs: 10000
Start layer: 0
Basis	Accuracy
10	0.626
11	0.075
10	0.995 (layer: 1)
11	0.255 (layer: 3)
