In [1]:
import torch
import torch.nn as nn
import numpy as np
import os
import math
import time 
from transformers import GPT2Model, GPT2Config, GPT2Tokenizer
from datetime import datetime
from matplotlib import pyplot as plt
import pickle
from transformer_lens import HookedTransformerConfig, HookedTransformer

try:
    device = torch.device('cuda')
except:
    print('Cuda not available')

torch.cuda.empty_cache()

In [2]:
all_vars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


def generate_data(tokenizer, n_var, batch_size=100):
    
    batch = []
    labels = []
    clause_order = []
    for _ in range(batch_size):
        values = np.random.randint(0, 2, (n_var,))
        var_idx = tuple(np.random.permutation(len(all_vars)))
        vars = [all_vars[i] for i in var_idx]

        # generate first sentence
        clauses = []
        clauses.append('val %d = %s ,' % (values[0], vars[0]))

        for i in range(1, n_var):
            modifier = 'val' if values[i] == values[i-1] else 'not'
            clauses.append('%s %s = %s ,' % (modifier, vars[i-1], vars[i]))
            

        sent = ''
        label = []
        
        clause_idx = tuple(range(n_var))
        sent += ''.join([clauses[idx] for idx in clause_idx])
        label += [values[idx] for idx in clause_idx]
        
        
        order = torch.zeros(1, n_var, n_var)
        for i in range(n_var):
            order[0, i, clause_idx[i]] = 1
            
        batch.append(tokenizer(sent, return_tensors='pt')['input_ids'])
        labels.append(values)
        clause_order.append(order)
    return torch.cat(batch), torch.LongTensor(labels), torch.cat(clause_order)




