# Single Head Attention

In [47]:
import torch
import torch.nn as nn

class SingleHeadAttention(nn.Module):
        
    def __init__(self, model_dim: int, head_size: int):
        super().__init__()
        self.key_layer = nn.Linear(model_dim, head_size, bias=False)
        self.query_layer = nn.Linear(model_dim, head_size, bias=False)
        self.value_layer = nn.Linear(model_dim, head_size, bias=False)
    
    def forward(self, embedded):
        K = self.key_layer(embedded)
        Q = self.query_layer(embedded)
        V = self.value_layer(embedded)
        _, T, A = K.shape
        scores = Q @ torch.transpose(K, 1, 2) / (A ** 0.5)
        mask = torch.tril(torch.ones(T, T))
        scores = scores.masked_fill(mask == 0, float("-inf"))
        scores = nn.functional.softmax(scores, dim=-1)
        return scores @ V

# Multi-Headed Self Attention

In [48]:
class MultiHeadedSelfAttention(nn.Module):
    
    def __init__(self, model_dim: int, num_heads: int):
        super().__init__()
        self.attention_heads = nn.ModuleList([
            SingleHeadAttention(model_dim, model_dim // num_heads) for _ in range(num_heads)
        ])
        self.compute = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(0.2)
        

    def forward(self, embedded):
        return self.dropout(self.compute(torch.cat([head(embedded) for head in self.attention_heads], dim=-1)))

# Vanilla Neural Network

In [49]:
class VanillaNeuralNetwork(nn.Module):
    
    def __init__(self, model_dim: int):
        super().__init__()
        self.first_linear_layer = nn.Linear(model_dim, model_dim * 4)
        self.relu = nn.ReLU()
        self.second_linear_layer = nn.Linear(model_dim * 4, model_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        return self.dropout(self.second_linear_layer(self.relu(self.first_linear_layer(x))))

# Transformer Block

In [50]:
class TransformerBlock(nn.Module):

    def __init__(self, model_dim: int, num_heads: int):
        super().__init__()
        self.mhsa = MultiHeadedSelfAttention(model_dim, num_heads)
        self.vanilla_nn = VanillaNeuralNetwork(model_dim)
        self.layer_norm_one = nn.LayerNorm(model_dim)
        self.layer_norm_two = nn.LayerNorm(model_dim)

    def forward(self, embedded):
        embedded = embedded + self.mhsa(self.layer_norm_one(embedded))  # Pre-norm + residual connection
        embedded = embedded + self.vanilla_nn(self.layer_norm_two(embedded))  # Pre-norm + residual connection
        return embedded

# GPT Class

In [51]:
class GPT(nn.Module):

    def __init__(self, vocab_size: int, context_length: int, model_dim: int, num_blocks: int, num_heads: int):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, model_dim)
        self.pos_embedding = nn.Embedding(context_length, model_dim)
        self.transformer_blocks = nn.Sequential(*(TransformerBlock(model_dim, num_heads) for _ in range(num_blocks)))
        self.layer_norm_three = nn.LayerNorm(model_dim)
        self.vocab_projection = nn.Linear(model_dim, vocab_size)

    def forward(self, context):
        _, T = context.shape
        # Compute token and positional embeddings
        token_embeddings = self.token_embedding(context)
        pos_embedding = self.pos_embedding(torch.arange(T))
        embedded = token_embeddings + pos_embedding
        output = self.transformer_blocks(embedded)  # Pass through transformer blocks
        # Layer norm + projection to vocabulary size
        return self.vocab_projection(self.layer_norm_three(output))

# Generation

In [None]:
def generate(model, new_chars: int, context, context_length: int, int_to_char: dict) -> str:
    res = []
    for i in range(new_chars):
        # Ensure the context length does not exceed the max length
        context = context[:, -context_length:]
        # Predict the next character probabilities
        prediction = model(context)  # B, T, Vocab_Size
        last_time_step = prediction[:, -1, :]  # B, Vocab_Size
        # Softmax to get the probabilities for the next character
        probabilities = nn.functional.softmax(last_time_step, dim=-1)
        # Sample the next character based on probabilities
        next_char = torch.multinomial(probabilities, 1)  # Sample and get the index
        # Update context with the new character
        context = torch.cat((context, next_char), dim=-1)
        # Append the next character to the result list
        res.append(int_to_char[next_char.item()])
    return "".join(res)

# Training

In [None]:
import csv
from torch.utils.data import Dataset, DataLoader

vocab_size = 104
context_length = 128
model_dim = 252
num_blocks = 6
num_heads = 6

int_to_char = {0: '\n', 1: ' ', 2: '!', 3: '"', 4: '$', 5: '%', 6: '&', 7: "'", 8: '(', 9: ')', 10: '*', 11: '+', 12: ',', 13: '-', 14: '.', 15: '/', 16: '0', 17: '1', 18: '2', 19: '3', 20: '4', 21: '5', 22: '6', 23: '7', 24: '8', 25: '9', 26: ':', 27: ';', 28: '?', 29: 'A', 30: 'B', 31: 'C', 32: 'D', 33: 'E', 34: 'F', 35: 'G', 36: 'H', 37: 'I', 38: 'J', 39: 'K', 40: 'L', 41: 'M', 42: 'N', 43: 'O', 44: 'P', 45: 'Q', 46: 'R', 47: 'S', 48: 'T', 49: 'U', 50: 'V', 51: 'W', 52: 'X', 53: 'Y', 54: 'Z', 55: '[', 56: ']', 57: '_', 58: 'a', 59: 'b', 60: 'c', 61: 'd', 62: 'e', 63: 'f', 64: 'g', 65: 'h', 66: 'i', 67: 'j', 68: 'k', 69: 'l', 70: 'm', 71: 'n', 72: 'o', 73: 'p', 74: 'q', 75: 'r', 76: 's', 77: 't', 78: 'u', 79: 'v', 80: 'w', 81: 'x', 82: 'y', 83: 'z', 84: '{', 85: '|', 86: '}', 87: 'à', 88: 'á', 89: 'è', 90: 'é', 91: 'ë', 92: 'ñ', 93: 'ó', 94: 'ú', 95: '\u2005', 96: '–', 97: '—', 98: '‘', 99: '’', 100: '“', 101: '”', 102: '…', 103: '\u205f'}
char_to_int = {char: idx for idx, char in int_to_char.items()}

class TweetDataset(Dataset):
    def __init__(self, text):
        self.data = [char_to_int[char] for char in text if char in char_to_int]
    
    def __len__(self):
        return len(self.data) - context_length
    
    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx:idx + context_length], dtype=torch.long)
        y = torch.tensor(self.data[idx + 1:idx + 1 + context_length], dtype=torch.long)
        return x, y

