# Entraînement d'un modèle sur la tâche LEGO

Dans ce notebook, nous importons et entraînons sur la tâche LEGO un modèle pythia à 19 millions de paramètres.

In [1]:
import torch
import torch.nn as nn
import numpy as np
import os
import math
import time 
from transformers import GPT2Model, GPT2Config, GPT2Tokenizer
from datetime import datetime
from matplotlib import pyplot as plt
import pickle
from transformer_lens import HookedTransformerConfig, HookedTransformer
from functions import *

try:
    device = torch.device('cuda')
except:
    print('Cuda not available')

torch.cuda.empty_cache()

Le fichier functions.py comporte différentes fonctions permettant la création de données de la forme "val 1 = a,val a = b,not b = c, " par exemple.

Ci-dessous, nous définissons diverses variables utiles par la suite.

In [2]:
# Used variables in the LEGO chains
all_vars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    
# Seed everything for reproducibility
seed_everything(0)

# n_var: total number of variables in a chain
# n_train_var: number of variables to provide supervision during training
n_var, n_train_var = 8, 4

# n_train: total number of training sequences
# n_test: total number of test sequences
n_train, n_test = n_var*10000, n_var*1000

# batch size >= 500 is recommended
batch_size = 50

# Specify tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Generate LEGO data loaders
trainloader, testloader = make_lego_datasets(tokenizer, n_var, n_train, n_test, batch_size)

# Examine an example LEGO sequence
seq, label, _ = trainloader.dataset[0]
print(tokenizer.decode(seq))
print(list(label.numpy()))

  return torch.cat(batch), torch.LongTensor(labels), torch.cat(clause_order)


val 0 = l,not l = y,val 1 = x,val y = k,val x = n,val n = z,val z = f,not f = w,not k = c,not c = r,val r = p,val p = a,not w = i,val a = d,not i = o,not o = q,
[0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1]


Le modèle que nous allons utiliser est le modèle pythia à 19 millions de paramètres, que nous importons à l'aide de TransformerLens, et auquel nous rajoutons un classifieur à la suite pour obtenir un unique nombre réel.

Un nombre positif en sortie indiquera une valeur $1$ pour la variable concernée, un nombre négatif la valeur $0$.

In [3]:
# Add a classification layer to predict whether the next variable is 0 or 1

L_hidden_state = [0]
last_hidden_state = lambda name: (name == 'ln_final.hook_normalized')

def add_list(tensor, hook):
    L_hidden_state[0] = tensor

class Model(nn.Module):
    def __init__(self, base, d_model, tgt_vocab=1):
        super(Model, self).__init__()
        self.base = base
        self.classifier = nn.Linear(d_model, tgt_vocab)
        
    def forward(self, x, mask=None):
        logits = self.base.run_with_hooks(x, fwd_hooks = [(last_hidden_state, add_list)])
        out = self.classifier(L_hidden_state[0])
        return out

torch.cuda.empty_cache()

#### TransformerLens model ####
model = HookedTransformer.from_pretrained('EleutherAI/pythia-19m')
hidden_size = 512

# Add the classification layer
model = Model(model, hidden_size).to('cuda')


# Define train and test functions for the LEGO task
train_var_pred = [i for i in range(2*n_train_var)] 
test_var_pred = [i for i in range(2*n_var)]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-19m into HookedTransformer


Nous pouvons regarder la structure du modèle:

In [4]:
print(model)

