- get model 512 activations 
- for harmful vs nonharmful data

- qual plot: data activation distribution does not seem very different at all.. 

- can model detect harmful prompts based on probe trained on 512 activations alone 

- can it do so 

In [None]:
import random
import torch as t
from torch import nn
from sae_lens import HookedSAETransformer
import matplotlib.pyplot as plt

from utils.dataset import read_from_pt_gz, save_to_pt_gz, load_target_concept_data


In [None]:
## load pythia with hooked transformer for consistent comparison

device = "cpu"
pythia_model: HookedSAETransformer = HookedSAETransformer.from_pretrained('EleutherAI/pythia-70m-deduped', device=device)

In [None]:
# (data, labels)
nonharmful_data = load_target_concept_data(train=True, target_label=0)
harmful_data = load_target_concept_data(train=True, target_label=1)

logits_no_saes, cache_no_saes = pythia_model.run_with_cache(nonharmful_data[0])
save_to_pt_gz("sparse_acts/train/nonharmful_acts_512.pt.gz", cache_no_saes['blocks.4.hook_resid_post'])
nonharmful_acts = read_from_pt_gz("sparse_acts/train/nonharmful_acts_512.pt.gz")

logits_no_saes, cache_no_saes = pythia_model.run_with_cache(harmful_data[0])
save_to_pt_gz("sparse_acts/train/harmful_acts_512.pt.gz", cache_no_saes['blocks.4.hook_resid_post'])
harmful_acts = read_from_pt_gz("sparse_acts/train/harmful_acts_512.pt.gz")

In [None]:
nonharmful_acts = read_from_pt_gz("sparse_acts/train/nonharmful_acts_512.pt.gz")
harmful_acts = read_from_pt_gz("sparse_acts/train/harmful_acts_512.pt.gz")


In [None]:


nonharmful_data = nonharmful_acts.mean((0, 1))  # mean along axes 0 and 1 for nonharmful_acts
harmful_data = harmful_acts.mean((0, 1))        # mean along axes 0 and 1 for harmful_acts

# Create a figure and axis
plt.figure(figsize=(8, 6))

# Plot nonharmful_acts boxplot at position 1 with a custom color
box_nonharmful = plt.boxplot(nonharmful_data, positions=[1], patch_artist=True, boxprops=dict(facecolor='lightblue', color='blue'))

# Plot harmful_acts boxplot at position 2 with a custom color
box_harmful = plt.boxplot(harmful_data, positions=[2], patch_artist=True, boxprops=dict(facecolor='lightcoral', color='red'))

# Function to annotate max, min, and median values
def annotate_boxplot(data, pos):
    # Extract the statistics from the boxplot
    median = t.median(data)
    minimum = t.min(data)
    maximum = t.max(data)

    plt.text(pos + 0.1, minimum, f"Min: {minimum:.2f}", ha='left', va='center', fontsize=10, color='blue')
    plt.text(pos + 0.1, median, f"Median: {median:.2f}", ha='left', va='center', fontsize=10, color='green')
    plt.text(pos + 0.1, maximum, f"Max: {maximum:.2f}", ha='left', va='center', fontsize=10, color='red')

# Annotate both boxplots
annotate_boxplot(nonharmful_data, 1)
annotate_boxplot(harmful_data, 2)

# Set the x-axis labels
plt.xticks([1, 2], ['Non-Harmful Acts', 'Harmful Acts'])

# Optional: Add labels and title
plt.xlabel('Type of Acts')
plt.ylabel('Mean Value')
plt.title('Comparison of Non-Harmful vs Harmful Activations (dim=512)')

# Display the plot
plt.show()

## quite similar

In [None]:
## mean for final token

nonharmful_data = nonharmful_acts[:, -1, :].mean(0)  # mean along axes 0 and 1 for nonharmful_acts
harmful_data = harmful_acts[:, -1, :].mean(0)      # mean along axes 0 and 1 for harmful_acts

# Create a figure and axis
plt.figure(figsize=(8, 6))

# Plot nonharmful_acts boxplot at position 1 with a custom color
box_nonharmful = plt.boxplot(nonharmful_data, positions=[1], patch_artist=True, boxprops=dict(facecolor='lightblue', color='blue'))

# Plot harmful_acts boxplot at position 2 with a custom color
box_harmful = plt.boxplot(harmful_data, positions=[2], patch_artist=True, boxprops=dict(facecolor='lightcoral', color='red'))

# Function to annotate max, min, and median values
def annotate_boxplot(data, pos):
    # Extract the statistics from the boxplot
    median = t.median(data)
    minimum = t.min(data)
    maximum = t.max(data)

    plt.text(pos + 0.1, minimum, f"Min: {minimum:.2f}", ha='left', va='center', fontsize=10, color='blue')
    plt.text(pos + 0.1, median, f"Median: {median:.2f}", ha='left', va='center', fontsize=10, color='green')
    plt.text(pos + 0.1, maximum, f"Max: {maximum:.2f}", ha='left', va='center', fontsize=10, color='red')

