# LLC estimation

In [None]:
import datasets
import math
import numpy as np

import torch
import torch.nn as nn
from tqdm import tqdm

from copy import deepcopy
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformerConfig
from transformer_lens.utils import lm_cross_entropy_loss
from transformer_lens.utils import tokenize_and_concatenate

In [None]:
BATCH_SIZE = 100

DATASET = 'oknMswoztTPaAVreBrWy/dsir-pile-1m-2'
DS_COL = 'contents'

MODEL_NAME = 'L1'

DEVICE = 'cuda'

In [None]:
from icl.language.model import get_model_cfg
from icl.language.utils import load_hf_checkpoint

model_cfgs = {}
model_cfgs[1] = get_model_cfg(num_layers=1)
model_cfgs[2] = get_model_cfg(num_layers=2)

model = HookedTransformer(model_cfgs[2])

In [None]:
def reformat_tokens_dataset_for_learning_coeff(tokens_dataset, batch_size):
  def custom_collate(batch):
    tokens = [item['tokens'] for item in batch]
    tokens_tensor = torch.stack(tokens)
    return [tokens_tensor, tokens_tensor.clone()]

  return torch.utils.data.DataLoader(tokens_dataset,
                                     batch_size=batch_size,
                                     collate_fn=custom_collate,
                                     shuffle=True,
                                     num_workers=2,
                                     pin_memory=True)

lc_dataset = datasets.load_dataset(DATASET,
                                   split='train')
lc_tokens_dataset = tokenize_and_concatenate(lc_dataset,
                                             model.tokenizer,
                                             streaming=False,
                                             max_length=model.cfg.n_ctx,
                                             column_name=DS_COL,
                                             add_bos_token=True,
                                             num_proc=12)
lc_dataset = reformat_tokens_dataset_for_learning_coeff(lc_tokens_dataset,
                                                        BATCH_SIZE)

print('num batches: ', len(lc_dataset))

In [None]:
import wandb

from icl.analysis.sgld import SGLD
from icl.language.llc import estimate_learning_coeff_with_summary

from transformer_lens.utils import lm_cross_entropy_loss

EPSILON = 0.003
GAMMA = 300

NUM_CHAINS = 20
NUM_DRAWS = 200

def llc_helper(model):
  optim_kwargs = dict(
      lr=EPSILON,
      noise_level=1.0,
      elasticity=GAMMA,
      num_samples=100,
      temperature="adaptive",
  )
  results = estimate_learning_coeff_with_summary(
      model=model,
      loader=lc_dataset,
      criterion=lm_cross_entropy_loss,
      sampling_method=SGLD,
      optimizer_kwargs=optim_kwargs,
      num_chains=NUM_CHAINS,
      num_draws=NUM_DRAWS,
      device=DEVICE,
      online=True,
  )
  return results


def estimate_checkpoint_llc(step, n_layers):
  model = load_hf_checkpoint(step, n_layers=n_layers)
  results = llc_helper(model)
  llc_mean = results['llc/means'][-1]
  llc_std = results['llc/stds'][-1]
  losses = results['loss/trace']
  wandb.log({
      "llc/mean": llc_mean,
      "llc/std": llc_std,
      **{
          f"loss/trace/{i}": losses[i] for i in range(len(losses))
      },
  }, step=step)

In [None]:
import wandb
WANDB_PROJECT = 'foo'
wandb.init(project="foo", entity="bar")
wandb.run.name = f"L1W256-llc"

torch.manual_seed(1)

# recommend parallelizing this
for step in range(0, 50000, 100):
  estimate_checkpoint_llc(step, n_layers=1)
  estimate_checkpoint_llc(step, n_layers=2)