# Transformer activation pattern demonstration
Predict patterns in a binary addition task. Use a simple tiny transformer architecture. 

In [43]:
import torch 
from torch import nn
from torch.utils.data import Dataset

vocab = ['0', '1', '+', '=']
vocab_size = 4

class BinaryAdditionDataset(Dataset):
    def __init__(self, num_bits=3, num_samples=1000):
        self.num_bits = num_bits
        self.samples = []
        
        for _ in range(num_samples):
            a = torch.randint(0, 2**num_bits, (1,)).item()
            b = torch.randint(0, 2**num_bits, (1,)).item()
            
            # Convert to binary number strings
            a_bin = format(a, f'0{num_bits}b')
            b_bin = format(b, f'0{num_bits}b')
            sum_bin = format(a + b, f'0{num_bits + 1}b')  # +1 to possibly carry
            
            # Input string e.g. "101+011="
            input_str = f"{a_bin}+{b_bin}="
            
            self.samples.append((input_str, sum_bin))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]


class TinyTransformer(nn.Module):
    def __init__(self, d_model=32, nhead=2, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=64, 
            dropout=0.1, 
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, 2) # Output logits are '0' or '1'

    def forward(self, x, attention_mask=None):
        x = self.embedding(x)
        x = self.transformer(x, src_key_padding_mask=attention_mask)
        return self.fc(x)
    

class BinaryAdditionTrainer:
    def __init__(self):
        self.char_to_idx = {'0': 0, '1': 1, '+': 2, '=': 3}
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}

    def encode_input(self, input_str):
        return torch.tensor([self.char_to_idx[c] for c in input_str])
    
    def decode_output(self, output):
        return ''.join([self.idx_to_char[i] for i in output])
    
    def train_step(self, model, optimizer, input_str, target_str):
        optimizer.zero_grad()

        x = self.encode_input(input_str).unsqueeze(0)

        logits = model(x)

        eq_pos = input_str.index('=') # Position of '='
        relevant_logits = logits[0, eq_pos+1:eq_pos+1+len(target_str)]

        target = torch.tensor([int(c) for c in target_str])

        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(relevant_logits, target)

        loss.backward()
        optimizer.step()

        return loss.item()
    
    def visualize_attention(self, model, input_str):
        # get attention weights from the last layer
        with torch.no_grad():
            x = self.encode_input(input_str).unsqueeze(0)
            attention_weights = []
            
            def hook_fn(module, input, output):
                attention_weights.append(output[1]) # attention weights

            # Register hook to capture attention weights
            for layer in model.transformer.layers:
                layer.self_attn.register_forward_hook(hook_fn)

            _ = model(x)

            # Remove hooks
            for layer in model.transformer.layers:
                layer.self_attn._forward_hooks.clear()

        return attention_weights


In [44]:
data = BinaryAdditionDataset(num_bits=3, num_samples=1000)
model = TinyTransformer()
trainer = BinaryAdditionTrainer()

In [45]:
trainer.train_step(model, torch.optim.Adam(model.parameters()), '101+011=', '1000')

torch.Size([1, 8, 2])
tensor([[[-0.2917, -0.6225],
         [ 0.7952, -0.0955],
         [-0.3083, -0.8198],
         [ 0.3369, -0.8783],
         [ 0.6477,  0.1202],
         [-0.0774, -0.8141],
         [-0.2552, -0.6812],
         [ 0.4612, -1.3199]]], grad_fn=<ViewBackward0>)
torch.Size([0, 2])
tensor([], size=(0, 2), grad_fn=<SliceBackward0>)


ValueError: Expected input batch_size (0) to match target batch_size (4).

In [33]:
test = torch.tensor([0, 1, 2, 3]).unsqueeze(0)

In [36]:
test

tensor([[0, 1, 2, 3]])

In [None]:
# OLD
def generate_binary_additions(num_samples, max_bits=4):
    '''
    Generate
    - X: list of strings '0000+0000'
    - y: list of strings '00000'
    '''
    X, y = [], []
    for _ in range(num_samples):
        a = torch.randint(0, 2**max_bits) # Integers between 0 and 2^max_bits-1
        b = torch.randint(0, 2**max_bits)
        X.append(format(a, f'0{max_bits}b') + '+' + format(b, f'0{max_bits}b'))
        y.append(format(a + b, f'0{max_bits+1}b'))
    return X, y

# OLD
def encode_binary_string_data(X, y):
    ''' Encode: '0' -> 0, '1' -> 1, '+' -> 2 '''
    X_encoded = torch.tensor([[int(c) if c != '+' else 2 for c in x] for x in X])
    y_encoded = torch.tensor([[int(c) for c in y_] for y_ in y])
    return X_encoded, y_encoded


# Example transformation:
x = "101+011"  # Input string
X_encoded = [int(c) if c in '01' else 2 for c in x]  # [1,0,1,2,0,1,1]

result = "1000"  # Output string
y_encoded = [int(c) for c in result]  # [1,0,0,0]

In [15]:
# X, y = generate_binary_additions(1000)
# X_encoded, y_encoded = encode_binary_string_data(X, y)