# Annotate both boxplots
annotate_boxplot(nonharmful_data, 1)
annotate_boxplot(harmful_data, 2)

# Set the x-axis labels
plt.xticks([1, 2], ['Non-Harmful Acts', 'Harmful Acts'])

# Optional: Add labels and title
plt.xlabel('Type of Acts')
plt.ylabel('Mean Value')
plt.title('Comparison of Non-Harmful vs Harmful Activations (dim=512)')

# Display the plot
plt.show()

## quite similar

- quantitative results: train with activations of last token 

<!-- since there is no signficant differences  -->

In [None]:
# data loader

In [None]:
## returns batches of activations
def data_loader(data, labels, batch_size=16, seed = 42, device="cpu"):
    idxs = list(range(len(data)))
    # creates a shuffled list 
    random.Random(seed).shuffle(idxs)
    # get data in this shuffled order
    data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
    # return the batches
    batches = [
        (data[i:i+batch_size], t.tensor(labels[i:i+batch_size], device=device)) for i in range(0, len(data), batch_size)
    ]
    return batches

In [None]:
# sice 
last_tok_acts_data = t.cat((nonharmful_acts[:, -1, :], harmful_acts[:, -1, :]), dim=0).tolist()
label = nonharmful_data[1] + harmful_data[1]

batches = data_loader(last_tok_acts_data, label)


In [None]:
class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits


In [None]:
from tqdm import tqdm

def train_probe(batches, lr=1e-2, epochs=1, dim=512, seed=42, probe="linear"):
    t.manual_seed(seed)
    if probe == "linear":
        probe = Probe(dim)
    else: 
        print('define probe')

    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    losses = []
    epoch_losses = []
    for epoch in tqdm(range(epochs)):
        # epoch_loss = 0
        for batch in batches:
            
            acts = batch[0]
            labels = batch[1] 
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
        epoch_losses.append(loss.item())

    return probe, losses, epoch_losses

In [None]:
# probe = Probe(512)
optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss()


def train_probe_only(probe, batches, lr=1e-2, epochs=1, dim=512, seed=42, probe="linear"):
    
    losses = []
    # epoch_losses = []
    
    for batch in batches:
        
        acts = batch[0]
        labels = batch[1] 
        logits = probe(acts)
        loss = criterion(logits, labels.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())

    return probe, losses

In [None]:
## load model 

## compute time for training

In [None]:
probe, losses, epoch_loss = train_probe(batches, epochs=25)



In [None]:
import matplotlib.pyplot as plt
plt.plot(epoch_loss)
plt.title("(dim=512) original concept probe training loss")

In [None]:
def test_probe(probe, batches, seed=42):
    with t.no_grad():
        corrects = []

        for batch in batches:
            acts = batch[0]
            labels = batch[1]
            logits = probe(acts)
            preds = (logits > 0.0).long()
            # print(logits)
            # print(preds)
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()


In [None]:
# with open("data/test_data.json", "r") as file: 
#     test_data = json.load(file)
# len(test_data)

In [None]:
# (data, labels)
nonharmful_data = load_target_concept_data(train=False, target_label=0)
harmful_data = load_target_concept_data(train=False, target_label=1)

logits_no_saes, cache_no_saes = pythia_model.run_with_cache(nonharmful_data[0])
save_to_pt_gz("sparse_acts/test/nonharmful_acts_512.pt.gz", cache_no_saes['blocks.4.hook_resid_post'])
test_nonharmful_acts = read_from_pt_gz("sparse_acts/test/nonharmful_acts_512.pt.gz")

logits_no_saes, cache_no_saes = pythia_model.run_with_cache(harmful_data[0])
save_to_pt_gz("sparse_acts/test/harmful_acts_512.pt.gz", cache_no_saes['blocks.4.hook_resid_post'])
test_harmful_acts = read_from_pt_gz("sparse_acts/test/harmful_acts_512.pt.gz")

In [None]:
# sice 
last_tok_acts_data = t.cat((test_nonharmful_acts[:, -1, :], test_harmful_acts[:, -1, :]), dim=0).tolist()
label = nonharmful_data[1] + harmful_data[1]

test_batches = data_loader(last_tok_acts_data, label)


In [None]:
# test_probe()

In [None]:
## reduce batches observe performance

epoch_train_loss = []
epoch_test_acc = []
epoches = 25


probe, losses, epoch_loss = train_probe(batches=train_batches) 
test_accuracy = test_probe(probe, batches=test_batches, seed=42)
# epoch_train_loss.append(losses[-1])
epoch_test_acc.append(test_accuracy)

### hmm should test accuracy inside the training loop? 


In [None]:
## measure score on harmful data only 


In [None]:
## measure classification score on advbench

In [None]:
## measure score on autodan
