In [130]:
from collections import defaultdict
import logging
from typing import cast, Dict, List, Tuple, Union
from typing_extensions import get_args, Literal
import sys
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import yaml
import argparse
import pandas as pd
from tqdm.notebook import tqdm
from functools import partial

from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl


from transformer_lens import HookedTransformer
from transformer_lens.utils import is_square
from transformer_lens.head_detector import (compute_head_attention_similarity_score, 
                      get_previous_token_head_detection_pattern, 
                      get_duplicate_token_head_detection_pattern,
                      get_induction_head_detection_pattern)



sys.path.append('/users/sanand14/data/sanand14/learning_dynamics/src/experiments/utils')
sys.path.append('/users/sanand14/data/sanand14/learning_dynamics/src/experiments')

from aheads import create_repeats_dataset


## TOY MODEL

In [131]:
class ToyModel(pl.LightningModule):
    def __init__(self, num_features: int, num_interm : int):
        super().__init__()
        self.save_hyperparameters()
        
        self.body = nn.Sequential(
            nn.Linear(num_features, num_interm),
            nn.ReLU(),
            nn.Linear(num_interm, 1, bias=False)
            )
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-2)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, logs = self.step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", logs["acc"], on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return {"loss": loss, "train_loss": logs["acc"]}
        return loss

    def validation_step(self, batch, batch_idx):
        loss, logs = self.step(batch)
        self.log("val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc", logs["acc"], on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return {"loss": loss, "val_acc": logs["acc"]}
    
    def step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = F.binary_cross_entropy(torch.sigmoid(logits), y)
        acc = ((logits.squeeze() > 0.5).float() == y.squeeze()).float().mean()
        return loss, {"loss": loss.item(), "acc": acc.item()}
    
    def forward(self, x):
        if isinstance(x, list):
            x, _ = x
        return self.body(x)

In [132]:
## generation hyperparameters

size = 10000
m1, m2, m3, m4 = 0, 1, 2, 3
s1, s2, s3, s4 = 1, 1, 1, 1
N, N_p, N_pp = 0, 2, 3
vec_size = 4

epochs = 1
device = "cuda" if torch.cuda.is_available() else "cpu"

In [133]:
def generate_gaussian_inputs(m1, s1, m2, s2, m3, s3, m4, s4, vec_size, N, N_p, N_pp, init_task=True):
    f1 = torch.normal(mean=torch.tensor(m1).repeat((1, vec_size)).float(), std=s1)
    f2 = torch.normal(mean=torch.tensor(m2).repeat((1, vec_size)).float(), std=s2)
    f3 = torch.normal(mean=torch.tensor(m3).repeat((1, vec_size)).float(), std=s3)
    f4 = torch.normal(mean=torch.tensor(m4).repeat((1, vec_size)).float(), std=s4)
    if init_task:
        full = f1
        label = (torch.mean(f1) > N)
        alt_label = (torch.mean(f1) > N)
    else:
        full = torch.concat([f1, f2, f3, f4], dim=1) 
        label = (torch.mean(f3) > N_p) if (torch.mean(f1) > N) else (torch.mean(f4) > N_pp)
        alt_label = (torch.mean(f3) > N_p) if (torch.mean(f2) > N) else (torch.mean(f4) > N_pp)
    return full, label, alt_label

In [134]:
inputs = []
labels = []
alt_labels = []

for i in (range(size)):
    input, label, alt_label = generate_gaussian_inputs(m1, s1, m2, s2, m3, s3, m4, s4, vec_size, N, N_p, N_pp)
    inputs.append(input)
    labels.append(label)
    alt_labels.append(alt_label)
    
inputs = torch.vstack(inputs)
labels = torch.vstack(labels).float()
alt_labels = torch.vstack(alt_labels)

In [135]:
train_size = int(0.8 * len(inputs))

inputs_t, labels_t, alt_labels_t = inputs[:train_size], labels[:train_size], alt_labels[:train_size]
inputs_v, labels_v, alt_labels_v = inputs[train_size:], labels[train_size:], alt_labels[train_size:]

train_dataset = TensorDataset(inputs_t.detach(), labels_t.view(-1, 1))
val_dataset = TensorDataset(inputs_v.detach(), labels_v.view(-1, 1))

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

toy_model = ToyModel(inputs_t.shape[1], 128).to(device)
trainer = Trainer(max_epochs=epochs)
trainer.fit(toy_model, train_dataloader, val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name | Type       | Params
------------------------------------
0 | body | Sequential | 768   
------------------------------------
768       Trainable params
0         Non-trainable params
768       Total params
0.003     Total estimated model params size (MB)


Epoch 0:  80%|███████▉  | 250/313 [00:00<00:00, 264.92it/s, loss=0.0347, v_num=13, train_loss_step=0.0111, train_acc_step=1.000]  
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/63 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/63 [00:00<?, ?it/s][A
Epoch 0:  80%|████████  | 251/313 [00:00<00:00, 263.72it/s, loss=0.0347, v_num=13, train_loss_step=0.0111, train_acc_step=1.000]
Epoch 0:  81%|████████  | 252/313 [00:00<00:00, 263.92it/s, loss=0.0347, v_num=13, train_loss_step=0.0111, train_acc_step=1.000]
Epoch 0:  81%|████████  | 253/313 [00:00<00:00, 264.16it/s, loss=0.0347, v_num=13, train_loss_step=0.0111, train_acc_step=1.000]
Epoch 0:  81%|████████  | 254/313 [00:00<00:00, 264.38it/s, loss=0.0347, v_num=13, train_loss_step=0.0111, train_acc_step=1.000]
Epoch 0:  81%|████████▏ | 255/313 [00:00<00:00, 264.61it/s, loss=0.0347, v_num=13, train_loss_step=0.0111, train_acc_step=1.000]
Epoch 0:  82%|████████▏ | 256/313 [00:00<00:00, 264.83it/s, loss=0.03

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 313/313 [00:01<00:00, 249.27it/s, loss=0.0347, v_num=13, train_loss_step=0.0111, train_acc_step=1.000, val_loss_step=0.00144, val_acc_step=1.000, val_loss_epoch=0.0326, val_acc_epoch=0.985, train_loss_epoch=0.0695, train_acc_epoch=0.969]


In [137]:
trainer.validate(toy_model, val_dataloader)

Validation DataLoader 0: 100%|██████████| 63/63 [00:00<00:00, 220.62it/s]


[{'val_loss_epoch': 0.03262592479586601, 'val_acc_epoch': 0.9854999780654907}]

In [142]:
toy_model.body[2]

Linear(in_features=128, out_features=1, bias=False)

## TRANSPLANTATION

In [5]:
PYTHIA_VOCAB_SIZE = 50277 #50304
N_LAYERS=12
MODEL = "EleutherAI/pythia-160m"
PYTHIA_CHECKPOINTS_OLD = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(range(1000, 143000 + 1, 10000)) + [143000]
PYTHIA_CHECKPOINTS = [512] + list(range(1000, 10000 + 1, 1000))

HeadName = Literal["previous_token_head", "duplicate_token_head", "induction_head"]
HEAD_NAMES = cast(List[HeadName], get_args(HeadName))

In [6]:
def create_repeats_dataset(num_samples=50, min_vector_size=5, max_vector_size=50, min_num_repeats=5, max_num_repeats=20, max_vocab=PYTHIA_VOCAB_SIZE):
  """Creates a dataset for the experiment."""
  dataset = []
  for _ in range(num_samples):
    vector_size = torch.randint(min_vector_size, max_vector_size, (1,)).item()
    num_repeats = torch.randint(min_num_repeats, max_num_repeats, (1,)).item()
    tokens = torch.randint(0, max_vocab, (1, vector_size))
    tokens = tokens.repeat((1, num_repeats))
    dataset.append(tokens)
  return dataset

In [7]:
dataset = torch.load('../outputs/aheads/dataset.pt')

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
def copy_attention_head(model1, model2, layer_idx, head_idx, dataset):
  if model1.isinstance(HookedTransformer) and model2.isinstance(HookedTransformer):
    model1.W_K.data[layer_idx, head_idx, :, :] = model2.W_K.data[layer_idx, head_idx, :, :]
    model1.W_Q.data[layer_idx, head_idx, :, :] = model2.W_Q.data[layer_idx, head_idx, :, :]
    model1.W_V.data[layer_idx, head_idx, :, :] = model2.W_V.data[layer_idx, head_idx, :, :]
    model1.b_K.data[layer_idx, head_idx, :] = model2.b_K.data[layer_idx, head_idx, :]
    model1.b_Q.data[layer_idx, head_idx, :] = model2.b_Q.data[layer_idx, head_idx, :]
    model1.b_V.data[layer_idx, head_idx, :] = model2.b_V.data[layer_idx, head_idx, :]
  else:
    model1.encoder.layers[layer_idx].self_attn.in_proj_weight.data[head_idx,:,:] = model2.encoder.layers[layer_idx].self_attn.in_proj_weight.data[head_idx,:,:]
  return perplexity(model1, dataset), perplexity(model2, dataset)


In [13]:
def calculate_perplexity(corpus, model, device="cpu"):
    encoded_input = model.to_tokens(corpus)
    encoded_input = encoded_input.to(device)
    with torch.no_grad():
      outputs = model(encoded_input).squeeze(0)
      loss = F.cross_entropy(outputs, encoded_input.squeeze(0), reduction='sum')/encoded_input.shape[1]
    perplexity = torch.exp(loss).item()
    return perplexity

In [15]:
def perplexity(model, dataset):
  data_loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False)
  with torch.no_grad():
    for batch in data_loader:
      inputs, targets = batch
      outputs = model(inputs)
      loss = F.cross_entropy(outputs, targets, reduction='sum')
      total_loss += loss.item()
    average_loss = total_loss / len(data_loader.dataset)
    return torch.exp(torch.tensor(average_loss))