In [1]:
import torch
import torch.nn as nn
import numpy as np
import os
import math
import time 
from transformers import GPT2Model, GPT2Config, GPT2Tokenizer
from datetime import datetime
from matplotlib import pyplot as plt
import pickle
from easy_transformer import EasyTransformer, EasyTransformerConfig

if torch.cuda.is_available() :
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
%env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512

env: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512


In [2]:
def generate_data(tokenizer, n_var, batch_size=100):
    
    batch = []
    labels = []
    clause_order = []
    for _ in range(batch_size):
        values_1 = np.random.randint(0, 2, (n_var,))
        var_idx = tuple(np.random.permutation(len(all_vars)))
        vars = [all_vars[i] for i in var_idx]

        # generate first sentence
        clauses_1 = []
        clauses_1.append('val %d = %s ,' % (values_1[0], vars[0])) 

        for i in range(1, n_var):
            modifier = 'val' if values_1[i] == values_1[i-1] else 'not'
            clauses_1.append('%s %s = %s ,' % (modifier, vars[i-1], vars[i]))
            
        clauses_2 = []
        values_2 = np.random.randint(0, 2, (n_var,))
        clauses_2.append('val %d = %s ,' % (values_2[0], vars[n_var]))

        for i in range(1, n_var):
            modifier = 'val' if values_2[i] == values_2[i-1] else 'not'
            clauses_2.append('%s %s = %s ,' % (modifier, vars[n_var+i-1], vars[i+n_var]))

        sent = ''
        label = []
        
        order = torch.zeros(1, 2*n_var, 2*n_var)
        clause_idx = tuple(np.random.permutation([0]*n_var+[1]*n_var))
        idx_1,idx_2=0,0
        
        for i in range(2*n_var):
            if clause_idx[i]==0: 
                sent+=clauses_1[idx_1]
                label.append(values_1[idx_1])
                order[0,i,idx_1] = 1
                idx_1+=1
            else : 
                sent+=clauses_2[idx_2]
                label.append(values_2[idx_2])
                order[0,i,idx_2+n_var] = 1
                idx_2+=1

        batch.append(tokenizer(sent, return_tensors='pt')['input_ids'])
        labels.append(np.concatenate((values_1,values_2)))
        clause_order.append(order)
    return torch.cat(batch), torch.LongTensor(labels), torch.cat(clause_order)

