In [36]:
import json, os, math, sys, random, re, pytz
from datetime import datetime
timezone = pytz.timezone('America/New_York') 
import torch
sys.path.append("../scripts/")
from causal_transformer.model import Causal_Transformer
from causal_transformer.config import *
from causal_transformer.dataset import sequences_collator
from causal_transformer.utils import get_acc

import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import concatenate_datasets
from datasets import load_dataset
from functools import partial
from collections import defaultdict, Counter
from pprint import pprint

In [2]:
device = "cuda"
task = "counting_samesymbol"

In [3]:
config = eval(f"{task}_Config()")
model = Causal_Transformer(config)
model = model.to(device)

In [4]:
ckpt_dir = "/data/yingshac/llms_do_math/scripts/causal_transformer/output"
load_from_dir = "0406_134528"
ckpt_dir = os.path.join(ckpt_dir, load_from_dir, "ckpts")
load_from_pt = sorted(os.listdir(ckpt_dir), key=lambda x: int(x.split("_")[1]))[-1]
state_dict = torch.load(os.path.join(ckpt_dir, load_from_pt), map_location=device)
model.load_state_dict(state_dict, strict=True)
print(f"load from {load_from_pt}")

load from 4_78127_transformer.pt


In [48]:
data_path = "../data/rasp_primitives/counting_samesymbol"
split = "ood_test"
test_data = load_dataset(
                    "text", 
                    data_files={split: f"{data_path}/{split}.txt"})
                
collator = partial(sequences_collator, w2i={w:i for i,w in enumerate(config.vocab)}, max_len=config.max_position_embeddings)
test_dataloader = DataLoader(test_data[split], shuffle=False, batch_size=config.per_device_train_batch_size, collate_fn=collator)

In [49]:
counting_correct, counting_demo, last_correct, last_demo, correct, demo = 0, 0, 0, 0, 0, 0
test_losses = []

criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)

model.eval()
testing_output = {}

date = datetime.now(timezone).strftime("%m%d_%H%M%S")

k = 0
for i, batch in enumerate(test_dataloader):
    logits = model(
        batch['input_id'].to(device),
    )
    loss = criterion(
        logits.view(-1, logits.size(-1)), # bs*seq_len, vocab_size
        batch['label'].view(-1).to(device), # 1, bs*seq_len
    )
    test_losses.append(loss.detach().item())
    _counting_correct, _counting_demo, _last_correct, _last_demo = get_acc(logits.detach().cpu(), batch['label'].detach().cpu(), ignore_index=-1)
    counting_correct += _counting_correct
    counting_demo += _counting_demo
    last_correct += _last_correct
    last_demo += _last_demo
    correct += (_counting_correct + _last_correct)
    demo += (_counting_demo + _last_demo)
   
    for input_id, gth_id, pred_id in zip(batch['input_id'], batch['label'], logits.argmax(dim=-1)):
        input_seq = [config.vocab[i] for i in input_id if config.vocab[i]!='<pad>']
        gth_seq = [config.vocab[i] for i in gth_id if i!=-1]
        pred_seq = [config.vocab[i] for i in pred_id][:len(gth_seq)]
        testing_output[k] = {
            "input": " ".join(input_seq),
            "gth": " ".join(gth_seq),
            "pred": " ".join(pred_seq),
        }
        k+=1
    
print(f"""
        | Test Loss: {round(np.mean(test_losses), 4)} 
        | Test Acc: {round(correct/demo, 4)} 
        | Test Counting Acc: {round(counting_correct/counting_demo, 4)} 
        | Test Last Acc: {round(last_correct/last_demo, 4)}
    """)


        | Test Loss: 5.716 
        | Test Acc: 0.75 
        | Test Counting Acc: 0.7612 
        | Test Last Acc: 0.0
    


In [50]:
os.makedirs(f"../scripts/causal_transformer/output/{load_from_dir}/test_samples", exist_ok=True)
json.dump({
        "test_data_file":  f"{data_path}/{split}.txt",
        "load_from": f"{load_from_dir}/{load_from_pt}",
        "test_acc": round(correct/demo, 4),
        "test_counting_acc": round(counting_correct/counting_demo, 4),
        "test_last_acc": round(last_correct/last_demo, 4),
        "test_loss": round(np.mean(test_losses), 4),
        "testing_output": testing_output,
    }, 
    open(f"../scripts/causal_transformer/output/{load_from_dir}/test_samples/{date}.json", "w"), indent=2)
    