### classification results 

- input: 

- output: 
    - acts probe
    - sparse acts probe
    - causal probe 

- inspect failure cases? 
    

In [None]:
import random
import torch as t
from torch import nn
from sae_lens import HookedSAETransformer
import matplotlib.pyplot as plt

from utils.data_utils import read_from_json_file, read_from_pt_gz
from utils.probe_utils import data_loader, train_probe, test_probe

In [None]:
class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

In [None]:
## load pythia with hooked transformer for consistent comparison

device = "cpu"
pythia_model: HookedSAETransformer = HookedSAETransformer.from_pretrained('EleutherAI/pythia-70m-deduped', device=device)

In [None]:
train_nonharmful = read_from_json_file("data/nonharmful_train_ds.json")
train_harmful = read_from_json_file("data/harmful_train_ds.json")

test_nonharmful = read_from_json_file("data/nonharmful_test_ds.json")
test_harmful = read_from_json_file("data/harmful_test_ds.json")


#### acts probe 

In [None]:
### get train acts
nonharmful_acts = read_from_pt_gz("sparse_acts/train/nonharmful_acts_512.pt.gz")
harmful_acts = read_from_pt_gz("sparse_acts/train/harmful_acts_512.pt.gz")

last_tok_acts_data = t.cat((nonharmful_acts[:, -1, :], harmful_acts[:, -1, :]), dim=0).tolist()
label = train_nonharmful[1] + train_harmful[1]
train_batches = data_loader(last_tok_acts_data, label)


In [None]:
## get test acts
nonharmful_acts = read_from_pt_gz("sparse_acts/test/nonharmful_acts_512.pt.gz")
harmful_acts = read_from_pt_gz("sparse_acts/test/harmful_acts_512.pt.gz")

last_tok_acts_data = t.cat((nonharmful_acts[:, -1, :], harmful_acts[:, -1, :]), dim=0).tolist()
label = test_nonharmful[1] + test_harmful[1]
test_batches = data_loader(last_tok_acts_data, label)


In [None]:
t.manual_seed(42)
probe = Probe(512)

epoch_train_loss = []
total_loss = []
epoch_test_acc = []
epoches = 25

# train_batches = batches[:-2]

for i in range(epoches): 
    probe, losses = train_probe(probe, train_batches)
    total_loss.extend(losses)
    epoch_train_loss.append(losses[-1])

    test_acc = test_probe(probe, batches=test_batches, seed=42)
    epoch_test_acc.append(test_acc)


In [None]:
def train_loop(train_batches, epoches=25, dim=512): 
    t.manual_seed(42)
    probe = Probe(512)

    epoch_train_loss = []
    total_loss = []
    epoch_test_acc = []

    for i in range(epoches): 
        probe, losses = train_probe(probe, train_batches)
        total_loss.extend(losses)
        epoch_train_loss.append(losses[-1])

        test_acc = test_probe(probe, batches=test_batches, seed=42)
        epoch_test_acc.append(test_acc)


In [None]:
## try reducing the data and retrain

In [None]:
# import matplotlib.pyplot as plt
# to do plot side by side?
plt.plot(total_loss)
plt.plot(epoch_test_acc)

plt.title("(dim=512) original concept probe test acc")
print(epoch_test_acc)

In [None]:
## mean for final token

nonharmful_data = nonharmful_acts[:, -1, :].mean(0)  # mean along axes 0 and 1 for nonharmful_acts
harmful_data = harmful_acts[:, -1, :].mean(0)      # mean along axes 0 and 1 for harmful_acts

# Create a figure and axis
plt.figure(figsize=(8, 6))

# Plot nonharmful_acts boxplot at position 1 with a custom color
box_nonharmful = plt.boxplot(nonharmful_data, positions=[1], patch_artist=True, boxprops=dict(facecolor='lightblue', color='blue'))

# Plot harmful_acts boxplot at position 2 with a custom color
box_harmful = plt.boxplot(harmful_data, positions=[2], patch_artist=True, boxprops=dict(facecolor='lightcoral', color='red'))

# Function to annotate max, min, and median values
def annotate_boxplot(data, pos):
    # Extract the statistics from the boxplot
    median = t.median(data)
    minimum = t.min(data)
    maximum = t.max(data)

    plt.text(pos + 0.1, minimum, f"Min: {minimum:.2f}", ha='left', va='center', fontsize=10, color='blue')
    plt.text(pos + 0.1, median, f"Median: {median:.2f}", ha='left', va='center', fontsize=10, color='green')
    plt.text(pos + 0.1, maximum, f"Max: {maximum:.2f}", ha='left', va='center', fontsize=10, color='red')

# Annotate both boxplots
annotate_boxplot(nonharmful_data, 1)
annotate_boxplot(harmful_data, 2)

# Set the x-axis labels
plt.xticks([1, 2], ['Non-Harmful Acts', 'Harmful Acts'])

# Optional: Add labels and title
plt.xlabel('Type of Acts')
plt.ylabel('Mean Activation Value (n=64)')
plt.title('Comparison of Non-Harmful vs Harmful last token Activations (dim=512)')

# Display the plot
plt.show()

## quite similar

In [None]:
nonharmful_data = nonharmful_acts.mean(1).mean(0)  # mean along axes 0 and 1 for nonharmful_acts
harmful_data = harmful_acts.mean(1).mean(0)       # mean along axes 0 and 1 for harmful_acts

# Create a figure and axis
plt.figure(figsize=(8, 6))

# Plot nonharmful_acts boxplot at position 1 with a custom color
box_nonharmful = plt.boxplot(nonharmful_data, positions=[1], patch_artist=True, boxprops=dict(facecolor='lightblue', color='blue'))

# Plot harmful_acts boxplot at position 2 with a custom color
box_harmful = plt.boxplot(harmful_data, positions=[2], patch_artist=True, boxprops=dict(facecolor='lightcoral', color='red'))

# Function to annotate max, min, and median values
def annotate_boxplot(data, pos):
    # Extract the statistics from the boxplot
    median = t.median(data)
    minimum = t.min(data)
    maximum = t.max(data)

    plt.text(pos + 0.1, minimum, f"Min: {minimum:.2f}", ha='left', va='center', fontsize=10, color='blue')
    plt.text(pos + 0.1, median, f"Median: {median:.2f}", ha='left', va='center', fontsize=10, color='green')
    plt.text(pos + 0.1, maximum, f"Max: {maximum:.2f}", ha='left', va='center', fontsize=10, color='red')

# Annotate both boxplots
annotate_boxplot(nonharmful_data, 1)
annotate_boxplot(harmful_data, 2)

# Set the x-axis labels
plt.xticks([1, 2], ['Non-Harmful Acts', 'Harmful Acts'])

# Optional: Add labels and title
plt.xlabel('Type of Acts')
plt.ylabel('Mean Value across tokens')
plt.title('Comparison of Non-Harmful vs Harmful Activation Values for (dim=512) averaged over n=64 prompts each')

# Display the plot
plt.show()

## quite similar

#### sparse acts probe
- todo: could re-run to get sparse acts - to viz



In [None]:
### get train sparse acts, test sparse acts
nonharmful_sparse_acts = read_from_pt_gz("sparse_acts/train/nonharmful_sparse_acts_32768.pt.gz")
harmful_sparse_acts = read_from_pt_gz("sparse_acts/train/harmful_sparse_acts_32768.pt.gz")




In [None]:
### train probe, test probe

#### causal probe

In [None]:
### get train ablation effects, test ablation effects