In [1]:
import json, os, math, sys, random, re, pytz
from datetime import datetime
timezone = pytz.timezone('America/New_York') 
import torch
sys.path.append("../scripts/")
from causal_transformer.model import Causal_Transformer
from causal_transformer.config import *
from causal_transformer.dataset import sequences_collator
from causal_transformer.utils import get_acc

import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import concatenate_datasets
from datasets import load_dataset
from functools import partial
from collections import defaultdict, Counter
from pprint import pprint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda"
task = "counting_diffsymbol"

In [3]:
config = eval(f"{task}_Config()")
model = Causal_Transformer(config)
model = model.to(device)

In [6]:
ckpt_dir = "/data/yingshac/llms_do_math/scripts/causal_transformer/output"
load_from_dir = "0406_161636"
load_from_specific_epc = 5
ckpt_dir = os.path.join(ckpt_dir, load_from_dir, "ckpts")
if load_from_specific_epc is None:
    load_from_pt = sorted(os.listdir(ckpt_dir), key=lambda x: int(x.split("_")[1]))[-1]
else:
    load_from_pt = sorted([x for x in os.listdir(ckpt_dir) if f"{load_from_specific_epc-1}_" in x[:4]], key=lambda x: int(x.split("_")[1]))[-1]
state_dict = torch.load(os.path.join(ckpt_dir, load_from_pt), map_location=device)
model.load_state_dict(state_dict, strict=True)
print(f"load from {load_from_pt}")

load from 4_78125_transformer.pt


In [11]:
data_path = f"../data/rasp_primitives/{task}"
split = "ood_test"
test_data = load_dataset(
                    "text", 
                    data_files={split: f"{data_path}/{split}.txt"})
                
collator = partial(sequences_collator, w2i={w:i for i,w in enumerate(config.vocab)}, max_len=config.max_position_embeddings)
test_dataloader = DataLoader(test_data[split], shuffle=False, batch_size=config.per_device_train_batch_size, collate_fn=collator)

Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 13706.88it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 277.09it/s]
Generating ood_test split: 1225 examples [00:00, 80874.26 examples/s]


In [12]:
counting_correct, counting_demo, last_correct, last_demo, correct, demo = 0, 0, 0, 0, 0, 0
test_losses = []

criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)

model.eval()
testing_output = {}

date = datetime.now(timezone).strftime("%m%d_%H%M%S")

k = 0
for i, batch in enumerate(test_dataloader):
    logits = model(
        batch['input_id'].to(device),
    )
    loss = criterion(
        logits.view(-1, logits.size(-1)), # bs*seq_len, vocab_size
        batch['label'].view(-1).to(device), # 1, bs*seq_len
    )
    test_losses.append(loss.detach().item())
    _counting_correct, _counting_demo, _last_correct, _last_demo = get_acc(logits.detach().cpu(), batch['label'].detach().cpu(), ignore_index=-1)
    counting_correct += _counting_correct
    counting_demo += _counting_demo
    last_correct += _last_correct
    last_demo += _last_demo
    correct += (_counting_correct + _last_correct)
    demo += (_counting_demo + _last_demo)
   
    for input_id, gth_id, pred_id in zip(batch['input_id'], batch['label'], logits.argmax(dim=-1)):
        input_seq = [config.vocab[i] for i in input_id if config.vocab[i]!='<pad>']
        gth_seq = [config.vocab[i] for i in gth_id if i!=-1]
        pred_seq = [config.vocab[i] for i in pred_id][:len(gth_seq)]
        testing_output[k] = {
            "input": " ".join(input_seq),
            "gth": " ".join(gth_seq),
            "pred": " ".join(pred_seq),
        }
        k+=1
    
print(f"""
        | Test Loss: {round(np.mean(test_losses), 4)} 
        | Test Acc: {round(correct/demo, 4)} 
        | Test Counting Acc: {round(counting_correct/counting_demo, 4)} 
        | Test Last Acc: {round(last_correct/last_demo, 4)}
    """)


        | Test Loss: 6.6906 
        | Test Acc: 0.75 
        | Test Counting Acc: 0.7612 
        | Test Last Acc: 0.0
    


In [13]:
os.makedirs(f"../scripts/causal_transformer/output/{load_from_dir}/test_samples", exist_ok=True)
json.dump({
        "test_data_file":  f"{data_path}/{split}.txt",
        "load_from": f"{load_from_dir}/{load_from_pt}",
        "test_acc": round(correct/demo, 4),
        "test_counting_acc": round(counting_correct/counting_demo, 4),
        "test_last_acc": round(last_correct/last_demo, 4),
        "test_loss": round(np.mean(test_losses), 4),
        "testing_output": testing_output,
    }, 
    open(f"../scripts/causal_transformer/output/{load_from_dir}/test_samples/{date}.json", "w"), indent=2)
    

In [4]:
from causal_transformer.utils import count_parameters
count_parameters(model)

+------------------------+---------+--------------+
|        Modules         | #Params | Param shape  |
+------------------------+---------+--------------+
|       wte.weight       |  131072 | [128, 1024]  |
|       wpe.weight       |  131072 | [128, 1024]  |
|    h.0.ln_1.weight     |   1024  |    [1024]    |
|     h.0.ln_1.bias      |   1024  |    [1024]    |
| h.0.attn.c_attn.weight | 3145728 | [1024, 3072] |
|  h.0.attn.c_attn.bias  |   3072  |    [3072]    |
| h.0.attn.c_proj.weight | 1048576 | [1024, 1024] |
|  h.0.attn.c_proj.bias  |   1024  |    [1024]    |
|    h.0.ln_2.weight     |   1024  |    [1024]    |
|     h.0.ln_2.bias      |   1024  |    [1024]    |
|  h.0.mlp.c_fc.weight   | 4194304 | [1024, 4096] |
|   h.0.mlp.c_fc.bias    |   4096  |    [4096]    |
| h.0.mlp.c_proj.weight  | 4194304 | [4096, 1024] |
|  h.0.mlp.c_proj.bias   |   1024  |    [1024]    |
|    h.1.ln_1.weight     |   1024  |    [1024]    |
|     h.1.ln_1.bias      |   1024  |    [1024]    |
| h.1.attn.c

In [2]:
len(["j", "a", "v", "k", "e", "z", "a", "a", "g", "b", "r", "q", "o", "g", "x", "x", "w", "u", "i", "s", "b", "z", "t", "o", "v", "t", "r", "b", "s", "v", "f", "y", "q", "m", "z", "z", "n", "x", "r", "n", "l", "c", "r", "g", "s", "f", "t", "i", "x", "q", "e", "b", "x", "v", "d", "c", "b", "f", "t", "b", "b", "f", "e", "d", "a", "s", "b", "m", "z", "f", "d", "s", "r", "r", "r", "l", "x", "b", "w", "j", "r", "v", "u", "d", "z", "m", "v", "z", "y", "v", "i", "k", "n", "f", "m", "g", "r", "t", "m", "1"]), len(["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", "24", "25", "26", "27", "28", "29", "30", "31", "32", "33", "34", "35", "36", "37", "38", "39", "40", "41", "42", "43", "44", "45", "46", "47", "48", "49", "50", "51", "52", "53", "54", "55", "56", "57", "58", "59", "60", "61", "62", "63", "64", "65", "66", "67", "68", "69", "70", "71", "72", "73", "74", "75", "76", "77", "78", "79", "80", "81", "82", "83", "84", "85", "86", "87", "88", "89", "90", "91", "92", "93", "94", "95", "96", "97", "98", "99", "100"])


(100, 100)