In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import json
from typing import Union
from jaxtyping import Int
from torch import Tensor
import random
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [35]:
VOCAB = {"[pad]": 0, "[start]": 1, "[end]": 2, "(": 3, ")": 4, "[": 5, "]": 6, "{": 7, "}": 8}
HIDDEN_SIZE = 56
HEAD_SIZE = 28
NUM_LAYERS = 3
NUM_HEADS = 2
MAX_LEN = 110
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 0.001
device = "mps"

In [36]:
class BracketsDataset(torch.utils.data.Dataset):
    def __init__(self, data_tuples, tokenizer):
        self.tokenizer = SimpleTokenizer("()[]{}")
        self.strs = [x[0] for x in data_tuples]
        self.isbal = torch.tensor([x[1] for x in data_tuples])
        self.toks = self.tokenizer.tokenize(self.strs)
        self.open_proportion = torch.tensor([(s.count("(")+s.count("[")+s.count("{")) / len(s) for s in self.strs])
        self.starts_open = torch.tensor([(s[0] == "(" or s[0] == "[" or s== "{") for s in self.strs]).bool()

    def __len__(self):
        return len(self.strs)

    def __getitem__(self, idx):
        return self.strs[idx], self.isbal[idx], self.toks[idx]

    def to(self, device):
        self.isbal = self.isbal.to(device)
        self.toks = self.toks.to(device)
        self.open_proportion = self.open_proportion.to(device)
        self.starts_open = self.starts_open.to(device)
        return self

In [37]:
class SimpleTokenizer:
    START_TOKEN = 1
    PAD_TOKEN = 0
    END_TOKEN = 2
    base_d = {"[start]": START_TOKEN, "[pad]": PAD_TOKEN, "[end]": END_TOKEN}

    def __init__(self, alphabet: str):
        self.alphabet = alphabet
        # the 3 is because there are 3 special tokens (defined just above)
        self.t_to_i = {**{c: i + 3 for i, c in enumerate(alphabet)}, **self.base_d}
        self.i_to_t = {i: c for c, i in self.t_to_i.items()}

    def tokenize(self, strs: list[str], max_len=None) -> Int[Tensor, "batch seq"]:
        def c_to_int(c: str) -> int:
            if c in self.t_to_i:
                return self.t_to_i[c]
            else:
                raise ValueError(c)

        if isinstance(strs, str):
            strs = [strs]

        if max_len is None:
            max_len = max((max(len(s) for s in strs), 1))

        ints = [
            [self.START_TOKEN]
            + [c_to_int(c) for c in s]
            + [self.END_TOKEN]
            + [self.PAD_TOKEN] * (max_len - len(s))
            for s in strs
        ]
        return torch.tensor(ints)

    def decode(self, tokens) -> list[str]:
        assert tokens.ndim >= 2, "Need to have a batch dimension"

        def int_to_c(c: int) -> str:
            if c < len(self.i_to_t):
                return self.i_to_t[c]
            else:
                raise ValueError(c)

        return [
            "".join(
                int_to_c(i.item()) for i in seq[1:] if i != self.PAD_TOKEN and i != self.END_TOKEN
            )
            for seq in tokens
        ]

    def __repr__(self) -> str:
        return f"SimpleTokenizer({self.alphabet!r})"


In [38]:
def load_data():
    with open("./dataset.json") as f:
        data_tuples = json.load(f)
    data_tuples = data_tuples
    random.shuffle(data_tuples)

    train_size = int(0.7 * len(data_tuples))
    val_size = int(0.1 * len(data_tuples))
    test_size = len(data_tuples) - train_size - val_size

    train_data = data_tuples[:train_size]
    val_data = data_tuples[train_size:train_size+val_size]
    test_data = data_tuples[train_size+val_size:]

    tokenizer = SimpleTokenizer("()[]{}")
    train_dataset = BracketsDataset(train_data, tokenizer).to(device)
    val_dataset = BracketsDataset(val_data, tokenizer).to(device)
    test_dataset = BracketsDataset(test_data, tokenizer).to(device)

    return train_dataset, val_dataset, test_dataset

In [39]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, head_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.hidden_size = hidden_size
        
        self.W_q = nn.Linear(hidden_size, num_heads * head_size)
        self.W_k = nn.Linear(hidden_size, num_heads * head_size)
        self.W_v = nn.Linear(hidden_size, num_heads * head_size)
        self.W_o = nn.Linear(num_heads * head_size, hidden_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, hidden_size = x.size()
        
        # Project and reshape queries, keys, values
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_size ** 0.5)
        
        # Mask the key matrix to ignore padded tokens
        if mask is not None:
            # Create a mask that broadcasts across attention heads
            mask = mask.unsqueeze(1).unsqueeze(1)  # Shape: [batch_size, 1, 1, seq_len]
            
            # Create a key-specific mask by repeating the mask for each head
            key_mask = mask.expand(-1, self.num_heads, seq_len, -1)
            
            # Set attention scores to -inf where the key mask is 0 (padded tokens)
            scores = scores.masked_fill(key_mask == 0, float('-inf'))
        
        # Compute attention weights
        attention_weights = torch.softmax(scores, dim=-1)
        
        # Apply attention to values
        context = torch.matmul(attention_weights, V)
        
        # Transpose and reshape back to original dimensions
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        
        # Project back to hidden size
        return self.W_o(context)

