In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

import json

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from dawnet.model import ModelRunner, Handler

model_id = "openai-community/gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
)
runner = ModelRunner(model)



In [3]:
class SAE(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.encoder = nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.thresh = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
        self.decoder = nn.Linear(in_features=hidden_size, out_features=input_size, bias=True)

    def encode(self, x):
        y = self.encoder(x)
        mask = (y > self.thresh)
        y = mask * nn.functional.relu(y)
        return y

    def decode(self, x):
        y = self.decoder(x)
        return y

    def forward(self, x):
        y = self.encode(x)
        y = self.decode(y)
        return y

sae = SAE(input_size=768, hidden_size=8888).cuda()
state_dict = torch.load("/home/john/Downloads/experiments/models/exp001.pth")
sae.load_state_dict(state_dict)

  state_dict = torch.load("/home/john/Downloads/experiments/models/exp001.pth")


<All keys matched successfully>

In [66]:
means = np.load("/data/mech/data/train/official_it_04_clipped_1_99_percentile_scaled_0_to_1_mean.npy")
std = np.load("/data/mech/data/train/official_it_04_clipped_1_99_percentile_scaled_0_to_1_std.npy")
percentiles = np.load("/data/mech/data/train/official_it_02_percentile.npy")
lower_bound = percentiles[0]
upper_bound = percentiles[-1]

In [72]:
means = torch.Tensor(means).to(model.device)
stds = torch.Tensor(std).to(model.device)
lower_bounds = torch.Tensor(lower_bound).to(model.device)
upper_bounds = torch.Tensor(upper_bound).to(model.device)

In [101]:
hooks = list(runner._hooks.keys())
for each in hooks:
    Handler(each, runner=runner).clear()

In [103]:
def capture_act(r, n, l, ia, ik, o):
    with torch.no_grad():
        os = (o[0].squeeze().clamp(lower_bounds, upper_bounds) - means) / stds
        act = sae.encode(os)
        reconstructed = sae.decode(act)
        print(1 - torch.mean((reconstructed - os) ** 2) / os.var())
        r._output[n] = act
    return o
    
handler1 = runner.add_forward_hooks(capture_act, "transformer.h.10")

In [104]:
def tokenize(text):
    return tokenizer.encode(text, return_tensors="pt").to(model.device)

In [124]:
text = "Anarchism advocates for the replacement of the state with stateless societies and voluntary free associations. As a historically left-wing movement, this reading of anarchism is placed on the farthest left of the political spectrum, usually described as the libertarian wing of the socialist movement (libertarian socialism)."
text = "<|endoftext|>" + text

In [125]:
with torch.no_grad():
    output = runner(tokenize(text))

tensor(0.8514, device='cuda:0')


In [126]:
runner._output

OrderedDict([('transformer.h.10',
              tensor([[ 1.5382,  2.3478,  1.8526,  ...,  0.0000,  0.0000,  3.8246],
                      [ 8.2039,  0.0000, 10.4483,  ...,  7.5424,  1.0203,  8.2108],
                      [ 8.6338,  3.4213,  3.6059,  ...,  6.1635,  0.0000,  6.0921],
                      ...,
                      [ 4.5442,  0.0000, 10.0385,  ..., 13.6222,  5.8427,  0.0000],
                      [ 8.3833,  0.0000, 11.6777,  ...,  7.3747,  2.3111,  0.0000],
                      [ 0.0000,  0.0000,  7.0186,  ..., 10.1486,  6.5264,  8.0502]],
                     device='cuda:0'))])

In [127]:
output.logits.shape

torch.Size([1, 59, 50257])

In [128]:
ts = runner._output['transformer.h.10']
ts = ts.cpu().numpy()
print(ts.shape)

(59, 8888)


In [129]:
(ts > 0).sum(1)

array([3329, 3731, 4052, 3720, 3925, 3706, 3829, 3616, 3699, 3599, 4039,
       4244, 4450, 4108, 3647, 3788, 4265, 4394, 3873, 3365, 3845, 3997,
       4513, 4358, 4553, 4486, 4094, 4008, 4026, 3943, 4134, 4015, 3947,
       4274, 4150, 3891, 4718, 4651, 4214, 4082, 3760, 4096, 3905, 3462,
       3931, 3892, 4193, 4203, 4219, 4124, 4038, 3941, 4270, 3698, 3556,
       4597, 4157, 4059, 3318])

In [130]:
(ts[10] > 0).sum()

4039

In [131]:
l = ts[0].tolist()

In [132]:
print(l)

[1.5382211208343506, 2.347805976867676, 1.8526480197906494, 4.746862888336182, 0.0, 0.0, 6.71193265914917, 2.169555902481079, 2.0037379264831543, 2.501110315322876, 0.0, 0.0, 0.4989100396633148, 0.0, 0.0, 12.828177452087402, 4.152775287628174, 0.0, 0.5286300182342529, 1.3900202512741089, 2.1073880195617676, 0.0, 0.0, 0.0, 1.301715612411499, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.626565456390381, 1.6474398374557495, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0173676013946533, 8.755379676818848, 2.7533717155456543, 0.0, 0.38247376680374146, 14.478367805480957, 0.0, 0.800075352191925, 2.1760740280151367, 0.0, 0.0, 0.0, 4.736583232879639, 0.0, 0.0, 0.0, 2.0660507678985596, 5.0634236335754395, 0.052826888859272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7150062918663025, 1.1393285989761353, 0.0, 0.0, 0.0, 1.2865393161773682, 1.5255910158157349, 0.0, 4.562021255493164, 0.0, 0.10472577065229416, 0.0, 3.9898202419281006, 0.0, 0.0, 0.0, 2.069981813430786, 0.0, 0.0, 3.240987539291382, 0.0, 0.0, 0.0

In [117]:
sum(l) / len(l)

1.8400518987776815

In [42]:
sae_thresh = sae.thresh.detach().cpu().numpy().tolist()

In [43]:
print(sae_thresh)

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,