In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import json
from typing import Union
from jaxtyping import Int
from torch import Tensor
import random
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [10]:
VOCAB = {"[pad]": 0, "[start]": 1, "[end]": 2, "a": 3, "b": 4}
HIDDEN_SIZE = 56
HEAD_SIZE = 28
NUM_LAYERS = 3
NUM_HEADS = 2
MAX_LEN = 100
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 0.001
device = "mps"

In [11]:

class SimpleTokenizer:
    START_TOKEN = 0
    PAD_TOKEN = 1
    END_TOKEN = 2
    base_d = {"[start]": START_TOKEN, "[pad]": PAD_TOKEN, "[end]": END_TOKEN}

    def __init__(self, alphabet: str):
        self.alphabet = alphabet
        # the 3 is because there are 3 special tokens (defined just above)
        self.t_to_i = {**{c: i + 3 for i, c in enumerate(alphabet)}, **self.base_d}
        self.i_to_t = {i: c for c, i in self.t_to_i.items()}

    def tokenize(self, strs: list[str], max_len=None) -> Int[Tensor, "batch seq"]:
        def c_to_int(c: str) -> int:
            if c in self.t_to_i:
                return self.t_to_i[c]
            else:
                raise ValueError(c)

        if isinstance(strs, str):
            strs = [strs]

        if max_len is None:
            max_len = max((max(len(s) for s in strs), 1))

        ints = [
            [self.START_TOKEN]
            + [c_to_int(c) for c in s]
            + [self.END_TOKEN]
            + [self.PAD_TOKEN] * (max_len - len(s))
            for s in strs
        ]
        return torch.tensor(ints)

    def decode(self, tokens) -> list[str]:
        assert tokens.ndim >= 2, "Need to have a batch dimension"

        def int_to_c(c: int) -> str:
            if c < len(self.i_to_t):
                return self.i_to_t[c]
            else:
                raise ValueError(c)

        return [
            "".join(
                int_to_c(i.item()) for i in seq[1:] if i != self.PAD_TOKEN and i != self.END_TOKEN
            )
            for seq in tokens
        ]

    def __repr__(self) -> str:
        return f"SimpleTokenizer({self.alphabet!r})"


class BracketsDataset(torch.utils.data.Dataset):
    def __init__(self, data_tuples, tokenizer):
        self.tokenizer = SimpleTokenizer("ab")
        self.strs = [x[0] for x in data_tuples]
        self.isbal = torch.tensor([x[1] for x in data_tuples])
        self.toks = self.tokenizer.tokenize(self.strs)
        self.open_proportion = torch.tensor([s.count("a") / len(s) for s in self.strs])
        self.starts_open = torch.tensor([s[0] == "a" for s in self.strs]).bool()

    def __len__(self):
        return len(self.strs)

    def __getitem__(self, idx):
        return self.strs[idx], self.isbal[idx], self.toks[idx]

    def to(self, device):
        self.isbal = self.isbal.to(device)
        self.toks = self.toks.to(device)
        self.open_proportion = self.open_proportion.to(device)
        self.starts_open = self.starts_open.to(device)
        return self

def load_data():
    with open("/Users/utkarsh/Documents/neural-toc/naacl_work/anbn_data.json") as f:
        data_tuples = json.load(f)
    data_tuples = data_tuples
    random.shuffle(data_tuples)

    train_size = int(0.8 * len(data_tuples))
    val_size = int(0.1 * len(data_tuples))
    test_size = len(data_tuples) - train_size - val_size

    train_data = data_tuples[:train_size]
    val_data = data_tuples[train_size:train_size+val_size]
    test_data = data_tuples[train_size+val_size:]

    tokenizer = SimpleTokenizer("ab")
    train_dataset = BracketsDataset(train_data, tokenizer).to(device)
    val_dataset = BracketsDataset(val_data, tokenizer).to(device)
    test_dataset = BracketsDataset(test_data, tokenizer).to(device)

    return train_dataset, val_dataset, test_dataset

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, head_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.W_q = nn.Linear(hidden_size, num_heads * head_size)
        self.W_k = nn.Linear(hidden_size, num_heads * head_size)
        self.W_v = nn.Linear(hidden_size, num_heads * head_size)
        self.W_o = nn.Linear(num_heads * head_size, hidden_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, hidden_size = x.size()
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_size ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, V).transpose(1, 2).contiguous()
        context = context.view(batch_size, seq_len, -1)
        return self.W_o(context)

class TransformerBlock(nn.Module):
    def __init__(self, hidden_size, head_size, num_heads):
        super().__init__()
        self.attention = MultiHeadAttention(hidden_size, head_size, num_heads)
        self.layernorm1 = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.ReLU(),
            nn.Linear(4 * hidden_size, hidden_size)
        )
        self.layernorm2 = nn.LayerNorm(hidden_size)

    def forward(self, x, mask=None):
        attn_output = self.attention(x, mask)
        x = self.layernorm1(x + attn_output)
        mlp_output = self.mlp(x)
        return self.layernorm2(x + mlp_output)

# class BalancedParenthesesModel(nn.Module):
#     def __init__(self, vocab_size, hidden_size, max_len, num_layers, num_heads):
#         super().__init__()
#         self.embedding = nn.Embedding(vocab_size, hidden_size)
#         self.positional_encodings = nn.Parameter(torch.zeros(1, max_len, hidden_size))
#         self.layers = nn.ModuleList([
#             TransformerBlock(hidden_size, HEAD_SIZE, num_heads)
#             for _ in range(num_layers)
#         ])
#         self.layernorm_final = nn.LayerNorm(hidden_size)
#         self.unembedding = nn.Linear(hidden_size, 2)