# Initialize model, optimizer, and loss function
model = GPT(vocab_size, context_length, model_dim, num_blocks, num_heads)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

# Hyperparameters
num_epochs = 5
batch_size = 64

# Load dataset and create DataLoader
text = ""
with open("Donald-Tweets!.csv", newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile)
    text = " ".join(row["Tweet_Text"] for row in reader if row["Tweet_Text"])
dataset = TweetDataset(text)
dataloader = DataLoader(dataset, batch_size, shuffle=True)

# Train the model
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, (x, y) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output.view(-1, vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if batch_idx % 10 == 0:  # Print every 10 batches
            print(f"Epoch {epoch + 1} of {num_epochs}, Batch {batch_idx} of {len(dataloader)}, Loss: {loss.item():.4f}")
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")

# Save trained weights
torch.save(model.state_dict(), "trained_weights.pt")

# Testing

In [54]:
model = GPT(vocab_size, context_length, model_dim, num_blocks, num_heads)
WEIGHT_PATH = 'trained_weights.pt'
model.load_state_dict(torch.load(WEIGHT_PATH, map_location=torch.device('cpu')))
model.eval()
new_chars = 280  # 5000
context = torch.zeros(1, 1, dtype = torch.int64)

In [83]:
print(generate(model, new_chars, context,
               context_length,
               int_to_char))

tensor([[0]])
tensor([[ 3.1143e+00,  2.7428e+00, -6.1902e-01,  2.7527e-01, -3.8614e+00,
         -3.3266e+00, -1.7736e+00, -1.4798e+01, -2.6960e+00, -2.8815e+00,
         -3.5740e+00, -6.7132e+00, -3.3306e+00, -1.9701e-01, -5.8330e-01,
         -3.6493e+00, -2.2487e+00,  8.1077e-01,  3.2907e-01,  7.4088e-01,
         -1.3221e+00, -4.3967e-01, -1.9112e+00, -1.6028e-01, -1.3138e+00,
         -1.4768e+00, -1.3201e+00, -8.5070e+00, -6.9050e+00,  1.4406e+00,
          7.3262e-01,  2.3770e+00,  1.2472e+00,  1.6282e+00,  1.2138e+00,
          2.8507e+00,  1.1375e+00,  2.5327e+00,  3.6666e-01,  2.1696e-01,
          1.8464e+00,  3.5148e+00,  1.6601e+00,  1.4193e+00,  1.0928e+00,
         -9.7250e-01,  2.5647e+00,  2.2412e+00,  3.9377e+00,  7.3051e-01,
          1.3156e+00,  2.5752e+00, -3.1107e+00, -2.9492e-01, -2.8860e-01,
         -8.3925e+00, -6.5030e+00, -2.7741e+00, -1.8546e-01, -3.1016e-01,
         -7.1510e-01, -3.4272e+00, -1.3549e+00, -9.1247e-01, -1.3999e+00,
          4.9609e+00, -9

IndexError: list index out of range

# ONNX Export

In [56]:
model = GPT(vocab_size, context_length, model_dim, num_blocks, num_heads)
WEIGHT_PATH = 'trained_weights.pt'
model.load_state_dict(torch.load(WEIGHT_PATH, map_location=torch.device('cpu')))
model.eval()

dummy_input = torch.randint(0, vocab_size, (1, context_length), dtype=torch.int32)
torch.onnx.export(
    model, 
    dummy_input, 
    "model.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    dynamic_axes={"input": {0: "batch_size", 1: "seq_len"}, "output": {0: "batch_size", 1: "seq_len"}}
)