def make_lego_datasets(tokenizer, n_var, n_train, n_test, batch_size):
    
    train_data = []
    train_labels = []
    train_order = []

    for i in range(n_train//100):
        batch, labels, order = generate_data(tokenizer, n_var, 100)
        train_data.append(batch)
        train_labels.append(labels)
        train_order.append(order)

    x_train = torch.cat(train_data)
    y_train = torch.cat(train_labels)
    order_train = torch.cat(train_order)
    
    trainset = torch.utils.data.TensorDataset(x_train, y_train, order_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    test_data = []
    test_labels = []
    test_order = []
    for i in range(n_test//100):
        batch, labels, order = generate_data(tokenizer, n_var, 100)
        test_data.append(batch)
        test_labels.append(labels)
        test_order.append(order)

    x_test = torch.cat(test_data)
    y_test = torch.cat(test_labels)
    order_test = torch.cat(test_order)

    testset = torch.utils.data.TensorDataset(x_test, y_test, order_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size)
    
    return trainloader, testloader

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [3]:
# Used variables in the LEGO chains
all_vars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    
# Seed everything for reproducibility
seed_everything(0)

# n_var: total number of variables in a chain
# n_train_var: number of variables to provide supervision during training
n_var, n_train_var = 2, 2

# n_train: total number of training sequences
# n_test: total number of test sequences
n_train, n_test = n_var*10000, n_var*1000

# batch size >= 500 is recommended
batch_size = 50

# Specify tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Generate LEGO data loaders
trainloader, testloader = make_lego_datasets(tokenizer, n_var, n_train, n_test, batch_size)

# Examine an example LEGO sequence
seq, label, _ = trainloader.dataset[0]
print(tokenizer.decode(seq))
print(list(label.numpy()))

  return torch.cat(batch), torch.LongTensor(labels), torch.cat(clause_order)


val 0 = c,not c = w,
[0, 1]


In [None]:
seq

In [None]:
tokenizer.tokenize("val 0 = c,not c = w, ")

In [4]:
# Add a classification layer to predict whether the next variable is 0 or 1

L_hidden_state = [0]
last_hidden_state = lambda name: (name == 'ln_final.hook_normalized')

def add_list(tensor, hook):
    L_hidden_state[0] = tensor

class Model(nn.Module):
    def __init__(self, base, d_model, tgt_vocab=1):
        super(Model, self).__init__()
        self.base = base
        self.classifier = nn.Linear(d_model, tgt_vocab)
        
    def forward(self, x, mask=None):
        logits = self.base.run_with_hooks(x, fwd_hooks = [(last_hidden_state, add_list)])
        out = self.classifier(L_hidden_state[0])
        return out

# Define the model

torch.cuda.empty_cache()

"""micro_gpt_cfg = HookedTransformerConfig(
    d_model=64,
    d_head=32,
    n_heads=12,
    d_mlp=512,
    n_layers=8,
    n_ctx=512,
    act_fn="gelu_new",
    normalization_type="LN",
    tokenizer_name="gpt2",
    seed = 0,
)
model = EasyTransformer(micro_gpt_cfg).to('cuda') # random smallish model
"""

'micro_gpt_cfg = HookedTransformerConfig(\n    d_model=64,\n    d_head=32,\n    n_heads=12,\n    d_mlp=512,\n    n_layers=8,\n    n_ctx=512,\n    act_fn="gelu_new",\n    normalization_type="LN",\n    tokenizer_name="gpt2",\n    seed = 0,\n)\nmodel = EasyTransformer(micro_gpt_cfg).to(\'cuda\') # random smallish model\n'

In [42]:
with open('good_model.pkl', 'rb') as file:
    model = pickle.load(file)

In [None]:
"""
with open('good_model.pkl', 'wb') as file:
    pickle.dump(model, file)
"""

In [6]:
# Define train and test functions for the LEGO task
train_var_pred = [i for i in range(2*n_train_var)] 
test_var_pred = [i for i in range(2*n_var)]

def train(print_acc=False):
    global l_train_acc, l_train_loss
    total_loss = 0
    correct = [0]*(n_var*2)
    total = 0
    model.train()
    for batch, labels, order in trainloader:
    
        x = batch.cuda()
        y = labels.cuda()
        inv_order = order.permute(0, 2, 1).cuda()
        
        optimizer.zero_grad()
        #pred = torch.argmax(model(x), -1, keepdim = True)
        #pred = torch.reshape(pred, (pred.shape[0], pred.shape[1], 1))
        pred = model(x)
        ordered_pred = torch.bmm(inv_order, pred[:, 3:-1:5, :]).squeeze()

        loss = 0
        for idx in range(n_train_var):
            loss += criterion(ordered_pred[:, idx], y[:, idx].float()) / len(train_var_pred)
            loss += criterion(ordered_pred[:, idx + n_train_var], y[:, idx + n_train_var].float()) / len(train_var_pred)
            
            total_loss += loss.item() / len(train_var_pred)

            correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
            correct[idx + n_train_var] += ((ordered_pred[:, idx + n_train_var]>0).long() == y[:, idx + n_train_var]).float().mean().item()
        
        total += 1
    
        loss.backward()
        optimizer.step()
    
    train_acc = [corr/total for corr in correct]

    l_train_loss.append(total_loss / total)
    l_train_acc.append(list(train_acc))

    return train_acc


def test():
    global l_test_acc, l_test_loss

    test_acc = []
    start = time.time()
    total_loss = 0
    correct = [0]*(n_var*2)
    total = 0
    model.eval()
    with torch.no_grad():
        for batch, labels, order in testloader:
    
            x = batch.cuda()
            y = labels.cuda()
            inv_order = order.permute(0, 2, 1).cuda()
            
            #pred = torch.argmax(model(x), -1, keepdim = True)
            #pred = torch.reshape(pred, (pred.shape[0], pred.shape[1], 1))
            pred = model(x)
            ordered_pred = torch.bmm(inv_order, pred[:, 3:-1:5, :]).squeeze()
            
            for idx in test_var_pred:
                loss = criterion(ordered_pred[:,idx], y[:, idx].float())
                total_loss += loss.item() / len(test_var_pred)
                correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
                          
            total += 1
        
        test_acc = [corr/total for corr in correct]

        l_test_loss.append(total_loss / total)
        l_test_acc.append(list(test_acc))

    return test_acc

In [None]:
model.classifier._parameters['weight'].shape

In [None]:
print(model)
print(sum(p.numel() for p in model.parameters()))

In [None]:
# Print activation shapes at every layer for our model

embed_or_first_layer = lambda name: (name[:6] != "blocks" or name[:8] == "blocks.0")

def print_shape(tensor, hook):
    print(f"Activation at hook {hook.name} has shape:")
    print(tensor.shape)

random_tokens = torch.randint(1000, 10000, (4, 50))
logits = model.base.run_with_hooks(random_tokens, fwd_hooks=[(embed_or_first_layer, print_shape)])

In [7]:
criterion = nn.BCEWithLogitsLoss().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

# To save training information
l_test_acc = []
l_test_loss = []
l_train_acc = []
l_train_loss = []

In [31]:
with open('good_model.pkl', 'rb') as file:
    model = pickle.load(file)

In [32]:
allact = dict()
allparams = lambda name: True
torch.cuda.empty_cache()
 
def init(tensor, hook):
    allact.update({hook.name:[]})
    
def save_act(tensor, hook):
    sector = hook.name
    allact.update({sector:[tensor]+allact[sector]})

trigger = trainloader.dataset[0][0]
logits = model.base.run_with_hooks(trigger, fwd_hooks=[(allparams, init)])

for i in range(len(trigger)) :
    trigger = trainloader.dataset[i][0]
    logits = model.base.run_with_hooks(trigger, fwd_hooks=[(allparams, save_act)])

In [None]:
for key, _ in allact.items():
    print(key)
    print(allact[key][0].shape, '\n')

In [33]:
torch.cuda.empty_cache()
allavg = dict()

for key, tensor_list in allact.items() :
    allavg.update({key: torch.mean(torch.cat(tensor_list, dim=0), dim=0)})

In [None]:
allavg['blocks.5.mlp.hook_pre'].shape

In [34]:
k = allavg['blocks.5.mlp.hook_pre'][-1, :].cpu().detach().numpy()
k = k.reshape((1,2048))
C = np.dot(np.transpose(k), k)
C = np.linalg.inv(C)

In [17]:
sent = "val 1 = a,not a = z, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']

def choose_hook(name):
    return name == 'blocks.5.mlp.hook_post'

L = [0]

def save_act(tensor, hook):
    L[0] = tensor

model.base.run_with_hooks(tok, fwd_hooks=[(choose_hook, save_act)])

k_star = L[0][0, -1, :].cuda()

In [None]:
sent = "val 1 = a,val a = z, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']

def choose_hook(name):
    return name == 'blocks.5.hook_mlp_out'

L = [0]

def save_act(tensor, hook):
    L[0] = tensor

model.base.run_with_hooks(tok, fwd_hooks=[(choose_hook, save_act)])

v_star = L[0][0, -1, :].cuda()

In [None]:
W = model.base.state_dict()['blocks.5.mlp.W_out']

In [None]:
W.shape

In [None]:
k_star.shape

In [None]:
Lambda = (v_star - torch.matmul(k_star, W)) / torch.matmul(torch.transpose(torch.matmul(torch.tensor(C).cuda(), k_star), 0, 0), k_star)

In [None]:
W_hat = W + torch.matmul(torch.transpose(torch.matmul(k_star, torch.tensor(C).cuda()), 0, 0).reshape((2048, 1)), Lambda.reshape((1, 512)))

In [None]:
W_hat

In [None]:
W - W_hat

In [None]:
torch.min(W - W_hat)

In [None]:
W

In [23]:
def find_max(array, lim):
    L = []
    for i, x in enumerate(array):
        if isinstance(x, np.ndarray):
            l = find_max(x, lim)
            l = [[i] + y for y in l]
            if l != []:
                L += l
        else:
            if abs(x) > lim:
                L.append([i])
    return L

In [None]:
len(find_max((W - W_hat).cpu().detach().numpy(), 1))

In [None]:
model.base.state_dict()['blocks.5.mlp.W_out'] += 2*(W_hat - W)

In [None]:
model.base.state_dict()['blocks.5.mlp.W_out']

In [None]:
sent = "val 0 = e,not e = k, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']
print("Résultat du modèle: ", model(tok)[:,3:-1:5,:])

In [None]:
sent = "val 1 = a,not a = b, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']
print("Résultat du modèle: ", model(tok)[:,3:-1:5,:])

In [None]:
sent = "val 0 = a, not a = b, val b = c, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']
print("Résultat du modèle: ", model(tok)[:,3:-1:5,:])

In [None]:
sent = "val 0 = a,not a = b,val b = c, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']
print("Résultat du modèle: ", model(tok)[:,3:-1:5,:])
#val 0 = c,not c = w,

Ci-dessous, on vérifie que le modèle répond toujours 1, 1 à une phrase de la forme "val 1 = _, not _ = _" (alors qu'il devrait répondre 1, 0), et répond juste aux phrases ayant une autre forme.

In [None]:
liste = [{"val": [], "not": []},
         {"val": [], "not": []}]

for x in all_vars:
    for y in all_vars:
        for digit in [0, 1]:
            for sign in ["val", "not"]:
                liste[digit][sign].append("val {} = {},{} {} = {}, ".format(digit, x, sign, a, b))

In [None]:
resultat = [{"val": [ [0, 0], [0, 0] ], "not": [ [0, 0], [0, 0] ]},
            {"val": [ [0, 0], [0, 0] ], "not": [ [0, 0], [0, 0] ]}]

for digit in [0, 1]:
    for sign in ["val", "not"]:
        for sent in liste[digit][sign]:
            tok = tokenizer(sent, return_tensors='pt')['input_ids']
            res = model(tok)[:,3:-1:5,:][0] #2, 1
            a, b = int(res[0][0] > 0), int(res[1][0] > 0)
            resultat[digit][sign][a][b] += 1 

In [None]:
resultat

Ci-dessous, on essaie d'utiliser plusieurs k_star.

In [11]:
list_tok = []
for x in all_vars:
    list_tok.append(tokenizer("val 1 = {},not {} = z, ".format(x, x), return_tensors = 'pt')['input_ids'])

def choose_hook(name):
    return name == 'blocks.5.mlp.hook_post'

L = [0]

def save_act(tensor, hook):
    L[0] = tensor

k_star = torch.zeros(2048).cuda()
for tok in list_tok:
    model.base.run_with_hooks(tok, fwd_hooks=[(choose_hook, save_act)])
    k_star += L[0][0, -1, :].cuda()
k_star *= 1/len(list_tok)

del list_tok

sent = "val 1 = a,val a = z, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']

def choose_hook(name):
    return name == 'blocks.5.hook_mlp_out'

L = [0]

def save_act(tensor, hook):
    L[0] = tensor

model.base.run_with_hooks(tok, fwd_hooks=[(choose_hook, save_act)])

v_star = L[0][0, -1, :].cuda()

W = model.base.state_dict()['blocks.5.mlp.W_out']
Lambda = (v_star - torch.matmul(k_star, W)) / torch.matmul(torch.transpose(torch.matmul(torch.tensor(C).cuda(), k_star), 0, 0), k_star)
W_hat = W + torch.matmul(torch.transpose(torch.matmul(k_star, torch.tensor(C).cuda()), 0, 0).reshape((2048, 1)), Lambda.reshape((1, 512)))
model.base.state_dict()['blocks.5.mlp.W_out'] += 2 * (W_hat - W)

In [43]:
sent = "val 1 = a,val a = z, "
tok = tokenizer(sent, return_tensors = 'pt')['input_ids']
def choose_hook(name):
    return name == 'blocks.5.hook_mlp_out'
L = [0]
def save_act(tensor, hook):
    L[0] = tensor
model.base.run_with_hooks(tok, fwd_hooks=[(choose_hook, save_act)])
v_star = L[0][0, -1, :].cuda()


list_tok = []
for x in all_vars[:10]:
    list_tok.append(tokenizer("val 1 = {},not {} = z, ".format(x, x), return_tensors = 'pt')['input_ids'])

def choose_hook(name):
    return name == 'blocks.5.mlp.hook_post'

L = [0]

def save_act(tensor, hook):
    L[0] = tensor

W = model.base.state_dict()['blocks.5.mlp.W_out']
perturbation = torch.zeros(W.shape).cuda()
    
for tok in list_tok:
    model.base.run_with_hooks(tok, fwd_hooks=[(choose_hook, save_act)])
    k_star = L[0][0, -1, :].cuda()
    Lambda = (v_star - torch.matmul(k_star, W)) / torch.matmul(torch.transpose(torch.matmul(torch.tensor(C).cuda(), k_star), 0, 0), k_star)
    perturbation += torch.matmul(torch.transpose(torch.matmul(k_star, torch.tensor(C).cuda()), 0, 0).reshape((2048, 1)), Lambda.reshape((1, 512)))

perturbation *= 1/len(list_tok)
    
del list_tok

model.base.state_dict()['blocks.5.mlp.W_out'] += 2 * perturbation

In [46]:
sent = "val 1 = a,not a = z, "
tok = tokenizer(sent, return_tensors='pt')['input_ids']
print("Résultat du modèle: ", model(tok)[:,3:-1:5,:])

Résultat du modèle:  tensor([[[7.8666],
         [4.1787]]], device='cuda:0', grad_fn=<SliceBackward0>)
