In [1]:
import torch
import torch.nn as nn
import numpy as np
import os
import math
import time 
from transformers import GPT2Model, GPT2Config, GPT2Tokenizer
from datetime import datetime
from matplotlib import pyplot as plt
import pickle
from transformer_lens import HookedTransformerConfig, HookedTransformer

try:
    device = torch.device('cuda')
except:
    print('Cuda not available')

torch.cuda.empty_cache()

In [2]:
all_vars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


def generate_data(tokenizer, n_var, batch_size=100):
    
    batch = []
    labels = []
    clause_order = []
    for _ in range(batch_size):
        values = np.random.randint(0, 2, (n_var,))
        var_idx = tuple(np.random.permutation(len(all_vars)))
        vars = [all_vars[i] for i in var_idx]

        # generate first sentence
        clauses = []
        clauses.append('val %d = %s ,' % (values[0], vars[0]))

        for i in range(1, n_var):
            modifier = 'val' if values[i] == values[i-1] else 'not'
            clauses.append('%s %s = %s ,' % (modifier, vars[i-1], vars[i]))
            

        sent = ''
        label = []
        
        clause_idx = tuple(range(n_var))
        sent += ''.join([clauses[idx] for idx in clause_idx])
        label += [values[idx] for idx in clause_idx]
        
        
        order = torch.zeros(1, n_var, n_var)
        for i in range(n_var):
            order[0, i, clause_idx[i]] = 1
            
        batch.append(tokenizer(sent, return_tensors='pt')['input_ids'])
        labels.append(values)
        clause_order.append(order)
    return torch.cat(batch), torch.LongTensor(labels), torch.cat(clause_order)




