In [1]:
import os
os.chdir('/workspace/FutureGPT2/src/')
from evals.utils import *
from models.bigram_model import *
from models.mlp_model import *
from models.future_model import *
from data.utils import get_tokenizer

from tqdm import tqdm
import pandas as pd
import gc
from glob import glob

In [2]:
# Suppress warnings (there are a lot b/c of how the checkpoints are stored...)
import warnings
warnings.filterwarnings('ignore')

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
gpt2_tokenizer = get_tokenizer('gpt2')
mistral_tokenizer = get_tokenizer('mistralai/Mistral-7B-v0.1')

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

In [None]:
def load_model(model_cls, ckpt_path, kwargs):
    return getattr(model_cls, 'load_from_checkpoint')(ckpt_path, strict=False, **kwargs)

In [None]:
def gen_futures(prompt, models, tokenizer, max_tokens=1, topk=10):
    for name, args in {**models}.items():
        print('MODEL:', name)
        model = load_model(*args)
        # generate_future from FutureGPT2/src/evals/utils.py
        print(generate_future(prompt, model, tokenizer, max_tokens=max_tokens, topk=topk).T)
    
        # Free up GPU memory
        del model
        gc.collect()
        torch.cuda.empty_cache()

## Neck hdim sweep (GPT2)

neck_h4^i is 2-layer neck with 4^i hidden dimension

In [None]:
models = {
    f'neck_h4^{i}': (
        LitFutureModelWithNeck, 
        f'/workspace/checkpoints/fc_neck_sweep/h4_{i}.ckpt', 
        {'neck_cls': 'mlp', 'hidden_idxs': 12, 'hidden_lb': 0, 'token_lb': 0}
    )
    for i in range(8)
}

In [None]:
prompt = 'Alice is visiting Japan, so she\'s exchanging her dollars for'
# next token prediction; top10:
gen_futures(prompt, models, gpt2_tokenizer, max_tokens=1, topk=10)
# next 20 tokens, top1:
# gen_futures(prompt, models, 20, 1)

## Finetune sweep (GPT2)

In [None]:
get_finetune_model = lambda blr, kappa: glob(
        f'/workspace/checkpoints/FINETUNE-KAPPA-SWEEP_*_base_lr-{blr}_kappa-{kappa}_*.ckpt',
)[0]
kappas = ['0.0001', '0.001', '0.01', '0.1', '1']
blrs = ['1e-05']   # base learning rate

models = {
    f'finetune_blr:{blr}_k:{k}': (
        LitFutureModelWithNeck, 
        get_finetune_model(blr, k), 
        {'neck_cls': 'mlp', 'hidden_idxs': 12, 'hidden_lb': 0, 'token_lb': 0}
    )
    for blr in blrs
    for k in kappas
}

In [None]:
gen_futures('Alice is visiting Japan, so she\'s exchanging her dollars for', models, gpt2_tokenizer)

## Hidden/Token lookbacks, MLP/LSTM Sweep (GPT2 and Mistral)

### Checkpoint arguments:
- hidden_idxs: index of hidden state layer

- hidden_lb:
  - if k, then will input hidden state from positions {t-k, ..., t}
  - set to -1 to ignore hidden state.
  - for LSTM, this is the lookback *per position* fed into the LSTM
  - so, since LSTM is already recurrent over the full history, should only set to 0 or -1 (to ignore hidden)

- token_lb: if k, then will input embedded tokens from positions {t-k+1, ..., t+1}
  - set to -1 to ignore input tokens
  
- neck_cls: either 'mlp' or 'lstm'

In [5]:
def get_ckpt(sweep_name, hidden_idxs, hidden_lb, token_lb, neck_cls):
    pattern = '/workspace/checkpoints/' + \
        '_'.join([
            sweep_name,
            '*'
            f'hidden_idxs-{hidden_idxs}',
            f'hidden_lb-{hidden_lb}',
            f'token_lb-{token_lb}',
            f'neck_cls-{neck_cls}',
            '*'
        ]) + '.ckpt'
    try:
        return glob(pattern)[0]
    except IndexError:
        #print(f'WARN: couldn\'t match {pattern}.')
        return None

In [6]:
from itertools import product
ckpts = []
# NECK-SWEEP2 is GPT2
sweeps = ['NECK-SWEEP2', 'MISTRAL-NECK-SWEEP']
hidden_idxs = [0, 11, 12, 31, 32]
hidden_lbs = [-1, 0, 1]
token_lbs = [-1, 0, 1]
neck_cls = ['mlp', 'lstm']
# Only some subset of the full product exists
# But it's easier to just try to load and ignore failures
# instead of explicitly enumerating everything
gpt2_models = {
    '_'.join(str(a) for a in args): (
        LitFutureModelWithNeck,
        get_ckpt(*args),
        {}
    )
    for args in product(['NECK-SWEEP2'], hidden_idxs, hidden_lbs, token_lbs, neck_cls)
    if get_ckpt(*args) is not None
}
mistral_models = {
    '-'.join(str(a) for a in args): (
        LitFutureModelWithNeck,
        get_ckpt(*args),
        {}
    )
    for args in product(['MISTRAL-NECK-SWEEP'], hidden_idxs, hidden_lbs, token_lbs, neck_cls)
    if get_ckpt(*args) is not None
}

In [9]:
get_ckpt('MISTRAL-NECK-SWEEP', 31, 0, 0, 'lstm')

'/workspace/checkpoints/MISTRAL-NECK-SWEEP_20240102-191556-Ec6c4_hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-lstm_epoch=00-val_self_loss=3.81.ckpt'

In [22]:
prompt = 'Alice is visiting Japan, so she\'s exchanging her dollars for'
gen_futures(prompt, gpt2_models, gpt2_tokenizer)

MODEL: NECK-SWEEP2_0_-1_-1_mlp
                        0
base_token_0          her
base_prob_0      0.000519
base_token_1     Japanese
base_prob_1       0.00006
base_token_2        money
base_prob_2      0.000199
base_token_3            a
base_prob_3      0.013693
base_token_4          yen
base_prob_4      0.000019
base_token_5          the
base_prob_5      0.027852
base_token_6         cash
base_prob_6      0.000039
base_token_7         some
base_prob_7      0.000852
base_token_8      dollars
base_prob_8       0.00013
base_token_9       things
base_prob_9      0.000173
future_token_0          ,
future_prob_0    0.036268
future_token_1          .
future_prob_1    0.035682
future_token_2        the
future_prob_2    0.027852
future_token_3         of
future_prob_3    0.020562
future_token_4        and
future_prob_4    0.018334
future_token_5         to
future_prob_5    0.016543
future_token_6          a
future_prob_6    0.013693
future_token_7         in
future_prob_7      0.0119
future_

In [None]:
gen_futures(prompt, mistral_models, mistral_tokenizer)