In [None]:
import datasets
import math
import numpy as np

import torch
import torch.nn as nn
from tqdm import tqdm

from copy import deepcopy
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformerConfig
from transformer_lens.utils import lm_cross_entropy_loss
from transformer_lens.utils import tokenize_and_concatenate

from transformer_lens import HookedTransformerConfig, HookedTransformer

DEVICE = 'cuda'

In [None]:
from icl.language.model import get_model_cfg
from icl.language.utils import load_hf_checkpoint

model_cfgs = {}
model_cfgs[1] = get_model_cfg(num_layers=1)
model_cfgs[2] = get_model_cfg(num_layers=2)

model = HookedTransformer(model_cfgs[2])

In [None]:
L1_models = {}
L2_models = {}

for step in tqdm(range(0, 50_001, 100)):
    L1_models[step] = load_hf_checkpoint(step, n_layers=1)
    L2_models[step] = load_hf_checkpoint(step, n_layers=2)

In [None]:
# model weight norms

L1_norms = []
for step in tqdm(range(0, 50001, 100)):
  model = L1_models[step]
  norm = np.sum([np.sum((p.detach().cpu().numpy() ** 2)) for p in model.parameters()]) ** 0.5
  L1_norms.append(norm)

L2_norms = []
for step in tqdm(range(0, 50001, 100)):
  model = L2_models[step]
  norm = np.sum([np.sum((p.detach().cpu().numpy() ** 2)) for p in model.parameters()]) ** 0.5
  L2_norms.append(norm)

In [None]:
# pca on model positional embedding weights
import numpy as np
from sklearn.decomposition import PCA


def pca_model_weights(model):
  W_pos = model.pos_embed.W_pos.detach().cpu().numpy()

  pca = PCA(n_components=3)
  pca.fit(W_pos)
  data = pca.transform(W_pos)
  pca_results = {
    'pca': pca,
    'fit_data': data,
    'ex_var': pca.explained_variance_ratio_,
  }
  return pca_results


In [None]:
L1_results = []
L2_results = []

for step in range(0, 50001, 100):
  L1_results.append(pca_model_weights(L1_models[step]))
  L2_results.append(pca_model_weights(L2_models[step]))

In [None]:
# L2 norm of positional embedding weights
def pos_embed_magnitudes(model):
  W_pos = model.pos_embed.W_pos.detach().cpu().numpy()
  magnitudes = []
  for i in range(1024):
    magnitudes.append(np.linalg.norm(W_pos[i, :]))
  return np.array(magnitudes)

L1_mags = []
L2_mags = []
for step in range(0, 50001, 100):
  L1_mags.append(pos_embed_magnitudes(L1_models[step]))
  L2_mags.append(pos_embed_magnitudes(L2_models[step]))