In [None]:
import datasets
import math
import numpy as np

import torch
import torch.nn as nn
from tqdm import tqdm

from copy import deepcopy
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformerConfig
from transformer_lens.utils import lm_cross_entropy_loss
from transformer_lens.utils import tokenize_and_concatenate

from transformer_lens import HookedTransformerConfig, HookedTransformer

DEVICE = 'cuda'

In [None]:
from icl.language.model import get_model_cfg
from icl.language.utils import load_hf_checkpoint

model_cfgs = {}
model_cfgs[1] = get_model_cfg(num_layers=1)
model_cfgs[2] = get_model_cfg(num_layers=2)

model = HookedTransformer(model_cfgs[2])

In [None]:
# small tokenizer
tokenizer_tiny = model.tokenizer

# big tokenizer
model_cfg_big = HookedTransformerConfig(
    n_layers=2,
    d_model=256,
    d_head=32,
    n_heads=8,
    n_ctx=1024,
    tokenizer_name='oknMswoztTPaAVreBrWy/GPT2-tokenizer',
    normalization_type='LN',
    attn_only=True,
    seed=1,
    positional_embedding_type='shortformer',
)
model_big = HookedTransformer(model_cfg_big)
tokenizer_big = model_big.tokenizer

In [None]:
def translate_int_to_str(token_int):
    t = torch.tensor([[token_int]], device=DEVICE)
    return model.to_str_tokens(t)

def translate_int_to_str_big(token_int):
    t = torch.tensor([[token_int]], device=DEVICE)
    return model_big.to_str_tokens(t)

big_vocab = {}
# get plaintext vocabs
for i in tqdm(range(len(tokenizer_big.vocab))):
  token = translate_int_to_str_big(i)
  big_vocab[i] = token[0]
  big_vocab[token[0]] = i

In [None]:
from collections import defaultdict

d_vocab_long = 50257

# tokenize big vocab list
tokenized_vocab = {}
length_counter = defaultdict(int)

for i in tqdm(range(d_vocab_long)):
  token = big_vocab[i]
  tokenized = tokenizer_tiny(token)
  tokenized = tokenized['input_ids']
  tokenized_vocab[i] = tokenized
  length_counter[len(tokenized)] += 1

In [None]:
def n_grams(n=2, max_vocab=None, exclude_spaces=True, add_bos_token=False):
  output = []
  for i in range(d_vocab_long):
    if max_vocab and i == max_vocab - 1:
      break
    tokenized = tokenized_vocab[i]
    if len(tokenized) == n:
      if exclude_spaces and 220 in tokenized:
        continue
      if add_bos_token:
        tokenized = [4999] + tokenized
      output.append((big_vocab[i], tokenized))
  return output

In [None]:
two_grams = n_grams(n=2, add_bos_token=True)
three_grams = n_grams(n=3, add_bos_token=True)
four_grams = n_grams(n=4, add_bos_token=True)
len(three_grams)
three_grams[:5]

In [None]:
def n_gram_loss(model, n_grams, max_n_grams=1000):
  losses = []
  for n_gram in tqdm(n_grams[:max_n_grams]):
    tokens = n_gram[1]
    t = torch.tensor([tokens], device=model.cfg.device)
    logits = model(t).detach()
    loss = lm_cross_entropy_loss(logits, t, per_token=True)
    losses.append(loss[0][-1].cpu().item())
  return np.mean(losses)

In [None]:
from collections import defaultdict

x_steps = list(range(0, 50001, 100))
losses = defaultdict(list)
loss_pairs = []
for n in range(2, 5):
  for layer in range(1, 3):
    loss_pairs.append((n, layer))

all_grams = {}
all_grams[2] = two_grams
all_grams[3] = three_grams
all_grams[4] = four_grams

for step in x_steps:
  print(step)
  model = load_hf_checkpoint(step)
  for n in range(2, 5):
    for layer in range(1, 3):
      model = load_hf_checkpoint(step, n_layers=layer)
      grams = all_grams[n]
      losses[(n, layer)].append(np.mean(n_gram_loss(model, grams)))

# one layer and two layer versions were originally run separately, 
# so either save this as both L1-n-grams.pkl and L2-n-grams.pkl or 
# update the imports for figure generation