class TransformerBlock(nn.Module):
    def __init__(self, hidden_size, head_size, num_heads):
        super().__init__()
        self.attention = MultiHeadAttention(hidden_size, head_size, num_heads)
        self.layernorm1 = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.ReLU(),
            nn.Linear(4 * hidden_size, hidden_size)
        )
        self.layernorm2 = nn.LayerNorm(hidden_size)

    def forward(self, x, mask=None):
        attn_output = self.attention(x, mask)
        x = self.layernorm1(x + attn_output)
        mlp_output = self.mlp(x)
        return self.layernorm2(x + mlp_output)

class BalancedParenthesesModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_len, num_layers, num_heads):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        
        # Create positional encodings on the correct device
        self.register_buffer('positional_encodings', 
            torch.zeros(1, max_len, hidden_size), 
            persistent=False
        )
        
        self.layers = nn.ModuleList([
            TransformerBlock(hidden_size, hidden_size // num_heads, num_heads)
            for _ in range(num_layers)
        ])
        
        self.layernorm_final = nn.LayerNorm(hidden_size)
        
        # Global pooling and classification
        self.global_pool = nn.AdaptiveAvgPool1d(1)  # Global average pooling
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size // 2, 2)
        )

    def forward(self, x, mask=None):
        # Ensure positional encodings are on the same device as input
        seq_len = x.size(1)
        
        # Slice and move positional encodings to input device
        positional_encodings = self.positional_encodings[:, :seq_len, :].to(x.device)
        
        # Embedding with positional encodings
        x = self.embedding(x) + positional_encodings
        
        # Prepare the mask
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            mask = mask.to(x.device)
        
        # Pass through transformer layers
        for layer in self.layers:
            x = layer(x, mask)
        
        # Final layer normalization
        x = self.layernorm_final(x)
        
        # Global pooling
        x = x.transpose(1, 2)  # Change to (batch, hidden, seq)
        x = self.global_pool(x).squeeze(-1)  # Global average pooling
        
        # Classification
        logits = self.classifier(x)
        
        return logits

In [40]:
def train_model(model, train_dataset, val_dataset, num_epochs, batch_size, lr, device):
    model.to(device)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.0]).to(device))
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch, (_, labels, tokens) in enumerate(train_loader):
            optimizer.zero_grad()
            labels.to(device)
            tokens.to(device)
            
            # Create mask for padding
            mask = (tokens != VOCAB["[pad]"]).float().to(device)
            
            output = model(tokens, mask)
            loss = criterion(output, labels)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        # Validation
        model.eval()
        val_losses = []
        val_preds = []
        val_true = []
        
        with torch.no_grad():
            for _, labels, tokens in val_loader:
                labels.to(device)
                tokens.to(device)
                mask = (tokens != VOCAB["[pad]"]).float().to(device)
                output = model(tokens, mask)
                loss = criterion(output, labels)
                
                val_losses.append(loss.item())
                val_preds.extend(output.argmax(dim=1).cpu().numpy())
                val_true.extend(labels.cpu().numpy())
        
        val_loss = np.mean(val_losses)
        val_acc = accuracy_score(val_true, val_preds)
        
        print(f"Epoch {epoch+1}: Train Loss {total_loss/len(train_loader):.4f}, "
              f"Val Loss {val_loss:.4f}, Val Accuracy {val_acc:.4f}")
        
        # Early stopping and model saving
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

In [41]:
def evaluate_model(model, test_dataset, device):
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
    correct = 0
    total = 0
    incorrect = []
    with torch.no_grad():
        for batch, (_, labels, tokens) in enumerate(test_loader):
            labels.to(device)
            tokens.to(device)
            mask = (tokens != VOCAB["[pad]"]).unsqueeze(1).unsqueeze(2).to(device)
            output = model(tokens)
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            incorrect.extend((predicted != labels).nonzero())

    print(f"Accuracy: {100*correct / total}")

In [42]:
train_dataset, val_dataset, test_dataset = load_data()

In [43]:
model = BalancedParenthesesModel(len(VOCAB), HIDDEN_SIZE, MAX_LEN, NUM_LAYERS, NUM_HEADS)

In [44]:
train_model(model, train_dataset, val_dataset, EPOCHS, BATCH_SIZE, LEARNING_RATE, device)

RuntimeError: expand(MPSFloatType{[32, 1, 1, 1, 1, 100]}, size=[-1, 2, 100, -1]): the number of sizes provided (4) must be greater or equal to the number of dimensions in the tensor (6)

In [None]:
evaluate_model(model, test_dataset, device)

Accuracy: 53.75