Model(
  (base): HookedTransformer(
    (embed): Embed()
    (hook_embed): HookPoint()
    (blocks): ModuleList(
      (0): TransformerBlock(
        (ln1): LayerNormPre(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (ln2): LayerNormPre(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (attn): Attention(
          (hook_k): HookPoint()
          (hook_q): HookPoint()
          (hook_v): HookPoint()
          (hook_z): HookPoint()
          (hook_attn_scores): HookPoint()
          (hook_pattern): HookPoint()
          (hook_result): HookPoint()
          (hook_rot_k): HookPoint()
          (hook_rot_q): HookPoint()
        )
        (mlp): MLP(
          (hook_pre): HookPoint()
          (hook_post): HookPoint()
        )
        (hook_attn_out): HookPoint()
        (hook_mlp_out): HookPoint()
        (hook_resid_pre): HookPoint()
        (hook_resid_post): HookPoint()
      )
      (1): Tran

Nous définissons deux fonctions `train` et `test` permettant respectivement d'entraîner notre modèle et de le tester.

In [5]:
criterion = nn.BCEWithLogitsLoss().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

# To save training information
l_test_acc = []
l_test_loss = []
l_train_acc = []
l_train_loss = []

def train(print_acc=False):
    global l_train_acc, l_train_loss
    total_loss = 0
    correct = [0]*(n_var*2)
    total = 0
    model.train()
    for batch, labels, order in trainloader:
    
        x = batch.cuda()
        y = labels.cuda()
        inv_order = order.permute(0, 2, 1).cuda()
        
        optimizer.zero_grad()
        #pred = torch.argmax(model(x), -1, keepdim = True)
        #pred = torch.reshape(pred, (pred.shape[0], pred.shape[1], 1))
        pred = model(x)
        ordered_pred = torch.bmm(inv_order, pred[:, 3:-1:5, :]).squeeze()

        loss = 0
        for idx in range(n_train_var):
            loss += criterion(ordered_pred[:, idx], y[:, idx].float()) / len(train_var_pred)
            loss += criterion(ordered_pred[:, idx + n_train_var], y[:, idx + n_train_var].float()) / len(train_var_pred)
            
            total_loss += loss.item() / len(train_var_pred)

            correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
            correct[idx + n_train_var] += ((ordered_pred[:, idx + n_train_var]>0).long() == y[:, idx + n_train_var]).float().mean().item()
        
        total += 1
    
        loss.backward()
        optimizer.step()
    
    train_acc = [corr/total for corr in correct]

    l_train_loss.append(total_loss / total)
    l_train_acc.append(list(train_acc))

    return train_acc


def test():
    global l_test_acc, l_test_loss

    test_acc = []
    start = time.time()
    total_loss = 0
    correct = [0]*(n_var*2)
    total = 0
    model.eval()
    with torch.no_grad():
        for batch, labels, order in testloader:
    
            x = batch.cuda()
            y = labels.cuda()
            inv_order = order.permute(0, 2, 1).cuda()
            
            #pred = torch.argmax(model(x), -1, keepdim = True)
            #pred = torch.reshape(pred, (pred.shape[0], pred.shape[1], 1))
            pred = model(x)
            ordered_pred = torch.bmm(inv_order, pred[:, 3:-1:5, :]).squeeze()
            
            for idx in test_var_pred:
                loss = criterion(ordered_pred[:,idx], y[:, idx].float())
                total_loss += loss.item() / len(test_var_pred)
                correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
                          
            total += 1
        
        test_acc = [corr/total for corr in correct]

        l_test_loss.append(total_loss / total)
        l_test_acc.append(list(test_acc))

    return test_acc

Nous pouvons alors entraîner notre modèle, en affichant régulièrement les résultats obtenus. Ici, nous nous contentons de faire 1 epoch pour vérifier le bon fonctionnement du code.

In [6]:
start = time.time()
n_epochs = 1
info = 1 # Show results every info epochs.

for epoch in range(1, n_epochs + 1):
    print('\n Epoch: ', epoch)
    train()
    test()
    scheduler.step()

    print('Time elapsed: %f s' %(time.time() - start))
    if epoch % info == 0 :
        print("Test loss: ", l_test_loss[-1])
        print("Test acc: ", '\n', l_test_acc[-1])


 Epoch:  1
Time elapsed: 130.047018 s
Test loss:  0.5327521087000235
Test acc:  
 [1.0, 0.7909999817609787, 0.6081249859184027, 0.5428749835118651, 0.5153749853372573, 0.5021249853074551, 0.49649998657405375, 0.4893749874085188, 1.0, 0.7884999800473451, 0.6092499863356352, 0.5359999863430858, 0.5109999857842922, 0.5036249840632081, 0.4994999857619405, 0.49887498524039986]


Nous pouvons également essayer d'utiliser les hooks sur ce modèle.

Nous regardons par exemple ici les hooks dont le nom commence par "blocks.0", et affichons leurs noms et leurs tailles.

In [7]:
# Print activation shapes at every layer for our model

embed_or_first_layer = lambda name: (name[:6] != "blocks" or name[:8] == "blocks.0")

def print_shape(tensor, hook):
    print(f"Activation at hook {hook.name} has shape:")
    print(tensor.shape)

random_tokens = torch.randint(1000, 10000, (4, 50))
logits = model.base.run_with_hooks(random_tokens, fwd_hooks=[(embed_or_first_layer, print_shape)])

Activation at hook hook_embed has shape:
torch.Size([4, 50, 512])
Activation at hook blocks.0.hook_resid_pre has shape:
torch.Size([4, 50, 512])
Activation at hook blocks.0.ln1.hook_scale has shape:
torch.Size([4, 50, 1])
Activation at hook blocks.0.ln1.hook_normalized has shape:
torch.Size([4, 50, 512])
Activation at hook blocks.0.attn.hook_q has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_k has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_v has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_rot_q has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_rot_k has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.attn.hook_attn_scores has shape:
torch.Size([4, 8, 50, 50])
Activation at hook blocks.0.attn.hook_pattern has shape:
torch.Size([4, 8, 50, 50])
Activation at hook blocks.0.attn.hook_z has shape:
torch.Size([4, 50, 8, 64])
Activation at hook blocks.0.hook_attn_out has 

Nous sauvegardons notre modèle pour ne pas avoir à le ré-entraîner à chaque fois:

In [8]:
"""
with open("trained_model.pkl", "wb") as file:
    pickle.dump(model, file)
"""

'\nwith open("model.pkl", "wb") as file:\n    pickle.dump(file, model)\n'