In [None]:
import datasets
import math
import numpy as np

import torch
import torch.nn as nn
from tqdm import tqdm

from copy import deepcopy
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformerConfig
from transformer_lens.utils import lm_cross_entropy_loss
from transformer_lens.utils import tokenize_and_concatenate

from transformer_lens import HookedTransformerConfig, HookedTransformer

DEVICE = 'cuda'

In [None]:
from icl.language.model import get_model_cfg
from icl.language.utils import load_hf_checkpoint

model_cfgs = {}
model_cfgs[1] = get_model_cfg(num_layers=1)
model_cfgs[2] = get_model_cfg(num_layers=2)

model = HookedTransformer(model_cfgs[2])

In [None]:
import pickle
from icl.constants import ANALYSIS

data_path = ANALYSIS / 'language'

# this file is not included due to size constraints  
with open(data_path / 'bigram_freq_percents.pkl', 'rb') as file:
  bigram_freq_percents = pickle.load(file)


In [None]:
dataset_name = 'oknMswoztTPaAVreBrWy/dsir-pile-100k'
dataset_col_name = 'contents'

dataset = datasets.load_dataset(dataset_name,
                                split='train')
tokens_dataset = tokenize_and_concatenate(dataset,
                                         model.tokenizer,
                                         streaming=False,
                                         max_length=model.cfg.n_ctx,
                                         column_name=dataset_col_name,
                                         add_bos_token=True,
                                         num_proc=12)

data_loader = torch.utils.data.DataLoader(tokens_dataset,
                                          batch_size=32,
                                          shuffle=False)
len(data_loader)

In [None]:
# compute optimal model performance, which is the average bigram entropy

import random
from scipy.stats import entropy

def bigram_entropy(model, data_loader, num_rows=10000, seed=0):
    '''Assumed data loader is shuffled'''
    random.seed(seed)
    device = model.cfg.device
    entropies = []
    for batch in tqdm(data_loader):
        if len(entropies) >= num_rows:
            break
        tokens = batch['tokens'].to(device)
        batch_logits = model(tokens).detach()
        for i, logits in enumerate(batch_logits):
            if len(entropies) >= num_rows:
                break
            idx = random.randint(1, len(logits)-1)
            curr_token = tokens[i][idx].item()
            entropies.append(entropy(bigram_freq_percents[curr_token]))
    return np.mean(entropies)


model = load_hf_checkpoint(0)
optimal_score = bigram_entropy(model, data_loader)
optimal_score

In [None]:
import random

def bigram_difference(model, data_loader, num_rows=10000, seed=0):
    '''Data loader must not be shuffled'''
    random.seed(seed)
    device = model.cfg.device
    cross_entropies = []
    for batch in tqdm(data_loader):
        if len(cross_entropies) >= num_rows:
            break
        tokens = batch['tokens'].to(device)
        batch_logits = model(tokens).detach()
        for i, logits in enumerate(batch_logits):
            if len(cross_entropies) >= num_rows:
                break
            idx = random.randint(1, len(logits)-1)
            curr_token = tokens[i][idx].item()
            target_dist = torch.tensor(bigram_freq_percents[curr_token], dtype=torch.float32, device=model.cfg.device)
            ce = F.cross_entropy(logits[idx], target_dist).detach().cpu().item()
            cross_entropies.append(ce)
    return np.mean(cross_entropies)

In [None]:
ces_L1 = []
ces_L2 = []

x_steps = list(range(0, 50001, 100))

for step in x_steps:
  model_L1 = load_hf_checkpoint(step, n_layers=1)
  model_L2 = load_hf_checkpoint(step, n_layers=2)
  ces_L1.append(bigram_difference(model_L1, data_loader, num_rows=5000))
  ces_L2.append(bigram_difference(model_L2, data_loader, num_rows=5000))
