In [1]:
import json, os, math, sys, random, re, pytz
from datetime import datetime
timezone = pytz.timezone('America/New_York') 
import torch
sys.path.append("../scripts/")
from causal_transformer.model import Causal_Transformer
from causal_transformer.config import *
from causal_transformer.dataset import sequences_collator
from causal_transformer.utils import get_acc

import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import concatenate_datasets
from datasets import load_dataset
from functools import partial
from collections import defaultdict, Counter
from pprint import pprint

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
device = "cuda"
task = "counting_diffsymbol_mod16"
config = eval(f"{task}_Config()")
ckpt_dir = "/data/yingshac/llms_do_math/scripts/causal_transformer/output"

load_from_dir = "0416_093958"
load_from_specific_epc = 3

load_from_config = json.load(open(os.path.join("../scripts/causal_transformer/output", load_from_dir, "config.json"), "r"))
for k in load_from_config:
    setattr(config, k, load_from_config[k])
model = Causal_Transformer(config)
model = model.to(device)

ckpt_dir = os.path.join(ckpt_dir, load_from_dir, "ckpts")
if load_from_specific_epc is None:
    load_from_pt = sorted(os.listdir(ckpt_dir), key=lambda x: int(x.split("_")[1]))[-1]
else:
    load_from_pt = sorted([x for x in os.listdir(ckpt_dir) if f"{load_from_specific_epc-1}_" in x[:4]], key=lambda x: int(x.split("_")[1]))[-1]
state_dict = torch.load(os.path.join(ckpt_dir, load_from_pt), map_location=device)
model.load_state_dict(state_dict, strict=False)
print(f"load from {load_from_pt}")


load from 2_93750_transformer.pt


In [8]:
data_path = f"../data/rasp_primitives/{task}"
split = "ood_test"
test_data = load_dataset(
                    "text", 
                    data_files={split: f"{data_path}/{split}.txt"})
print(f"num {split} data = {len(test_data[split])}")

if config.absolute_posemb_shift or config.rotary_posemb_shift:
    augmentation = "shift"
elif config.absolute_posemb_rdmz or config.rotary_posemb_rdmz:
    augmentation = "randomized"
collator = partial(sequences_collator, 
                   w2i={w:i for i,w in enumerate(config.vocab)}, 
                   max_len=config.max_position_embeddings,
                   augmentation=augmentation,
                   )
test_dataloader = DataLoader(test_data[split], shuffle=False, batch_size=config.per_device_train_batch_size, collate_fn=collator)

num ood_test data = 2800


In [9]:
counting_correct, counting_demo, last_correct, last_demo, correct, demo = 0, 0, 0, 0, 0, 0
test_losses = []

criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)

model.eval()
testing_output = {}

date = datetime.now(timezone).strftime("%m%d_%H%M%S")

k = 0
for i, batch in enumerate(test_dataloader):
    position_ids = None
    if batch['position_id'] is not None: position_ids = batch['position_id'].to(device)
    
    logits = model(
        batch['input_id'].to(device),
        position_ids = position_ids,
        attention_mask = batch['attention_mask'].to(device),
    )

    loss = criterion(
        logits.view(-1, logits.size(-1)), # bs*seq_len, vocab_size
        batch['label'].view(-1).to(device), # 1, bs*seq_len
    )
    test_losses.append(loss.detach().item())
    _counting_correct, _counting_demo, _last_correct, _last_demo = get_acc(logits.detach().cpu(), batch['label'].detach().cpu(), ignore_index=-1)
    counting_correct += _counting_correct
    counting_demo += _counting_demo
    last_correct += _last_correct
    last_demo += _last_demo
    correct += (_counting_correct + _last_correct)
    demo += (_counting_demo + _last_demo)
   
    for input_id, gth_id, pred_id in zip(batch['input_id'], batch['label'], logits.argmax(dim=-1)):
        input_seq = [config.vocab[i] for i in input_id if config.vocab[i]!='<pad>']
        gth_seq = [config.vocab[gth_id[i]] for i in range(len(gth_id)) if gth_id[i]!=-1]
        pred_seq = [config.vocab[pred_id[i]] for i in range(len(gth_id)) if gth_id[i]!=-1][:len(gth_seq)]
        testing_output[k] = {
            "input": " ".join(input_seq),
            "gth": " ".join(gth_seq),
            "pred": " ".join(pred_seq),
        }
        k+=1
    
print(f""" {split} acc
        | Test Loss: {round(np.mean(test_losses), 4)} 
        | Test Acc: {round(correct/demo, 4)} 
        | Test Counting Acc: {round(counting_correct/counting_demo, 4)} 
        | Test Last Acc: {round(last_correct/last_demo, 4)}
    """)

 ood_test acc
        | Test Loss: 0.0003 
        | Test Acc: 0.9999 
        | Test Counting Acc: 0.9999 
        | Test Last Acc: 0.9989
    


