In [None]:
import wandb
import datasets
import math
import numpy as np

import torch
import torch.nn as nn
from tqdm import tqdm

from copy import deepcopy
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformerConfig
from transformer_lens.utils import lm_cross_entropy_loss
from transformer_lens.utils import tokenize_and_concatenate

from transformer_lens import HookedTransformerConfig, HookedTransformer

DEVICE = 'cuda'

In [None]:
from icl.language.model import get_model_cfg
from icl.language.utils import load_hf_checkpoint

model_cfgs = {}
model_cfgs[1] = get_model_cfg(num_layers=1)
model_cfgs[2] = get_model_cfg(num_layers=2)

model = HookedTransformer(model_cfgs[2])

In [None]:
dataset_name = 'oknMswoztTPaAVreBrWy/dsir-pile-100k'
dataset_col_name = 'contents'

dataset = datasets.load_dataset(dataset_name,
                                split='train')
tokens_dataset = tokenize_and_concatenate(dataset,
                                         model.tokenizer,
                                         streaming=False,
                                         max_length=model.cfg.n_ctx,
                                         column_name=dataset_col_name,
                                         add_bos_token=True,
                                         num_proc=12)

data_loader = torch.utils.data.DataLoader(tokens_dataset,
                                          batch_size=32,
                                          shuffle=False)
len(data_loader)

### Prefix matching

In [None]:
import random


class Sample:
    def __init__(self):
        self.sample = []
        self.special_token = None
        self.first_idx = None
        self.second_idx = None
        self.attn_idx = None


def generate_synthetic_prefix_matching_sample(context_length, d_vocab):
    bos_token = d_vocab - 1
    special_token = random.randint(0, d_vocab - 2)

    # choose locations for the first and second instances of the special token
    first_idx = random.randint(1, context_length // 2 - 1)
    second_idx = random.randint(context_length // 2 + 1, context_length - 1)
    # this is where we want our head to attend to
    attn_idx = first_idx + 1

    sample = Sample()
    for idx in range(context_length):
        if idx == 0:
            sample.sample.append(bos_token)
        elif idx == first_idx or idx == second_idx:
            sample.sample.append(special_token)
        else:
            # draw until you don't get the special token
            rand_token = special_token
            while rand_token == special_token:
                rand_token = random.randint(0, d_vocab - 2)
            sample.sample.append(rand_token)

    sample.sample = torch.unsqueeze(torch.tensor(sample.sample, dtype=torch.int64), 0)
    sample.special_token = special_token
    sample.first_idx = first_idx
    sample.second_idx = second_idx
    sample.attn_idx = attn_idx
    return sample


def get_prefix_attns(model, n_layers, sample):
    '''Assumes batch size 1'''
    _, cache = model.run_with_cache(sample.sample)
    prefix_attns = [[] for _ in range(n_layers)]

    for layer in range(n_layers):
        layer_cache = cache[f'blocks.{layer}.attn.hook_pattern'][0]
        for attn_head in layer_cache:
            attn_pattern = attn_head[sample.second_idx]
            prefix_attn = attn_pattern[sample.attn_idx]
            prefix_attns[layer].append(prefix_attn.cpu().item())
    return prefix_attns

In [None]:
import random

SEED = 0
D_VOCAB = 5000
NUM_SAMPLES = 10000
SAMPLE_LEN = 128

MAX_CHECKPOINT = 50000

def get_checkpoint_prefix_score(step, n_layers, d_vocab, num_samples, sample_len):
    model = load_hf_checkpoint(step, n_layers=n_layers)
    random.seed(SEED)
    all_prefix_attns = []
    for _ in tqdm(range(num_samples)):
        sample = generate_synthetic_prefix_matching_sample(sample_len, d_vocab)
        prefix_attns = get_prefix_attns(model, n_layers, sample)
        all_prefix_attns.append(prefix_attns)
    all_prefix_attns = np.array(all_prefix_attns)
    return all_prefix_attns.mean(axis=0)

L1_prefix_scores = []
L2_prefix_scores = []
for step in range(0, MAX_CHECKPOINT + 1, 100):
    L1_scores = get_checkpoint_prefix_score(step, 1, D_VOCAB, NUM_SAMPLES, SAMPLE_LEN)
    L1_prefix_scores.append(L1_scores)
    L2_scores = get_checkpoint_prefix_score(step, 2, D_VOCAB, NUM_SAMPLES, SAMPLE_LEN)
    L2_prefix_scores.append(L2_scores)

L1_prefix_scores = np.array(L1_prefix_scores)
L2_prefix_scores = np.array(L2_prefix_scores)

### Previous token matching

In [None]:
import random

class Sample:
    def __init__(self):
        self.sample = []

def generate_synthetic_prev_matching_sample(context_length, d_vocab):
    bos_token = d_vocab - 1

    sample_len = random.randint(context_length // 4, context_length)
    sample = [bos_token] + [random.randint(0, d_vocab - 2) for _ in range(sample_len - 1)]
    return sample

def get_prev_attns(model, n_layers, sample):
    '''Assumes batch size 1'''
    sample = torch.tensor(sample, device=model.cfg.device)
    _, cache = model.run_with_cache(sample)
    prev_attns = [[] for _ in range(n_layers)]
    for layer in range(n_layers):
        layer_cache = cache[f'blocks.{layer}.attn.hook_pattern'][0]
        for attn_head in layer_cache:
            attn_pattern = attn_head[-1]
            prev_attn = attn_pattern[-2]
            prev_attns[layer].append(prev_attn.cpu().item())
    return prev_attns


In [None]:
import random

D_VOCAB = 5000
NUM_SAMPLES = 10_000
MAX_SAMPLE_LEN = 64

MAX_CHECKPOINT = 50_000

def get_checkpoint_prev_score(step, n_layers):
    model = load_hf_checkpoint(step, n_layers=n_layers)
    random.seed(0)
    all_prev_attns = []
    for _ in tqdm(range(NUM_SAMPLES)):
        sample = generate_synthetic_prev_matching_sample(MAX_SAMPLE_LEN, D_VOCAB)
        prev_attns = get_prev_attns(model, n_layers, sample)
        all_prev_attns.append(prev_attns)
    all_prev_attns = np.array(all_prev_attns)
    return all_prev_attns.mean(axis=0)

L1_prev_scores = []
L2_prev_scores = []
x_steps = list(range(0, MAX_CHECKPOINT + 1, 100))

for step in x_steps:
    L1_scores = get_checkpoint_prev_score(step, n_layers=1)
    L1_prev_scores.append(L1_scores)
    L2_scores = get_checkpoint_prev_score(step, n_layers=2)
    L2_prev_scores.append(L2_scores)

L1_prefix_scores = np.array(L1_prefix_scores)
L2_prefix_scores = np.array(L2_prefix_scores)