In [1]:
%load_ext autoreload
%autoreload 2

from src.utils.helper import *
import pickle

with open('models/testrun_dyck1_4h/vocab.p', 'rb') as f:
    vocab = pickle.load(f)
    
with open('/Users/aviralgupta/naacl_work/Transformer-Formal-Languages-2/models/testrun_dyck1_4h/config.p', 'rb') as f:
    config = pickle.load(f)
    
config

{'mode': 'train',
 'debug': False,
 'load_model': False,
 'results': True,
 'run_name': 'testrun_dyck1_4h',
 'display_freq': 35,
 'dataset': 'Dyck-1',
 'vocab_size': 4,
 'histogram': True,
 'gpu': 0,
 'seed': 1729,
 'logging': 1,
 'ckpt': 'model',
 'emb_size': 64,
 'model_type': 'SAN',
 'cell_type': 'LSTM',
 'hidden_size': 64,
 'depth': 1,
 'dropout': 0.0,
 'max_length': 35,
 'bptt': 35,
 'use_emb': False,
 'init_range': 0.08,
 'tied': False,
 'generalize': True,
 'd_model': 32,
 'd_ffn': 64,
 'heads': 4,
 'pos_encode': False,
 'max_period': 10000.0,
 'pos_encode_type': 'absolute',
 'posffn': True,
 'bias': True,
 'viz': False,
 'freeze_emb': False,
 'freeze_q': False,
 'freeze_k': False,
 'freeze_v': False,
 'freeze_f': False,
 'zero_k': False,
 'lr': 0.005,
 'decay_patience': 3,
 'decay_rate': 0.1,
 'max_grad_norm': -0.25,
 'batch_size': 32,
 'epochs': 25,
 'opt': 'rmsprop',
 'lang': 'Dyck',
 'lower_window': 2,
 'upper_window': 100,
 'lower_depth': 0,
 'upper_depth': -1,
 'val_lower_

In [2]:
from src.dataloader import *

data_path = '/Users/aviralgupta/naacl_work/Transformer-Formal-Languages-2/data/Dyck-1/train_corpus.pk'
with open(data_path, 'rb') as f:
    train_corpus = pickle.load(f)
    
train_loader = Sampler(train_corpus, vocab, 10)

src, _, wd_lens = train_loader.get_batch(0)
src.shape, wd_lens

['()((()))', '((())((())))', '((((((())(((((((()))))))))))))', '(())()', '(())', '()', '()()', '((((((()))((()((((((()(())))))))()()))))))', '(((()))((())(()((((())))()))))', '((()(())))']


(torch.Size([42, 10]), tensor([ 8, 12, 30,  6,  4,  2,  4, 42, 30, 10]))

In [3]:
def pad_seq(seq, max_length, voc):
    seq += [voc.get_id('T') for i in range(max_length - len(seq))]
    return seq

def sent_to_idx(voc, sent, max_length=-1):
    idx_vec = []
    for w in sent:
        idx = voc.get_id(w)
        idx_vec.append(idx)

    idx_vec.append(voc.get_id('T'))
    idx_vec = pad_seq(idx_vec, max_length+1, voc)
    return idx_vec

def sents_to_idx(voc, sents):
    max_length = max([len(s) for s in sents])
    all_indexes = []
    for sent in sents:
        all_indexes.append(sent_to_idx(voc, sent, max_length))

    all_indexes = torch.tensor(all_indexes, dtype= torch.long)
    return all_indexes

raw = train_corpus.source
data_ids = sents_to_idx(vocab, raw)
data = data_ids[:, :-1]
data.shape

torch.Size([10000, 50])

In [4]:
OPEN = vocab.get_id('(')
CLOSE = vocab.get_id(')')

print(OPEN, CLOSE)

def simulate_stack(tokenised_paren):
    stack = []
    stack_depths = []
    for token in tokenised_paren:
        token = token.item()
        if token == OPEN:
            stack.append(token)
        elif token == CLOSE:
            stack.pop()
        stack_depths.append(len(stack))
    return stack_depths

def create_dataset(model, paren_tokens, device="mps"):
    X = []
    Y = []

    model.eval()

    with torch.no_grad():
        for tokenised_paren in paren_tokens:
            output, internal = model.model(tokenised_paren.to(device).unsqueeze(0).T, get_encoder_reps=True)
            internal = internal.transpose(0, 1).squeeze(0)
            stack_depths = simulate_stack(tokenised_paren)
            for i, states in enumerate(internal):
                X.append(states.cpu().numpy())
                Y.append(stack_depths[i])

    X_tensor = torch.tensor(X)
    Y_tensor = torch.tensor(Y)
    
    return X_tensor, Y_tensor

1 2


In [5]:
import src
import torch
from src.model import LanguageModel
from src.components.transformers import TransformerModel

import torch

chkpt_path = '/Users/aviralgupta/naacl_work/Transformer-Formal-Languages-2/models/testrun_dyck1_4h/model_25.pt'
model = LanguageModel(config, vocab, 'cpu', None)
model.load_state_dict(torch.load(chkpt_path, map_location='cpu')['model_state_dict'])

a, b = model.model(torch.randint(0, 3, (1, 10)), get_encoder_reps=True)
b.shape

2025-01-10 12:25:55.734259: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


torch.Size([1, 10, 32])

In [6]:
from torch import nn

class ProbingClassifer(nn.Module):
    def __init__(self, input_size, output_size):
        super(ProbingClassifer, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc4 = nn.Linear(128, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.relu(out)
        out = self.fc4(out)
        return out

In [22]:
Y.max()

tensor(17)

In [17]:
from torch.utils.data import DataLoader, TensorDataset

X, Y = create_dataset(model, data[:100], "cpu")
dataset = TensorDataset(X, Y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [18]:
next(iter(dataloader))[0].shape

torch.Size([32, 32])

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
probing_model = ProbingClassifer(32, 18).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(probing_model.parameters(), lr=0.001)

In [24]:
num_epochs = 30
acc_id = []
for epoch in range(num_epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    probing_model.train()
    for inputs, targets in dataloader:
        optimizer.zero_grad()
        outputs = probing_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        #reshape target to match output
        correct += (predicted == targets).sum().item()
        total += targets.size(0)
        
    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    acc_id.append(accuracy)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

Epoch [1/30], Loss: 1.7005, Accuracy: 0.5614
Epoch [2/30], Loss: 1.3666, Accuracy: 0.5852
Epoch [3/30], Loss: 1.1622, Accuracy: 0.6260
Epoch [4/30], Loss: 1.0140, Accuracy: 0.6496
Epoch [5/30], Loss: 0.9678, Accuracy: 0.6520
Epoch [6/30], Loss: 0.9581, Accuracy: 0.6512
Epoch [7/30], Loss: 0.9323, Accuracy: 0.6550
Epoch [8/30], Loss: 0.9248, Accuracy: 0.6576
Epoch [9/30], Loss: 0.9184, Accuracy: 0.6554
Epoch [10/30], Loss: 0.9186, Accuracy: 0.6576
Epoch [11/30], Loss: 0.9137, Accuracy: 0.6560
Epoch [12/30], Loss: 0.9468, Accuracy: 0.6502
Epoch [13/30], Loss: 0.9079, Accuracy: 0.6606
Epoch [14/30], Loss: 0.9039, Accuracy: 0.6638
Epoch [15/30], Loss: 0.9044, Accuracy: 0.6592
Epoch [16/30], Loss: 0.9019, Accuracy: 0.6670
Epoch [17/30], Loss: 0.9027, Accuracy: 0.6710
Epoch [18/30], Loss: 0.9012, Accuracy: 0.6648
Epoch [19/30], Loss: 0.8992, Accuracy: 0.6640
Epoch [20/30], Loss: 0.9164, Accuracy: 0.6624
Epoch [21/30], Loss: 0.9193, Accuracy: 0.6580
Epoch [22/30], Loss: 0.8932, Accuracy: 0.66