In [21]:
import json, os, math, sys, random, re, pytz
from datetime import datetime
timezone = pytz.timezone('America/New_York') 
import torch
sys.path.append("../scripts/")
from causal_transformer.model import Causal_Transformer
from causal_transformer.config import *
from causal_transformer.dataset import sequences_collator
from causal_transformer.utils import get_acc

import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import concatenate_datasets
from datasets import load_dataset
from functools import partial
from collections import defaultdict, Counter
from pprint import pprint

In [22]:
device = "cuda"
task = "counting_raspL"

In [23]:
config = eval(f"{task}_Config()")

ckpt_dir = "/data/yingshac/llms_do_math/scripts/causal_transformer/output"
load_from_dir = "0411_232651"
load_from_config = json.load(open(os.path.join("../scripts/causal_transformer/output", load_from_dir, "config.json"), "r"))
for k in load_from_config:
    setattr(config, k, load_from_config[k])

In [24]:
model = Causal_Transformer(config)
model = model.to(device)

In [25]:
load_from_specific_epc = 5
ckpt_dir = os.path.join(ckpt_dir, load_from_dir, "ckpts")
if load_from_specific_epc is None:
    load_from_pt = sorted(os.listdir(ckpt_dir), key=lambda x: int(x.split("_")[1]))[-1]
else:
    load_from_pt = sorted([x for x in os.listdir(ckpt_dir) if f"{load_from_specific_epc-1}_" in x[:4]], key=lambda x: int(x.split("_")[1]))[-1]
state_dict = torch.load(os.path.join(ckpt_dir, load_from_pt), map_location=device)
model.load_state_dict(state_dict, strict=False)
print(f"load from {load_from_pt}")

load from 4_39065_transformer.pt


In [26]:
data_path = f"../data/rasp_primitives/{task}"
split = "ood_test"
test_data = load_dataset(
                    "text", 
                    data_files={split: f"{data_path}/{split}.txt"})

collator = partial(sequences_collator, 
                   w2i={w:i for i,w in enumerate(config.vocab)}, 
                   max_len=config.max_position_embeddings,
                   posemb_shift=config.absolute_posemb_shift or config.rotary_posemb_shift
                   )
test_dataloader = DataLoader(test_data[split], shuffle=False, batch_size=config.per_device_train_batch_size, collate_fn=collator)

In [27]:
counting_correct, counting_demo, last_correct, last_demo, correct, demo = 0, 0, 0, 0, 0, 0
test_losses = []

criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)

model.eval()
testing_output = {}

date = datetime.now(timezone).strftime("%m%d_%H%M%S")

k = 0
for i, batch in enumerate(test_dataloader):
    position_ids = None
    if batch['position_id'] is not None: position_ids = batch['position_id'].to(device)
    
    logits = model(
        batch['input_id'].to(device),
        position_ids = position_ids,
        attention_mask = batch['attention_mask'].to(device),
    )

    loss = criterion(
        logits.view(-1, logits.size(-1)), # bs*seq_len, vocab_size
        batch['label'].view(-1).to(device), # 1, bs*seq_len
    )
    test_losses.append(loss.detach().item())
    _counting_correct, _counting_demo, _last_correct, _last_demo = get_acc(logits.detach().cpu(), batch['label'].detach().cpu(), ignore_index=-1)
    counting_correct += _counting_correct
    counting_demo += _counting_demo
    last_correct += _last_correct
    last_demo += _last_demo
    correct += (_counting_correct + _last_correct)
    demo += (_counting_demo + _last_demo)
   
    for input_id, gth_id, pred_id in zip(batch['input_id'], batch['label'], logits.argmax(dim=-1)):
        input_seq = [config.vocab[i] for i in input_id if config.vocab[i]!='<pad>']
        gth_seq = [config.vocab[gth_id[i]] for i in range(len(gth_id)) if gth_id[i]!=-1]
        pred_seq = [config.vocab[pred_id[i]] for i in range(len(gth_id)) if gth_id[i]!=-1][:len(gth_seq)]
        testing_output[k] = {
            "input": " ".join(input_seq),
            "gth": " ".join(gth_seq),
            "pred": " ".join(pred_seq),
        }
        k+=1
    
print(f"""
        | Test Loss: {round(np.mean(test_losses), 4)} 
        | Test Acc: {round(correct/demo, 4)} 
        | Test Counting Acc: {round(counting_correct/counting_demo, 4)} 
        | Test Last Acc: {round(last_correct/last_demo, 4)}
    """)


        | Test Loss: 0.2256 
        | Test Acc: 0.9685 
        | Test Counting Acc: 0.9685 
        | Test Last Acc: 0.9704
    


