In [None]:
import wandb
import datasets
import math
import numpy as np

import torch
import torch.nn as nn
from tqdm import tqdm

from copy import deepcopy
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformerConfig
from transformer_lens.utils import lm_cross_entropy_loss
from transformer_lens.utils import tokenize_and_concatenate

from transformer_lens import HookedTransformerConfig, HookedTransformer

DEVICE = 'cuda'

In [None]:
from icl.language.model import get_model_cfg
from icl.language.utils import load_hf_checkpoint

model_cfgs = {}
model_cfgs[1] = get_model_cfg(num_layers=1)
model_cfgs[2] = get_model_cfg(num_layers=2)

model = HookedTransformer(model_cfgs[2])

In [None]:
dataset_name = 'oknMswoztTPaAVreBrWy/dsir-pile-100k'
dataset_col_name = 'contents'

dataset = datasets.load_dataset(dataset_name,
                                split='train')
tokens_dataset = tokenize_and_concatenate(dataset,
                                         model.tokenizer,
                                         streaming=False,
                                         max_length=model.cfg.n_ctx,
                                         column_name=dataset_col_name,
                                         add_bos_token=True,
                                         num_proc=12)

data_loader = torch.utils.data.DataLoader(tokens_dataset,
                                          batch_size=32,
                                          shuffle=False)
len(data_loader)

In [None]:
def icl_score(model, batch):
    '''Score is loss of 500th token minus loss of 50th token'''
    device = model.cfg.device
    tokens = batch['tokens'].to(device)
    logits = model(tokens).detach()
    loss = lm_cross_entropy_loss(logits, tokens, per_token = True)
    return loss[:, 498].mean() - loss[:, 48].mean()  # offset by 1 bc it's predicting the next token

def score_checkpoint(step, data_loader, n_layers):
  model = load_hf_checkpoint(step, n_layers=n_layers)
  icl_scores = []
  for batch in tqdm(data_loader):
    score = icl_score(model, batch)
    icl_scores.append(score.cpu().item())
  return np.mean(icl_scores)

In [None]:
icl_scores = []
for step in range(0, 50001, 100):
  score = score_checkpoint(step, data_loader, n_layers=2)
  icl_scores.append(score)