In [None]:
import wandb
import datasets
import math
import numpy as np

import torch
import torch.nn as nn
from tqdm import tqdm

from copy import deepcopy
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformerConfig
from transformer_lens.utils import lm_cross_entropy_loss
from transformer_lens.utils import tokenize_and_concatenate

from transformer_lens import HookedTransformerConfig, HookedTransformer

DEVICE = 'cuda'

In [None]:
from icl.language.model import get_model_cfg
from icl.language.utils import load_hf_checkpoint

model_cfgs = {}
model_cfgs[1] = get_model_cfg(num_layers=1)
model_cfgs[2] = get_model_cfg(num_layers=2)

model = HookedTransformer(model_cfgs[2])

In [None]:
dataset_name = 'oknMswoztTPaAVreBrWy/dsir-pile-100k'
dataset_col_name = 'contents'

dataset = datasets.load_dataset(dataset_name,
                                split='train')
tokens_dataset = tokenize_and_concatenate(dataset,
                                         model.tokenizer,
                                         streaming=False,
                                         max_length=model.cfg.n_ctx,
                                         column_name=dataset_col_name,
                                         add_bos_token=True,
                                         num_proc=12)

data_loader = torch.utils.data.DataLoader(tokens_dataset,
                                          batch_size=32,
                                          shuffle=False)
len(data_loader)

In [None]:
LOSS_POSITIONS = [0, 1, 2, 3, 4, 5, 6, 7, 10, 20, 30, 50, 100, 200, 300, 500, 1000]
NUM_BATCHES = 313

In [None]:
import einops
from collections import defaultdict
from transformer_lens.utils import get_act_name

def pos_ablation_hook(value, hook):
  value[:, :] = 0.

def ablated_logits(model, tokens):
  return model.run_with_hooks(tokens,
                              fwd_hooks=[(get_act_name('pos_embed'),
                                          pos_ablation_hook)])

def compute_losses(model, data_loader, loss_dict, ablate_pos=False):
  batch_count = 0
  device = model.cfg.device
  checkpoint_losses = defaultdict(list)
  for batch in tqdm(data_loader, total=NUM_BATCHES):
    if batch_count >= NUM_BATCHES:
      break
    batch_count += 1
    tokens = batch['tokens'].to(device)
    if ablate_pos:
      logits = ablated_logits(model, tokens).detach()
    else:
      logits = model(tokens).detach()
    losses = lm_cross_entropy_loss(logits, tokens, per_token=True)
    losses = einops.reduce(losses, 'batch pos -> pos', 'mean')
    mean_loss = einops.reduce(losses, 'pos ->', 'mean')
    checkpoint_losses[-1].append(mean_loss.item()) # save mean loss at -1
    for pos in LOSS_POSITIONS:
      checkpoint_losses[pos].append(losses[pos].item())
  for pos in [-1] + LOSS_POSITIONS:
    loss_dict[pos].append(np.mean(checkpoint_losses[pos]))
  return loss_dict


In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm

def pos_color(pos):
  val = np.log(pos+1) / np.log(1024)
  return cm.viridis(val)

def plot_losses(loss_dict):
  x_steps = np.array(list(range(0, 50001, 100))) + 1
  for pos in LOSS_POSITIONS:
    vals = loss_dict[pos]
    color = pos_color(pos)
    plt.plot(x_steps[:len(vals)], vals, alpha=0.4, color=color)
  vals = loss_dict[-1]
  color = 'tab:red'
  plt.plot(x_steps[:len(vals)], vals, color=color)
  plt.axvline(x=800, linestyle=':', color='black', alpha=0.5)
  plt.axvline(x=6500, linestyle=':', color='black', alpha=0.5)
  plt.axvline(x=8500, linestyle=':', color='black', alpha=0.5)
  plt.xscale('log')
  plt.show()


x_steps = list(range(0, 50001, 100))

L1_losses = defaultdict(list)
L1_ablated_losses = defaultdict(list)

# one layer model
for step in x_steps:
  model = load_hf_checkpoint(step, n_layers=1)
  L1_losses = compute_losses(model, data_loader, L1_losses, ablate_pos=False)
  L1_ablated_losses = compute_losses(model, data_loader, L1_ablated_losses, ablate_pos=True)
  

L2_losses = defaultdict(list)
L2_ablated_losses = defaultdict(list)

# two layer model
for step in x_steps:
  model = load_hf_checkpoint(step, n_layers=1)
  L2_losses = compute_losses(model, data_loader, L2_losses, ablate_pos=False)
  L2_ablated_losses = compute_losses(model, data_loader, L2_ablated_losses, ablate_pos=True)


In [None]:
# save your files somewhere when done!