In [1]:
import json, os, math, sys, random, re, pytz, argparse, warnings
from datetime import datetime
timezone = pytz.timezone('America/Los_Angeles') 
import torch
os.chdir("../scripts/causal_transformer")
from model import Causal_Transformer
from config import *
from dataset import sequences_collator
from utils import get_acc, trim_task

import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import load_dataset
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [87]:
handle = "0527_153953"
load_from_epoch = 20
device = "cuda"

In [88]:
config = Basic_Config()
default_config = Default_Config()
load_from_config = json.load(open(os.path.join(config.output_dir, handle, "config.json"), "r"))
config_keys = dir(config)
for k in config_keys:
    if k.startswith("__"): continue
    if k in load_from_config: setattr(config, k, load_from_config[k])
    else:
        setattr(config, k, default_config.__getattribute__(k))
        warnings.warn(f"Cannot find {k} in the resume_from_config. Set to {default_config.__getattribute__(k)} by default.")

ckpt_dir = os.path.join(config.ckpt_dir, handle, "ckpts")
avail_ckpts = sorted(os.listdir(ckpt_dir), key=lambda x: int(x.split("_")[1]))
load_from_pt = [ckpt for ckpt in avail_ckpts if int(ckpt.split("_")[0]) == load_from_epoch-1][0]



In [89]:
augmentation = None
if config.absolute_posemb_shift or config.rotary_posemb_shift or config.sinusoidal_posemb_shift:
    augmentation = "shift"
elif config.absolute_posemb_rdmz or config.rotary_posemb_rdmz:
    augmentation = "randomized"
elif config.scaler_posemb:
    if config.scaler_posemb_shift: augmentation = "scaler+shift"
    else: augmentation = "zooming"
collator = partial(sequences_collator, 
                w2i={w:i for i,w in enumerate(config.vocab)}, 
                max_seq_len=config.max_seq_len,
                max_position_embeddings=config.max_position_embeddings,
                augmentation=augmentation,
                )
val_file = open(f"{config.eval_data_path}/{trim_task(config.task)}/val.txt", "r").readlines()
max_seen_len = max([len([x for x in json.loads(l)[0] if x != "<pad>"]) for l in val_file])
print(f"max_seen_len for {config.task} = {max_seen_len}")

max_seen_len for counting_samesymbol_plain3_addbigram = 51


In [102]:
model = Causal_Transformer(config)
model = model.to(device)
state_dict = torch.load(os.path.join(ckpt_dir, load_from_pt), map_location=device)
model.load_state_dict(state_dict, strict=False)
model.eval()


Causal_Transformer(
  (wte): Embedding(104, 1024)
  (drop): Dropout(p=0.0, inplace=False)
  (h): ModuleList(
    (0): Block(
      (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (rotary_emb): RotaryEmbedding()
      )
      (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): ReLU()
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=1024, out_features=104, bias=False)
)

In [6]:
split = "val"
data_path = f"{config.eval_data_path}/{trim_task(config.task)}"
test_data = load_dataset(
                    "text", 
                    data_files={split: f"{data_path}/{split}.txt"})
test_dataloader = DataLoader(test_data[split], shuffle=False, batch_size=1, collate_fn=collator)
test_iter = iter(test_dataloader)

In [58]:
batch = next(test_iter)

position_ids = None
if batch['position_id'] is not None: position_ids = batch['position_id'].to(device)
input_ids = batch['input_id'].to(device)
attention_mask = batch['attention_mask'].to(device)


In [73]:
from model import apply_rotary_pos_emb

In [155]:
input_embs = model.wte(input_ids)

for block in model.h:
    hidden_states = block.ln_1(input_embs)
    query, key, value = block.attn.c_attn(hidden_states).split(block.attn.split_size, dim=2)

    query = block.attn.split_heads(query)
    key = block.attn.split_heads(key, k=True)
    value = block.attn.split_heads(value)

    # apply rotary positional embedding if needed
    if block.attn.config.rotary_posemb:
        cos, sin = block.attn.rotary_emb(value, position_ids)
        query, key = apply_rotary_pos_emb(query, key, cos, sin)
    attn_output, attn_patterns = block.attn._attn(query, key, value, attention_mask, position_ids)
    #attn_output = block.attn.merge_heads(attn_output)
    #attn_output = block.attn.c_proj(attn_output)
    #attn_output = block.attn.resid_dropout(attn_output)
    # attn_output = block.attn(block.ln_1(input_embs), attention_mask, position_ids)
    break

In [130]:
attn_output.shape

torch.Size([1, 128, 1024])

In [203]:
attn_patterns[0, 0, 14]

tensor([1.0000e+00, 3.5739e-14, 1.2819e-13, 8.9527e-11, 1.0814e-10, 2.2454e-11,
        3.5073e-11, 1.4714e-12, 1.0287e-13, 4.9049e-13, 4.3785e-13, 1.1728e-13,
        2.3829e-14, 2.4624e-16, 7.1920e-16, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+