In [55]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import json
from typing import Union
from jaxtyping import Int
from torch import Tensor

In [34]:
VOCAB = {"[pad]": 0, "[start]": 1, "[end]": 2, "(": 3, ")": 4}
HIDDEN_SIZE = 56
HEAD_SIZE = 28
NUM_LAYERS = 3
NUM_HEADS = 2
MAX_LEN = 50
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.001
device = "mps"

In [35]:
def load_data(file):
    N_SAMPLES = 5000
    with open(file) as f:
        data_tuples: list[tuple[str, bool]] = json.load(f)
        # print(f"loaded {len(data_tuples)} examples")
    assert isinstance(data_tuples, list)
    data_tuples = data_tuples[:N_SAMPLES]
    data = BracketsDataset(data_tuples).to(device)
    data_mini = BracketsDataset(data_tuples[:100]).to(device)
    return data, data_mini, data_tuples


In [56]:

class SimpleTokenizer:
    START_TOKEN = 0
    PAD_TOKEN = 1
    END_TOKEN = 2
    base_d = {"[start]": START_TOKEN, "[pad]": PAD_TOKEN, "[end]": END_TOKEN}

    def __init__(self, alphabet: str):
        self.alphabet = alphabet
        # the 3 is because there are 3 special tokens (defined just above)
        self.t_to_i = {**{c: i + 3 for i, c in enumerate(alphabet)}, **self.base_d}
        self.i_to_t = {i: c for c, i in self.t_to_i.items()}

    def tokenize(self, strs: list[str], max_len=None) -> Int[Tensor, "batch seq"]:
        def c_to_int(c: str) -> int:
            if c in self.t_to_i:
                return self.t_to_i[c]
            else:
                raise ValueError(c)

        if isinstance(strs, str):
            strs = [strs]

        if max_len is None:
            max_len = max((max(len(s) for s in strs), 1))

        ints = [
            [self.START_TOKEN]
            + [c_to_int(c) for c in s]
            + [self.END_TOKEN]
            + [self.PAD_TOKEN] * (max_len - len(s))
            for s in strs
        ]
        return torch.tensor(ints)

    def decode(self, tokens) -> list[str]:
        assert tokens.ndim >= 2, "Need to have a batch dimension"

        def int_to_c(c: int) -> str:
            if c < len(self.i_to_t):
                return self.i_to_t[c]
            else:
                raise ValueError(c)

        return [
            "".join(
                int_to_c(i.item()) for i in seq[1:] if i != self.PAD_TOKEN and i != self.END_TOKEN
            )
            for seq in tokens
        ]

    def __repr__(self) -> str:
        return f"SimpleTokenizer({self.alphabet!r})"


In [57]:
class BracketsDataset:
    """A dataset containing sequences, is_balanced labels, and tokenized sequences"""

    def __init__(self, data_tuples: list):
        """
        data_tuples is list[tuple[str, bool]] signifying sequence and label
        """
        self.tokenizer = SimpleTokenizer("()")
        self.strs = [x[0] for x in data_tuples]
        
        # Debugging print statement
        print("data_tuples:", data_tuples)
        
        # Ensure the second element of each tuple is a boolean
        self.isbal = torch.tensor([bool(x[1]) for x in data_tuples])
        
        self.toks = self.tokenizer.tokenize(self.strs)
        self.open_proportion = torch.tensor([s.count("(") / len(s) for s in self.strs])
        self.starts_open = torch.tensor([s[0] == "(" for s in self.strs]).bool()

    def __len__(self) -> int:
        return len(self.strs)

    def __getitem__(self, idx) -> "BracketsDataset | tuple[str, t.Tensor, t.Tensor]":
        if isinstance(idx, slice):
            return self.__class__(list(zip(self.strs[idx], self.isbal[idx])))
        return (self.strs[idx], self.isbal[idx], self.toks[idx])

    def to(self, device) -> "BracketsDataset":
        self.isbal = self.isbal.to(device)
        self.toks = self.toks.to(device)
        self.open_proportion = self.open_proportion.to(device)
        self.starts_open = self.starts_open.to(device)
        return self

    @property
    def seq_length(self) -> int:
        return self.toks.size(-1)

    @classmethod
    def with_length(
        cls, data_tuples: list[tuple[str, bool]], selected_len: int
    ) -> "BracketsDataset":
        return cls([(s, b) for (s, b) in data_tuples if len(s) == selected_len])

    @classmethod
    def with_start_char(
        cls, data_tuples: list[tuple[str, bool]], start_char: str
    ) -> "BracketsDataset":
        return cls([(s, b) for (s, b) in data_tuples if s[0] == start_char])


In [58]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, head_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.W_q = nn.Linear(hidden_size, num_heads * head_size)
        self.W_k = nn.Linear(hidden_size, num_heads * head_size)
        self.W_v = nn.Linear(hidden_size, num_heads * head_size)
        self.W_o = nn.Linear(num_heads * head_size, hidden_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, hidden_size = x.size()
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_size ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, V).transpose(1, 2).contiguous()
        context = context.view(batch_size, seq_len, -1)
        return self.W_o(context)