#     def forward(self, x, mask=None):
#         seq_len = x.size(1)
#         x = self.embedding(x) + self.positional_encodings[:, :seq_len, :]
#         for layer in self.layers:
#             x = layer(x, mask)
#         x = self.layernorm_final(x)
#         logits = self.unembedding(x[:, 0, :])
#         return logits

class BalancedParenthesesModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_len, num_layers, num_heads):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.positional_encodings = nn.Parameter(torch.zeros(1, max_len, hidden_size))
        self.layers = nn.ModuleList([
            TransformerBlock(hidden_size, HEAD_SIZE, num_heads)
            for _ in range(num_layers)
        ])
        self.layernorm_final = nn.LayerNorm(hidden_size)
        self.unembedding = nn.Linear(hidden_size, 2)
        
        # Store intermediate states during forward pass
        self.token_states = None

    def forward(self, x, mask=None, return_states=False):
        seq_len = x.size(1)
        
        # Embedding with positional encodings
        x = self.embedding(x) + self.positional_encodings[:, :seq_len, :]
        
        # Pass through transformer layers
        for layer in self.layers:
            x = layer(x, mask)
        
        # Final layer normalization
        x = self.layernorm_final(x)
        
        # Compute logits
        logits = self.unembedding(x[:, 0, :])
        
        # Store token states if requested
        if return_states:
            return logits, x.detach()
        
        return logits

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

def train_model(model, train_dataset, val_dataset, num_epochs, batch_size, lr, device):
    """
    Train the Transformer model on the brackets dataset.
    
    Args:
        model (TransformerModel): The Transformer model to train.
        train_dataset (BracketsDataset): The training dataset.
        val_dataset (BracketsDataset): The validation dataset.
        num_epochs (int): The number of training epochs.
        batch_size (int): The batch size.
        lr (float): The learning rate.
        device (torch.device): The device to use for training (CPU or GPU).
    """
    # Set up data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Set up the optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Move the model to the specified device
    model.to(device)

    # Train the model
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # Training loop
        model.train()
        total_train_loss = 0
        for batch, (_, labels, tokens) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(tokens)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        # Validation loop
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for _, (_, labels, tokens) in enumerate(val_loader):
                mask = (tokens != VOCAB["[pad]"]).unsqueeze(1).unsqueeze(2).to(device)
                output = model(tokens, mask)
                loss = criterion(output, labels)
                total_val_loss += loss.item()

        # Print training and validation results
        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    print("Training complete.")
    return model

In [14]:
train_data, val_data, test_data = load_data()
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)

model = BalancedParenthesesModel(len(VOCAB), HIDDEN_SIZE, MAX_LEN, NUM_LAYERS, NUM_HEADS)
train_model(model, train_data, val_data, EPOCHS, BATCH_SIZE, LEARNING_RATE, device="mps")


Epoch 1/50, Train Loss: 0.6923, Val Loss: 0.6257
Epoch 2/50, Train Loss: 0.7290, Val Loss: 0.6119
Epoch 3/50, Train Loss: 0.6491, Val Loss: 0.6721
Epoch 4/50, Train Loss: 0.6804, Val Loss: 0.6129
Epoch 5/50, Train Loss: 0.6605, Val Loss: 0.5671
Epoch 6/50, Train Loss: 0.6429, Val Loss: 0.5664
Epoch 7/50, Train Loss: 0.6652, Val Loss: 0.5769
Epoch 8/50, Train Loss: 0.6457, Val Loss: 0.6018
Epoch 9/50, Train Loss: 0.6511, Val Loss: 0.6139
Epoch 10/50, Train Loss: 0.6298, Val Loss: 0.5816
Epoch 11/50, Train Loss: 0.6628, Val Loss: 0.5636
Epoch 12/50, Train Loss: 0.6438, Val Loss: 0.5718
Epoch 13/50, Train Loss: 0.6590, Val Loss: 0.5904
Epoch 14/50, Train Loss: 0.6516, Val Loss: 0.6095
Epoch 15/50, Train Loss: 0.6440, Val Loss: 0.6142
Epoch 16/50, Train Loss: 0.6737, Val Loss: 0.5895
Epoch 17/50, Train Loss: 0.6666, Val Loss: 0.5976
Epoch 18/50, Train Loss: 0.6410, Val Loss: 0.6148
Epoch 19/50, Train Loss: 0.6515, Val Loss: 0.5951
Epoch 20/50, Train Loss: 0.6631, Val Loss: 0.5847
Epoch 21/

BalancedParenthesesModel(
  (embedding): Embedding(5, 56)
  (layers): ModuleList(
    (0-2): 3 x TransformerBlock(
      (attention): MultiHeadAttention(
        (W_q): Linear(in_features=56, out_features=56, bias=True)
        (W_k): Linear(in_features=56, out_features=56, bias=True)
        (W_v): Linear(in_features=56, out_features=56, bias=True)
        (W_o): Linear(in_features=56, out_features=56, bias=True)
      )
      (layernorm1): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=56, out_features=224, bias=True)
        (1): ReLU()
        (2): Linear(in_features=224, out_features=56, bias=True)
      )
      (layernorm2): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
    )
  )
  (layernorm_final): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
  (unembedding): Linear(in_features=56, out_features=2, bias=True)
)

In [15]:
tokeniser = SimpleTokenizer("ab")

In [16]:
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)
correct = 0
total = 0
incorrect = []
with torch.no_grad():
    for batch, (_, labels, tokens) in enumerate(test_loader):
        output = model(tokens)
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        if (predicted != labels).any():
            incorrect.append((tokeniser.decode(tokens), labels, predicted))

test_accuracy = 100 * correct / total
print(f"Test Accuracy: {test_accuracy:.2f}%")

Test Accuracy: 64.29%