def make_lego_datasets(tokenizer, n_var, n_train, n_test, batch_size):
    
    train_data = []
    train_labels = []
    train_order = []

    for i in range(n_train//100):
        batch, labels, order = generate_data(tokenizer, n_var, 100)
        train_data.append(batch)
        train_labels.append(labels)
        train_order.append(order)

    x_train = torch.cat(train_data)
    y_train = torch.cat(train_labels)
    order_train = torch.cat(train_order)
    
    trainset = torch.utils.data.TensorDataset(x_train, y_train, order_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    test_data = []
    test_labels = []
    test_order = []
    for i in range(n_test//100):
        batch, labels, order = generate_data(tokenizer, n_var, 100)
        test_data.append(batch)
        test_labels.append(labels)
        test_order.append(order)

    x_test = torch.cat(test_data)
    y_test = torch.cat(test_labels)
    order_test = torch.cat(test_order)

    testset = torch.utils.data.TensorDataset(x_test, y_test, order_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size)
    
    return trainloader, testloader

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [19]:
# Used variables in the LEGO chains
all_vars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    
# Seed everything for reproducibility
seed_everything(0)

# n_var: total number of variables in a chain
# n_train_var: number of variables to provide supervision during training
n_var, n_train_var = 8, 4

# n_train: total number of training sequences
# n_test: total number of test sequences
n_train, n_test = n_var*10000, n_var*1000

# batch size >= 500 is recommended
batch_size = 250

# Specify tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Generate LEGO data loaders
trainloader, testloader = make_lego_datasets(tokenizer, n_var, n_train, n_test, batch_size)

# Examine an example LEGO sequence
seq, label, _ = trainloader.dataset[0]
print(tokenizer.decode(seq))
print(list(label.numpy()))

val 0 = l,not l = y,val 1 = x,val y = k,val x = n,val n = z,val z = f,not f = w,not k = c,not c = r,val r = p,val p = a,not w = i,val a = d,not i = o,not o = q,
[0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1]


In [20]:
# Add a classification layer to predict whether the next variable is 0 or 1

L_hidden_state = [0]
last_hidden_state = lambda name: (name == 'ln_final.hook_normalized')

def add_list(tensor, hook):
    L_hidden_state[0] = tensor

class Model(nn.Module):
    def __init__(self, base, d_model, tgt_vocab=1):
        super(Model, self).__init__()
        self.base = base
        self.classifier = nn.Linear(d_model, tgt_vocab)
        
    def forward(self, x, mask=None):
        logits = self.base.run_with_hooks(x, fwd_hooks = [(last_hidden_state, add_list)])

        out = self.classifier(L_hidden_state[0])
        return out

# Define the model

micro_gpt_cfg = EasyTransformerConfig(
    d_model=64,
    d_head=32,
    n_heads=12,
    d_mlp=512,
    n_layers=8,
    n_ctx=512,
    act_fn="gelu_new",
    normalization_type="LN",
    tokenizer_name="gpt2",
    seed = 0,
)

# EasyTransformer model
model = EasyTransformer(micro_gpt_cfg).to('cuda')
hidden_size = 64

# Add the classification layer
model = Model(model, hidden_size).to('cuda')
#model = nn.DataParallel(model.cuda())


# Define train and test functions for the LEGO task
train_var_pred = [i for i in range(2*n_train_var)] 
test_var_pred = [i for i in range(2*n_var)]

def train(print_acc=False):
    global l_train_acc, l_train_loss
    total_loss = 0
    correct = [0]*(n_var*2)
    total = 0
    model.train()
    for batch, labels, order in trainloader:
    
        x = batch.cuda()
        y = labels.cuda()
        inv_order = order.permute(0, 2, 1).cuda()
        
        optimizer.zero_grad()
        pred = model(x)
        ordered_pred = torch.bmm(inv_order, pred[:, 3:-1:5, :]).squeeze()

        loss = 0
        for idx in range(n_train_var):
            loss += criterion(ordered_pred[:, idx], y[:, idx].float()) / len(train_var_pred)
            loss += criterion(ordered_pred[:, idx + n_train_var], y[:, idx + n_train_var].float()) / len(train_var_pred)
            
            total_loss += loss.item() / len(train_var_pred)

            correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
            correct[idx + n_train_var] += ((ordered_pred[:, idx + n_train_var]>0).long() == y[:, idx + n_train_var]).float().mean().item()
        
        total += 1
    
        loss.backward()
        optimizer.step()
    
    train_acc = [corr/total for corr in correct]

    l_train_loss.append(total_loss / total)
    l_train_acc.append(list(train_acc))

    return train_acc


def test():
    global l_test_acc, l_test_loss

    test_acc = []
    start = time.time()
    total_loss = 0
    correct = [0]*(n_var*2)
    total = 0
    model.eval()
    with torch.no_grad():
        for batch, labels, order in testloader:
    
            x = batch.cuda()
            y = labels.cuda()
            inv_order = order.permute(0, 2, 1).cuda()
            pred = model(x)
            ordered_pred = torch.bmm(inv_order, pred[:, 3:-1:5, :]).squeeze()
            
            for idx in test_var_pred:
                loss = criterion(ordered_pred[:,idx], y[:, idx].float())
                total_loss += loss.item() / len(test_var_pred)
                correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
                          
            total += 1
        
        test_acc = [corr/total for corr in correct]

        l_test_loss.append(total_loss / total)
        l_test_acc.append(list(test_acc))

    return test_acc

Using pad_token, but it is not set yet.


Moving model to device:  cuda
Moving model to device:  cuda


In [21]:
criterion = nn.BCEWithLogitsLoss().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

# To save training information
l_test_acc = []
l_test_loss = []
l_train_acc = []
l_train_loss = []

# test acc is evaluated for each variables, printed in the order long the chain
start = time.time()
for epoch in range(11):
    train()
    test()
    scheduler.step()
    with open('lego.pkl', 'wb') as file:
        pickle.dump(model, file)

    print('Time elapsed: %f s' %(time.time() - start), "Epoch :", epoch)
    if epoch%5 == 0 :
        print("TEST LOSS")
        print(l_test_loss[-1])
        print("TEST ACC")
        print(l_test_acc[-1])
        print("TRAIN LOSS")
        #print(l_train_loss)
        print("TRAIN ACC")
        #print(l_train_acc)

Time elapsed: 57.126478 s Epoch : 0
TEST LOSS
0.5506576357220183
TEST ACC
[0.9911250397562981, 0.7563750315457582, 0.600250031799078, 0.5346250217407942, 0.5025000246241689, 0.49850002117455006, 0.49837501905858517, 0.4905000180006027, 0.9907500334084034, 0.754625029861927, 0.5947500336915255, 0.5400000233203173, 0.5142500242218375, 0.5005000187084079, 0.4968750225380063, 0.4987500198185444]
TRAIN LOSS
TRAIN ACC
Time elapsed: 117.895962 s Epoch : 1
Time elapsed: 181.240324 s Epoch : 2
Time elapsed: 246.237274 s Epoch : 3
Time elapsed: 311.590572 s Epoch : 4
Time elapsed: 376.923397 s Epoch : 5
TEST LOSS
0.5321296484498816
TEST ACC
[0.9993750043213367, 0.7871250305324793, 0.6153750270605087, 0.5387500310316682, 0.5133750140666962, 0.5031250230967999, 0.5037500215694308, 0.49087501782923937, 0.9992500096559525, 0.7897500395774841, 0.6120000332593918, 0.5403750250115991, 0.4982500206679106, 0.502000018954277, 0.49450002051889896, 0.49137502163648605]
TRAIN LOSS
TRAIN ACC
Time elapsed: 442

In [6]:
# Print activation shapes at every layer for our model

embed_or_first_layer = lambda name: (name[:6] != "blocks" or name[:8] == "blocks.0")

def print_shape(tensor, hook):
    print(f"Activation at hook {hook.name} has shape:")
    print(tensor.shape)

random_tokens = torch.randint(1000, 10000, (4, 50))
logits = model.base.run_with_hooks(random_tokens, fwd_hooks=[(embed_or_first_layer, print_shape)])

Activation at hook hook_embed has shape:
torch.Size([4, 50, 64])
Activation at hook hook_pos_embed has shape:
torch.Size([4, 50, 64])
Activation at hook blocks.0.hook_resid_pre has shape:
torch.Size([4, 50, 64])
Activation at hook blocks.0.ln1.hook_scale has shape:
torch.Size([4, 50, 1])
Activation at hook blocks.0.ln1.hook_normalized has shape:
torch.Size([4, 50, 64])
Activation at hook blocks.0.attn.hook_q has shape:
torch.Size([4, 50, 12, 32])
Activation at hook blocks.0.attn.hook_k has shape:
torch.Size([4, 50, 12, 32])
Activation at hook blocks.0.attn.hook_v has shape:
torch.Size([4, 50, 12, 32])
Activation at hook blocks.0.attn.hook_attn_scores has shape:
torch.Size([4, 12, 50, 50])
Activation at hook blocks.0.attn.hook_attn has shape:
torch.Size([4, 12, 50, 50])
Activation at hook blocks.0.attn.hook_z has shape:
torch.Size([4, 50, 12, 32])
Activation at hook blocks.0.hook_attn_out has shape:
torch.Size([4, 50, 64])
Activation at hook blocks.0.hook_resid_mid has shape:
torch.Size

In [30]:
# Random tests

i = 0
for batch, labels, order in testloader :
    x = batch.cuda()
    y = labels.cuda()
    inv_order = order.permute(0, 2, 1).cuda()
    pred = model(x)
    ordered_pred = torch.bmm(inv_order, pred[:, 3:-1:5, :]).squeeze()
    if i == 0 :
        print("Target tensor :")
        print(y[0,:])
        #print(ordered_pred[:,0])
        #print(ordered_pred[:,0].size())
        print()
        print("Prediction tensor :")
        print(ordered_pred[0,:])
        #print(torch.argmax(ordered_pred[:,0].softmax(dim=-1),dim=-1))
        #print(torch.argmax(ordered_pred[:,0].softmax(dim=-1),dim=-1).size())
        #print(model.base.to_string(torch.argmax(ordered_pred[:,:].softmax(dim=-1),dim=-1)))
        i += 1

print('\n Exemple avec "val 1 = a,not a = b,not b = c,val c = d"')
tok = model.base.to_tokens("val 1 = a,not a = b,not b = c,val c = d")
pred = model(tok)
distilled = pred[:, 3:-1:5, :]
print("Résultat : ",distilled)
print((distilled>0).long())

RuntimeError: CUDA out of memory. Tried to allocate 3.75 GiB (GPU 0; 23.66 GiB total capacity; 21.21 GiB already allocated; 801.00 MiB free; 21.80 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [33]:
with open('lego.pkl', 'wb') as file:
    pickle.dump(model, file)

AttributeError: Can't pickle local object 'HookPoint.add_hook.<locals>.full_hook'

In [9]:
with open('lego.pkl', 'rb') as file:
    model = pickle.load(file)

In [32]:
print(sum(p.numel() for p in model.parameters()))
model

7843218


Model(
  (base): EasyTransformer(
    (embed): Embed()
    (hook_embed): HookPoint()
    (pos_embed): PosEmbed()
    (hook_pos_embed): HookPoint()
    (blocks): ModuleList(
      (0): TransformerBlock(
        (ln1): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (ln2): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (attn): Attention(
          (hook_k): HookPoint()
          (hook_q): HookPoint()
          (hook_v): HookPoint()
          (hook_z): HookPoint()
          (hook_attn_scores): HookPoint()
          (hook_attn): HookPoint()
          (hook_result): HookPoint()
        )
        (mlp): MLP(
          (hook_pre): HookPoint()
          (hook_post): HookPoint()
        )
        (hook_attn_out): HookPoint()
        (hook_mlp_out): HookPoint()
        (hook_resid_pre): HookPoint()
        (hook_resid_mid): HookPoint()
        (hook_resid_post): HookPoint()
      