def make_lego_datasets(tokenizer, n_var, n_train, n_test, batch_size):
    
    train_data = []
    train_labels = []
    train_order = []

    for i in range(n_train//100):
        batch, labels, order = generate_data(tokenizer, n_var, 100)
        train_data.append(batch)
        train_labels.append(labels)
        train_order.append(order)

    x_train = torch.cat(train_data)
    y_train = torch.cat(train_labels)
    order_train = torch.cat(train_order)
    
    trainset = torch.utils.data.TensorDataset(x_train, y_train, order_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    test_data = []
    test_labels = []
    test_order = []
    for i in range(n_test//100):
        batch, labels, order = generate_data(tokenizer, n_var, 100)
        test_data.append(batch)
        test_labels.append(labels)
        test_order.append(order)

    x_test = torch.cat(test_data)
    y_test = torch.cat(test_labels)
    order_test = torch.cat(test_order)

    testset = torch.utils.data.TensorDataset(x_test, y_test, order_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size)
    
    return trainloader, testloader

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [3]:
# Used variables in the LEGO chains
all_vars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    
# Seed everything for reproducibility
seed_everything(0)

# n_var: total number of variables in a chain
# n_train_var: number of variables to provide supervision during training
n_var, n_train_var = 2, 2

# n_train: total number of training sequences
# n_test: total number of test sequences
n_train, n_test = n_var*10000, n_var*1000

# batch size >= 500 is recommended
batch_size = 50

# Specify tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Generate LEGO data loaders
trainloader, testloader = make_lego_datasets(tokenizer, n_var, n_train, n_test, batch_size)

# Examine an example LEGO sequence
seq, label, _ = trainloader.dataset[0]
print(tokenizer.decode(seq))
print(list(label.numpy()))

  return torch.cat(batch), torch.LongTensor(labels), torch.cat(clause_order)


val 0 = c,not c = w,
[0, 1]


In [4]:
seq

tensor([2100,  657,  796,  269,  837, 1662,  269,  796,  266,  837])

In [5]:
tokenizer.tokenize("val 0 = c,not c = w, ")

['val', 'Ġ0', 'Ġ=', 'Ġc', ',', 'not', 'Ġc', 'Ġ=', 'Ġw', ',', 'Ġ']

In [6]:
# Add a classification layer to predict whether the next variable is 0 or 1

L_hidden_state = [0]
last_hidden_state = lambda name: (name == 'ln_final.hook_normalized')

def add_list(tensor, hook):
    L_hidden_state[0] = tensor

class Model(nn.Module):
    def __init__(self, base, d_model, tgt_vocab=1):
        super(Model, self).__init__()
        self.base = base
        self.classifier = nn.Linear(d_model, tgt_vocab)
        
    def forward(self, x, mask=None):
        logits = self.base.run_with_hooks(x, fwd_hooks = [(last_hidden_state, add_list)])
        out = self.classifier(L_hidden_state[0])
        return out

# Define the model

torch.cuda.empty_cache()

"""micro_gpt_cfg = HookedTransformerConfig(
    d_model=64,
    d_head=32,
    n_heads=12,
    d_mlp=512,
    n_layers=8,
    n_ctx=512,
    act_fn="gelu_new",
    normalization_type="LN",
    tokenizer_name="gpt2",
    seed = 0,
)
model = EasyTransformer(micro_gpt_cfg).to('cuda') # random smallish model
"""

'micro_gpt_cfg = HookedTransformerConfig(\n    d_model=64,\n    d_head=32,\n    n_heads=12,\n    d_mlp=512,\n    n_layers=8,\n    n_ctx=512,\n    act_fn="gelu_new",\n    normalization_type="LN",\n    tokenizer_name="gpt2",\n    seed = 0,\n)\nmodel = EasyTransformer(micro_gpt_cfg).to(\'cuda\') # random smallish model\n'

In [7]:
with open('good_model.pkl', 'rb') as file:
    model = pickle.load(file)

In [7]:
"""
with open('good_model.pkl', 'wb') as file:
    pickle.dump(model, file)
"""

"\nwith open('good_model.pkl', 'wb') as file:\n    pickle.dump(model, file)\n"

In [8]:
# Define train and test functions for the LEGO task
train_var_pred = [i for i in range(2*n_train_var)] 
test_var_pred = [i for i in range(2*n_var)]

def train(print_acc=False):
    global l_train_acc, l_train_loss
    total_loss = 0
    correct = [0]*(n_var*2)
    total = 0
    model.train()
    for batch, labels, order in trainloader:
    
        x = batch.cuda()
        y = labels.cuda()
        inv_order = order.permute(0, 2, 1).cuda()
        
        optimizer.zero_grad()
        #pred = torch.argmax(model(x), -1, keepdim = True)
        #pred = torch.reshape(pred, (pred.shape[0], pred.shape[1], 1))
        pred = model(x)
        ordered_pred = torch.bmm(inv_order, pred[:, 3:-1:5, :]).squeeze()

        loss = 0
        for idx in range(n_train_var):
            loss += criterion(ordered_pred[:, idx], y[:, idx].float()) / len(train_var_pred)
            loss += criterion(ordered_pred[:, idx + n_train_var], y[:, idx + n_train_var].float()) / len(train_var_pred)
            
            total_loss += loss.item() / len(train_var_pred)

            correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
            correct[idx + n_train_var] += ((ordered_pred[:, idx + n_train_var]>0).long() == y[:, idx + n_train_var]).float().mean().item()
        
        total += 1
    
        loss.backward()
        optimizer.step()
    
    train_acc = [corr/total for corr in correct]

    l_train_loss.append(total_loss / total)
    l_train_acc.append(list(train_acc))

    return train_acc


def test():
    global l_test_acc, l_test_loss

    test_acc = []
    start = time.time()
    total_loss = 0
    correct = [0]*(n_var*2)
    total = 0
    model.eval()
    with torch.no_grad():
        for batch, labels, order in testloader:
    
            x = batch.cuda()
            y = labels.cuda()
            inv_order = order.permute(0, 2, 1).cuda()
            
            #pred = torch.argmax(model(x), -1, keepdim = True)
            #pred = torch.reshape(pred, (pred.shape[0], pred.shape[1], 1))
            pred = model(x)
            ordered_pred = torch.bmm(inv_order, pred[:, 3:-1:5, :]).squeeze()
            
            for idx in test_var_pred:
                loss = criterion(ordered_pred[:,idx], y[:, idx].float())
                total_loss += loss.item() / len(test_var_pred)
                correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
                          
            total += 1
        
        test_acc = [corr/total for corr in correct]

        l_test_loss.append(total_loss / total)
        l_test_acc.append(list(test_acc))

    return test_acc

In [39]:
model.classifier._parameters['weight'].shape

torch.Size([1, 512])

In [10]:
print(model)
print(sum(p.numel() for p in model.parameters()))

Model(
  (base): HookedTransformer(
    (embed): Embed()
    (hook_embed): HookPoint()
    (blocks): ModuleList(
      (0): TransformerBlock(
        (ln1): LayerNormPre(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (ln2): LayerNormPre(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (attn): Attention(
          (hook_k): HookPoint()
          (hook_q): HookPoint()
          (hook_v): HookPoint()
          (hook_z): HookPoint()
          (hook_attn_scores): HookPoint()
          (hook_pattern): HookPoint()
          (hook_result): HookPoint()
          (hook_rot_k): HookPoint()
          (hook_rot_q): HookPoint()
        )
        (mlp): MLP(
          (hook_pre): HookPoint()
          (hook_post): HookPoint()
        )
        (hook_attn_out): HookPoint()
        (hook_mlp_out): HookPoint()
        (hook_resid_pre): HookPoint()
        (hook_resid_post): HookPoint()
      )
      (1): Tran

In [7]:
# Print activation shapes at every layer for our model

embed_or_first_layer = lambda name: (name[:6] != "blocks" or name[:8] == "blocks.0")

def print_shape(tensor, hook):
    print(f"Activation at hook {hook.name} has shape:")
    print(tensor.shape)

random_tokens = torch.randint(1000, 10000, (4, 50))
logits = model.base.run_with_hooks(random_tokens, fwd_hooks=[(embed_or_first_layer, print_shape)])

Activation at hook hook_embed has shape:
torch.Size([4, 50, 512])
Activation at hook blocks.0.hook_resid_pre has shape:
torch.Size([4, 50, 512])
Activation at hook blocks.0.ln1.hook_scale has shape:
torch.Size([4, 50, 1])
Activation at hook blocks.0.ln1.hook_normalized has shape:
torch.Size([4, 50, 512])
Activation at hook blocks.0.attn.hook_q has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_k has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_v has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_rot_q has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_rot_k has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_attn_scores has shape:
torch.Size([4, 8, 50, 50])
Activation at hook blocks.0.attn.hook_pattern has shape:
torch.Size([4, 8, 50, 50])
Activation at hook blocks.0.attn.hook_z has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.hook_attn_out has 

In [9]:
criterion = nn.BCEWithLogitsLoss().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

# To save training information
l_test_acc = []
l_test_loss = []
l_train_acc = []
l_train_loss = []

In [10]:
with open('good_model.pkl', 'rb') as file:
    model = pickle.load(file)

In [11]:
allact = dict()
allparams = lambda name: True
torch.cuda.empty_cache()
 
def init(tensor, hook):
    allact.update({hook.name:[]})
    
def save_act(tensor, hook):
    sector = hook.name
    allact.update({sector:[tensor]+allact[sector]})

trigger = trainloader.dataset[0][0]
logits = model.base.run_with_hooks(trigger, fwd_hooks=[(allparams, init)])

for i in range(len(trigger)) :
    trigger = trainloader.dataset[i][0]
    logits = model.base.run_with_hooks(trigger, fwd_hooks=[(allparams, save_act)])

In [90]:
for key, _ in allact.items():
    print(key)
    print(allact[key][0].shape, '\n')

hook_embed
torch.Size([1, 10, 512]) 

blocks.0.hook_resid_pre
torch.Size([1, 10, 512]) 

blocks.0.ln1.hook_scale
torch.Size([1, 10, 1]) 

blocks.0.ln1.hook_normalized
torch.Size([1, 10, 512]) 

blocks.0.attn.hook_q
torch.Size([1, 10, 8, 64]) 

blocks.0.attn.hook_k
torch.Size([1, 10, 8, 64]) 

blocks.0.attn.hook_v
torch.Size([1, 10, 8, 64]) 

blocks.0.attn.hook_rot_q
torch.Size([1, 10, 8, 64]) 

blocks.0.attn.hook_rot_k
torch.Size([1, 10, 8, 64]) 

blocks.0.attn.hook_attn_scores
torch.Size([1, 8, 10, 10]) 

blocks.0.attn.hook_pattern
torch.Size([1, 8, 10, 10]) 

blocks.0.attn.hook_z
torch.Size([1, 10, 8, 64]) 

blocks.0.hook_attn_out
torch.Size([1, 10, 512]) 

blocks.0.ln2.hook_scale
torch.Size([1, 10, 1]) 

blocks.0.ln2.hook_normalized
torch.Size([1, 10, 512]) 

blocks.0.mlp.hook_pre
torch.Size([1, 10, 2048]) 

blocks.0.mlp.hook_post
torch.Size([1, 10, 2048]) 

blocks.0.hook_mlp_out
torch.Size([1, 10, 512]) 

blocks.0.hook_resid_post
torch.Size([1, 10, 512]) 

blocks.1.hook_resid_pre
t

In [12]:
torch.cuda.empty_cache()
allavg = dict()

for key, tensor_list in allact.items() :
    allavg.update({key: torch.mean(torch.cat(tensor_list, dim=0), dim=0)})

In [14]:
allavg['blocks.5.mlp.hook_pre'].shape

torch.Size([11, 2048])

In [13]:
k = allavg['blocks.5.mlp.hook_pre'][-1, :].cpu().detach().numpy()

In [14]:
k = k.reshape((1,2048))

In [15]:
C = np.dot(np.transpose(k), k)

In [16]:
C = np.linalg.inv(C)

In [17]:
sent = "val 1 = a,not a = z, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']

def choose_hook(name):
    return name == 'blocks.5.mlp.hook_post'

L = [0]

def save_act(tensor, hook):
    L[0] = tensor

model.base.run_with_hooks(tok, fwd_hooks=[(choose_hook, save_act)])

k_star = L[0][0, -1, :].cuda()

In [18]:
sent = "val 1 = a,val a = z, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']

def choose_hook(name):
    return name == 'blocks.5.hook_mlp_out'

L = [0]

def save_act(tensor, hook):
    L[0] = tensor

model.base.run_with_hooks(tok, fwd_hooks=[(choose_hook, save_act)])

v_star = L[0][0, -1, :].cuda()

In [19]:
W = model.base.state_dict()['blocks.5.mlp.W_out']

In [69]:
W.shape

torch.Size([2048, 512])

In [70]:
k_star.shape

torch.Size([2048])

In [20]:
Lambda = (v_star - torch.matmul(k_star, W)) / torch.matmul(torch.transpose(torch.matmul(torch.tensor(C).cuda(), k_star), 0, 0), k_star)

In [21]:
W_hat = W + torch.matmul(torch.transpose(torch.matmul(k_star, torch.tensor(C).cuda()), 0, 0).reshape((2048, 1)), Lambda.reshape((1, 512)))

In [124]:
W_hat

tensor([[ 0.0239, -0.0359,  0.0345,  ...,  0.0083,  0.0227,  0.0336],
        [-0.0451, -0.0291, -0.0214,  ..., -0.0451,  0.0332,  0.0130],
        [-0.0400,  0.0225, -0.0576,  ..., -0.0649,  0.0366, -0.0211],
        ...,
        [ 0.0257, -0.0151,  0.0507,  ...,  0.0413, -0.0217, -0.0025],
        [-0.0040, -0.0194,  0.0203,  ...,  0.0399,  0.0034, -0.0176],
        [ 0.0127, -0.0230, -0.0190,  ...,  0.0282,  0.0298,  0.0280]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [125]:
W - W_hat

tensor([[ 8.6707e-03,  1.6967e-03, -4.1231e-03,  ..., -9.9503e-03,
         -6.6779e-03, -3.1188e-03],
        [ 2.1572e-04,  4.2213e-05, -1.0258e-04,  ..., -2.4756e-04,
         -1.6614e-04, -7.7594e-05],
        [ 5.1670e-05,  1.0110e-05, -2.4568e-05,  ..., -5.9292e-05,
         -3.9794e-05, -1.8585e-05],
        ...,
        [ 7.7998e-05,  1.5263e-05, -3.7089e-05,  ..., -8.9508e-05,
         -6.0072e-05, -2.8056e-05],
        [-2.5378e-04, -4.9660e-05,  1.2068e-04,  ...,  2.9123e-04,
          1.9545e-04,  9.1285e-05],
        [-1.1534e-04, -2.2572e-05,  5.4847e-05,  ...,  1.3237e-04,
          8.8833e-05,  4.1489e-05]], device='cuda:0', grad_fn=<SubBackward0>)

In [126]:
torch.min(W - W_hat)

tensor(-8.7303, device='cuda:0', grad_fn=<MinBackward1>)

In [104]:
W

tensor([[ 0.0326, -0.0342,  0.0304,  ..., -0.0017,  0.0160,  0.0305],
        [-0.0449, -0.0290, -0.0215,  ..., -0.0454,  0.0330,  0.0129],
        [-0.0399,  0.0226, -0.0576,  ..., -0.0649,  0.0366, -0.0211],
        ...,
        [ 0.0257, -0.0151,  0.0507,  ...,  0.0412, -0.0218, -0.0025],
        [-0.0042, -0.0194,  0.0204,  ...,  0.0402,  0.0036, -0.0175],
        [ 0.0126, -0.0230, -0.0189,  ...,  0.0284,  0.0299,  0.0281]],
       device='cuda:0')

In [22]:
def find_max(array, lim):
    L = []
    for i, x in enumerate(array):
        if isinstance(x, np.ndarray):
            l = find_max(x, lim)
            l = [[i] + y for y in l]
            if l != []:
                L += l
        else:
            if abs(x) > lim:
                L.append([i])
    return L

In [162]:
len(find_max((W - W_hat).cpu().detach().numpy(), 1))

0

In [23]:
model.base.state_dict()['blocks.5.mlp.W_out'] += 2*(W_hat - W)

In [24]:
model.base.state_dict()['blocks.5.mlp.W_out']

tensor([[ 0.0153, -0.0376,  0.0386,  ...,  0.0182,  0.0294,  0.0368],
        [-0.0453, -0.0291, -0.0213,  ..., -0.0449,  0.0334,  0.0130],
        [-0.0400,  0.0225, -0.0575,  ..., -0.0648,  0.0367, -0.0210],
        ...,
        [ 0.0256, -0.0151,  0.0508,  ...,  0.0414, -0.0216, -0.0025],
        [-0.0037, -0.0193,  0.0202,  ...,  0.0396,  0.0032, -0.0176],
        [ 0.0128, -0.0230, -0.0190,  ...,  0.0281,  0.0298,  0.0280]],
       device='cuda:0')

In [25]:
sent = "val 0 = e,not e = k, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']
print("Résultat du modèle: ", model(tok)[:,3:-1:5,:])

Résultat du modèle:  tensor([[[-7.0071],
         [ 0.6680]]], device='cuda:0', grad_fn=<SliceBackward0>)


In [26]:
sent = "val 1 = a,not a = b, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']
print("Résultat du modèle: ", model(tok)[:,3:-1:5,:])

Résultat du modèle:  tensor([[[8.0914],
         [4.3208]]], device='cuda:0', grad_fn=<SliceBackward0>)


In [9]:
sent = "val 0 = a, not a = b, val b = c, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']
print("Résultat du modèle: ", model(tok)[:,3:-1:5,:])

Résultat du modèle:  tensor([[[-11.9871],
         [-10.3129],
         [ -5.2165]]], device='cuda:0', grad_fn=<SliceBackward0>)


In [14]:
sent = "val 0 = a,not a = b,val b = c, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']
print("Résultat du modèle: ", model(tok)[:,3:-1:5,:])
#val 0 = c,not c = w,

Résultat du modèle:  tensor([[[-11.9871],
         [  7.6134],
         [ 11.9640]]], device='cuda:0', grad_fn=<SliceBackward0>)