In [10]:
os.makedirs(f"../scripts/causal_transformer/output/{load_from_dir}/test_samples", exist_ok=True)
json.dump({
        "test_data_file":  f"{data_path}/{split}.txt",
        "load_from": f"{load_from_dir}/{load_from_pt}",
        "test_acc": round(correct/demo, 4),
        "test_counting_acc": round(counting_correct/counting_demo, 4),
        "test_last_acc": round(last_correct/last_demo, 4),
        "test_loss": round(np.mean(test_losses), 4),
        "testing_output": testing_output,
    }, 
    open(f"../scripts/causal_transformer/output/{load_from_dir}/test_samples/{date}.json", "w"), indent=2)
    

## Model Summary (In the order of forward pass)

In [3]:
from torchinfo import summary

In [29]:
device = "cuda"
task = "counting_selective_padhelper"
config = eval(f"{task}_Config()")
config.absolute_posemb_shift = False
config.rotary_posemb_shift = False
config.absolute_posemb = False
config.rotary_posemb = False
config.num_hidden_layers = 2
config.embd_pdrop = 0.1
config.attn_pdrop = 0.1
config.resid_pdrop = 0.1

model = Causal_Transformer(config)
model = model.to(device)
print(next(model.parameters()).device)

cuda:0


In [30]:
"""
References: https://github.com/TylerYep/torchinfo
"""

summary(
    model,
    (1, 128),
    dtypes=[torch.long],
    verbose=2,
    col_width=16,
    col_names=[
        "kernel_size", 
        #"output_size", 
        "num_params", 
        "params_percent"
    ],
    row_settings=["var_names"],
    device="cuda"
)

Layer (type (var_name))                       Kernel Shape     Param #          Param %
Causal_Transformer (Causal_Transformer)       --               --                    --
├─Embedding (wte)                             --               55,296             0.22%
│    └─weight                                 [1024, 54]       └─55,296
├─Dropout (drop)                              --               --                    --
├─ModuleList (h)                              --               --                    --
│    └─0.ln_1.weight                          [1024]           ├─1,024
│    └─0.ln_1.bias                            [1024]           ├─1,024
│    └─0.attn.c_attn.weight                   [1024, 3072]     ├─3,145,728
│    └─0.attn.c_attn.bias                     [3072]           ├─3,072
│    └─0.attn.c_proj.weight                   [1024, 1024]     ├─1,048,576
│    └─0.attn.c_proj.bias                     [1024]           ├─1,024
│    └─0.ln_2.weight                          [1024]  

Layer (type (var_name))                       Kernel Shape     Param #          Param %
Causal_Transformer (Causal_Transformer)       --               --                    --
├─Embedding (wte)                             --               55,296             0.22%
│    └─weight                                 [1024, 54]       └─55,296
├─Dropout (drop)                              --               --                    --
├─ModuleList (h)                              --               --                    --
│    └─0.ln_1.weight                          [1024]           ├─1,024
│    └─0.ln_1.bias                            [1024]           ├─1,024
│    └─0.attn.c_attn.weight                   [1024, 3072]     ├─3,145,728
│    └─0.attn.c_attn.bias                     [3072]           ├─3,072
│    └─0.attn.c_proj.weight                   [1024, 1024]     ├─1,048,576
│    └─0.attn.c_proj.bias                     [1024]           ├─1,024
│    └─0.ln_2.weight                          [1024]  

In [6]:
len(["q", "b", "e", "y", "q", "m", "m", "f", "s", "k", "r", "m", "p", "b", "u", "e", "p", "y", "q", "b", "v", "j", "i", "a", "w", "e", "e", "x", "o", "s", "i", "o", "j", "e", "u", "m", "x", "n", "z", "l", "o", "f", "u", "i", "o", "k", "d", "r", "y", "s", "h", "x", "f", "a", "o", "i", "j", "r", "y", "x", "q", "b", "f", "j", "y", "p", "g", "y", "e", "a", "v", "k", "h", "l", "v", "a", "l", "b", "e", "z", "r", "b", "t", "g", "k", "u", "t", "h", "b", "q", "e", "t", "x", "f", "b", "g", "s", "p", "w", "x", "u", "q", "s", "o", "f", "c", "e", "v", "b", "l", "e", "c", "b", "b", "q", "s", "e", "q", "r", "c", "q", "x", "e", "x", "u", "z", "t", "u", "m", "w", "v", "h", "p", "h", "w", "s", "l", "q", "d", "f", "k", "c", "s", "y", "q", "w", "c", "k", "e", "l", "p", "s", "p", "f", "c", "v", "p", "r", "j", "m", "b", "e", "m", "u", "n", "a", "k", "o", "x", "x", "e", "s", "x", "y", "m", "n", "d", "u", "i", "t", "i", "n", "t", "o", "d", "v", "d", "x", "c", "r", "a", "v", "b", "g", "a", "z", "k", "m", "d", "u"])
    

200