Note: this will get per-token logits for checkpoints 100 steps apart, which is a total of 501 checkpoints. In order to get 5001 datapoints to get the denser ED graph, adjust the training script to save every 10 steps instead.

In [None]:
import datasets
import einops
import math
import numpy as np

import torch
import torch.nn as nn
from tqdm import tqdm

from copy import deepcopy
# from dataclasses import dataclass
from fancy_einsum import einsum
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformerConfig
from transformer_lens.utils import lm_cross_entropy_loss
from transformer_lens.utils import tokenize_and_concatenate

In [None]:
from icl.language.model import get_model_cfg
from icl.language.utils import load_hf_checkpoint

model_cfgs = {}
model_cfgs[1] = get_model_cfg(num_layers=1)
model_cfgs[2] = get_model_cfg(num_layers=2)

model = HookedTransformer(model_cfgs[2])

In [None]:
dataset_name = 'oknMswoztTPaAVreBrWy/dsir-pile-100k'
dataset_col_name = 'contents'

dataset = datasets.load_dataset(dataset_name,
                                split='train')
tokens_dataset = tokenize_and_concatenate(dataset,
                                         model.tokenizer,
                                         streaming=False,
                                         max_length=model.cfg.n_ctx,
                                         column_name=dataset_col_name,
                                         add_bos_token=True,
                                         num_proc=12)

data_loader = torch.utils.data.DataLoader(tokens_dataset,
                                          batch_size=32,
                                          shuffle=False)
len(data_loader)

In [None]:
import random

def per_token_logit(model, data_loader, num_tokens=10000, seed=0):
    '''Data loader must not be shuffled'''
    random.seed(seed)
    device = model.cfg.device
    per_token_logits = []
    for batch in tqdm(data_loader):
        if len(per_token_logits) >= num_tokens:
            break
        tokens = batch['tokens'].to(device)
        batch_logits = model(tokens).detach()
        for i, logits in enumerate(batch_logits):
          if len(per_token_logits) >= num_tokens:
            break
          idx = random.randint(0, len(logits)-2)
          true_next_token = tokens[i][idx+1].cpu().item()
          per_token_logits.append(logits[idx][true_next_token].cpu().item())
    return per_token_logits

In [None]:
L1_all_per_token_logits = []
L2_all_per_token_logits = []

for step in range(0, 50_000 + 1, 10):
  L1_per_token_logits = per_token_logit(load_hf_checkpoint(step, n_layers=1), data_loader)
  L2_per_token_logits = per_token_logit(load_hf_checkpoint(step, n_layers=2), data_loader)
  L1_all_per_token_logits.append(L1_per_token_logits)
  L2_all_per_token_logits.append(L2_per_token_logits)