In [28]:
os.makedirs(f"../scripts/causal_transformer/output/{load_from_dir}/test_samples", exist_ok=True)
json.dump({
        "test_data_file":  f"{data_path}/{split}.txt",
        "load_from": f"{load_from_dir}/{load_from_pt}",
        "test_acc": round(correct/demo, 4),
        "test_counting_acc": round(counting_correct/counting_demo, 4),
        "test_last_acc": round(last_correct/last_demo, 4),
        "test_loss": round(np.mean(test_losses), 4),
        "testing_output": testing_output,
    }, 
    open(f"../scripts/causal_transformer/output/{load_from_dir}/test_samples/{date}.json", "w"), indent=2)
    

In [4]:
from causal_transformer.utils import count_parameters
count_parameters(model)

+------------------------+---------+--------------+
|        Modules         | #Params | Param shape  |
+------------------------+---------+--------------+
|       wte.weight       |  131072 | [128, 1024]  |
|       wpe.weight       |  131072 | [128, 1024]  |
|    h.0.ln_1.weight     |   1024  |    [1024]    |
|     h.0.ln_1.bias      |   1024  |    [1024]    |
| h.0.attn.c_attn.weight | 3145728 | [1024, 3072] |
|  h.0.attn.c_attn.bias  |   3072  |    [3072]    |
| h.0.attn.c_proj.weight | 1048576 | [1024, 1024] |
|  h.0.attn.c_proj.bias  |   1024  |    [1024]    |
|    h.0.ln_2.weight     |   1024  |    [1024]    |
|     h.0.ln_2.bias      |   1024  |    [1024]    |
|  h.0.mlp.c_fc.weight   | 4194304 | [1024, 4096] |
|   h.0.mlp.c_fc.bias    |   4096  |    [4096]    |
| h.0.mlp.c_proj.weight  | 4194304 | [4096, 1024] |
|  h.0.mlp.c_proj.bias   |   1024  |    [1024]    |
|    h.1.ln_1.weight     |   1024  |    [1024]    |
|     h.1.ln_1.bias      |   1024  |    [1024]    |
| h.1.attn.c

In [1]:
import torch

In [74]:
x = torch.randint(0, 2, (3, 1, 8, 8))
print(x)

tensor([[[[1, 1, 1, 1, 1, 1, 0, 1],
          [1, 1, 1, 1, 0, 0, 0, 1],
          [0, 1, 1, 0, 0, 0, 1, 1],
          [0, 0, 1, 1, 1, 0, 0, 1],
          [1, 1, 0, 0, 0, 1, 0, 1],
          [0, 0, 0, 0, 1, 1, 0, 0],
          [1, 0, 0, 1, 1, 0, 0, 0],
          [1, 1, 0, 1, 0, 1, 1, 1]]],


        [[[1, 0, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 1, 1, 1, 0, 1],
          [1, 1, 1, 1, 1, 1, 1, 0],
          [0, 1, 1, 0, 0, 0, 1, 0],
          [1, 1, 1, 1, 0, 0, 0, 1],
          [0, 0, 0, 1, 0, 1, 1, 0],
          [0, 0, 0, 0, 1, 0, 1, 0],
          [1, 0, 0, 1, 0, 0, 0, 1]]],


        [[[1, 1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 1, 0, 0, 0, 0],
          [0, 0, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 1, 0, 1, 0],
          [1, 1, 1, 1, 0, 0, 0, 1],
          [1, 0, 0, 1, 0, 0, 1, 0],
          [1, 0, 1, 0, 0, 1, 1, 1],
          [1, 1, 1, 0, 1, 0, 1, 1]]]])


In [75]:
eye = torch.eye(8).unsqueeze(0)[:, None, :, :]
eye.size()

torch.Size([1, 1, 8, 8])

In [76]:
(x.bool() | eye.bool()).long()

tensor([[[[1, 1, 1, 1, 1, 1, 0, 1],
          [1, 1, 1, 1, 0, 0, 0, 1],
          [0, 1, 1, 0, 0, 0, 1, 1],
          [0, 0, 1, 1, 1, 0, 0, 1],
          [1, 1, 0, 0, 1, 1, 0, 1],
          [0, 0, 0, 0, 1, 1, 0, 0],
          [1, 0, 0, 1, 1, 0, 1, 0],
          [1, 1, 0, 1, 0, 1, 1, 1]]],


        [[[1, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 1, 1, 1, 0, 1],
          [1, 1, 1, 1, 1, 1, 1, 0],
          [0, 1, 1, 1, 0, 0, 1, 0],
          [1, 1, 1, 1, 1, 0, 0, 1],
          [0, 0, 0, 1, 0, 1, 1, 0],
          [0, 0, 0, 0, 1, 0, 1, 0],
          [1, 0, 0, 1, 0, 0, 0, 1]]],


        [[[1, 1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 1, 0, 0, 0, 0],
          [0, 0, 1, 1, 0, 0, 0, 0],
          [1, 0, 0, 1, 1, 0, 1, 0],
          [1, 1, 1, 1, 1, 0, 0, 1],
          [1, 0, 0, 1, 0, 1, 1, 0],
          [1, 0, 1, 0, 0, 1, 1, 1],
          [1, 1, 1, 0, 1, 0, 1, 1]]]])