class TransformerBlock(nn.Module):
    def __init__(self, hidden_size, head_size, num_heads):
        super().__init__()
        self.attention = MultiHeadAttention(hidden_size, head_size, num_heads)
        self.layernorm1 = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.ReLU(),
            nn.Linear(4 * hidden_size, hidden_size)
        )
        self.layernorm2 = nn.LayerNorm(hidden_size)

    def forward(self, x, mask=None):
        attn_output = self.attention(x, mask)
        x = self.layernorm1(x + attn_output)
        mlp_output = self.mlp(x)
        return self.layernorm2(x + mlp_output)

class BalancedParenthesesModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_len, num_layers, num_heads):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.positional_encodings = nn.Parameter(torch.zeros(1, max_len, hidden_size))
        self.layers = nn.ModuleList([
            TransformerBlock(hidden_size, HEAD_SIZE, num_heads)
            for _ in range(num_layers)
        ])
        self.layernorm_final = nn.LayerNorm(hidden_size)
        self.unembedding = nn.Linear(hidden_size, 2)  # Binary classification

    def forward(self, x, mask=None):
        seq_len = x.size(1)
        x = self.embedding(x) + self.positional_encodings[:, :seq_len, :]
        for layer in self.layers:
            x = layer(x, mask)
        x = self.layernorm_final(x)
        logits = self.unembedding(x[:, 0, :])  # Use only the [start] token for classification
        return logits

In [52]:
def train_model(model, train_loader, val_loader, epochs, lr):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            logits = model(inputs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        val_loss, val_acc = evaluate_model(model, val_loader, criterion)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

def evaluate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "mps")

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits = model(inputs)
            loss = criterion(logits, labels)
            val_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return val_loss / len(val_loader), correct / total

In [63]:
data = load_data("/Users/utkarsh/Documents/neural-toc/naacl_work/ARENA_3.0/chapter1_transformer_interp/exercises/part51_balanced_bracket_classifier/brackets_data.json")
train_data, val_data, test_data = torch.utils.data.random_split(data, [int(0.7 * len(data)), int(0.15 * len(data)), len(data) - int(0.85 * len(data))])
train_loader = DataLoader(BracketsDataset(train_data), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(BracketsDataset(val_data), batch_size=BATCH_SIZE)

model = BalancedParenthesesModel(len(VOCAB), HIDDEN_SIZE, MAX_LEN, NUM_LAYERS, NUM_HEADS)
train_model(model, train_loader, val_loader, EPOCHS, LEARNING_RATE)


data_tuples: [['(()((())()(((())()))((', False], ['))(()()(', False], ['(())(())', True], [')())())()(())()))(((())())', False], [')(()(())(()))()))())))', False], ['(()(()()())())', True], ['(())()()))))', False], ['(()((()(()))))', True], ['()', True], ['()', True], ['(()(()())(()()())(((())((()))())))', True], ['))()', False], ['()(((()(())()()())))', True], ['()()', True], ['())((())(()())(())()(())((()()()(())))', False], ['()(())', True], [')())()())(', False], ['()())(()((()())()())))', False], ['((((())))(())())((((()))))()', True], [')((()()(((', False], ['()', True], ['((()((())()())())())', True], ['()((((', False], ['()', True], ['()(((())))', True], ['()()((((()))()(())))()', True], ['((((((((((())(()()()', False], [')(()())))()((())))()(()(()', False], [')))())())))()(((())()(())(())))))()())()', False], ['(()((()()))()())()', True], ['()', True], ['()()()()))))()()((()))(())', False], ['((', False], ['()', True], ['(())()((((()))))((((())(', False], ['(((())(())()()))(()

ValueError: (()((())()(((())()))((

In [64]:
data

(<__main__.BracketsDataset at 0x1478bc8b0>,
 <__main__.BracketsDataset at 0x14785ceb0>,
 [['(()((())()(((())()))((', False],
  ['))(()()(', False],
  ['(())(())', True],
  [')())())()(())()))(((())())', False],
  [')(()(())(()))()))())))', False],
  ['(()(()()())())', True],
  ['(())()()))))', False],
  ['(()((()(()))))', True],
  ['()', True],
  ['()', True],
  ['(()(()())(()()())(((())((()))())))', True],
  ['))()', False],
  ['()(((()(())()()())))', True],
  ['()()', True],
  ['())((())(()())(())()(())((()()()(())))', False],
  ['()(())', True],
  [')())()())(', False],
  ['()())(()((()())()())))', False],
  ['((((())))(())())((((()))))()', True],
  [')((()()(((', False],
  ['()', True],
  ['((()((())()())())())', True],
  ['()((((', False],
  ['()', True],
  ['()(((())))', True],
  ['()()((((()))()(())))()', True],
  ['((((((((((())(()()()', False],
  [')(()())))()((())))()(()(()', False],
  [')))())())))()(((())()(())(())))))()())()', False],
  ['(()((()()))()())()', True],
  ['()