In [None]:
import os
import json

import numpy as np
import pandas as pd
import scipy.stats as ss

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

import transformers as tf
import datasets as ds
from datasets import load_metric

import matplotlib as mp
%matplotlib inline
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm, trange

import clip_graph as cg

In [None]:
os.chdir(os.path.expanduser('~/github/congrat'))

In [None]:
device = 'cuda:2'

In [None]:
seed = 2969591811
pl.seed_everything(seed)

# Utils

In [None]:
def perplexity(model, ldr, max_batches=None, device=None):
    ppls = []
    
    for i, batch in enumerate(tqdm(ldr, total=max_batches)):
        if max_batches is not None and i >= max_batches:
            break
        
        input_ids = batch['input_ids']
        attn_mask = batch['attention_mask']
        labels = batch['labels']
        
        if device is not None:
            input_ids = input_ids.to(device)
            attn_mask = attn_mask.to(device)
            labels = labels.to(device)
        
        with torch.no_grad():
            logits = model(
                input_ids=input_ids,
                attention_mask=attn_mask
            ).logits
        
        ppl = F.cross_entropy(logits.transpose(1, 2), labels, reduction='none')
        ppl = (ppl * attn_mask).sum(dim=1) / attn_mask.sum(dim=1)
        ppl = torch.exp2(ppl)
        
        ppls += ppl.tolist()
    
    return ppls

# Load

In [None]:
dataset = 'trex'

In [None]:
eval_causal_dm = cg.utils.datamodule_from_yaml(f'configs/eval-datasets/{dataset}/causal.yaml')['dm']

ds = eval_causal_dm.test_dataset.dataset

In [None]:
ckpt_dir = f'lightning_logs/lm-pretrain/{dataset}/causal/version_1'
lm_model = cg.scoring.interpret_ckpt_dir(ckpt_dir, eval_causal_dm)['model'].model.model
lm_model = lm_model.to(device)

ckpt_dir = f'lightning_logs/clip-graph/causal-lm-train/{dataset}/extra-epoch/version_0/'
cg_model = cg.scoring.interpret_ckpt_dir(ckpt_dir, eval_causal_dm)['model'].model
cg_model = cg_model.to(device)

assert cg_model.config.n_embd == lm_model.config.n_embd
assert cg_model.config.vocab_size == lm_model.config.vocab_size

# lm-pretrain model

In [None]:
ldr = torch.utils.data.DataLoader(ds, collate_fn=ds.__collate__, batch_size=8)
lm_ppls = perplexity(lm_model, ldr, device=device)
lm_ppls = pd.Series(lm_ppls)

In [None]:
lm_ppls.describe()

In [None]:
lm_ppls.hist(log=True)

# clip-graph model

In [None]:
ldr = torch.utils.data.DataLoader(ds, collate_fn=ds.__collate__, batch_size=8)
cg_ppls = perplexity(cg_model, ldr, device=device)
cg_ppls = pd.Series(cg_ppls)

In [None]:
cg_ppls.describe()

In [None]:
cg_ppls.hist(log=True)

# Test difference

In [None]:
dat = lm_ppls.to_numpy()[None, :]

res = ss.bootstrap(
    dat,
    np.mean,
    n_resamples=10000,
    
    # ss.bootstrap computes two-tailed intervals,
    # we want a one-tailed test
    confidence_level = 1 - 0.05 * 2,
)

display(
    lm_ppls.mean(),
    cg_ppls.mean(),
    res.confidence_interval,
    (cg_ppls.mean() > res.confidence_interval.high or cg_ppls.mean() < res.confidence_interval.low)
)