# Single Head Attention

In [47]:
import torch
import torch.nn as nn

class SingleHeadAttention(nn.Module):
        
    def __init__(self, model_dim: int, head_size: int):
        super().__init__()
        self.key_layer = nn.Linear(model_dim, head_size, bias=False)
        self.query_layer = nn.Linear(model_dim, head_size, bias=False)
        self.value_layer = nn.Linear(model_dim, head_size, bias=False)
    
    def forward(self, embedded):
        K = self.key_layer(embedded)
        Q = self.query_layer(embedded)
        V = self.value_layer(embedded)
        _, T, A = K.shape
        scores = Q @ torch.transpose(K, 1, 2) / (A ** 0.5)
        mask = torch.tril(torch.ones(T, T))
        scores = scores.masked_fill(mask == 0, float("-inf"))
        scores = nn.functional.softmax(scores, dim=-1)
        return scores @ V

# Multi-Headed Self Attention

In [48]:
class MultiHeadedSelfAttention(nn.Module):
    
    def __init__(self, model_dim: int, num_heads: int):
        super().__init__()
        self.attention_heads = nn.ModuleList([
            SingleHeadAttention(model_dim, model_dim // num_heads) for _ in range(num_heads)
        ])
        self.compute = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(0.2)
        

    def forward(self, embedded):
        return self.dropout(self.compute(torch.cat([head(embedded) for head in self.attention_heads], dim=-1)))

# Vanilla Neural Network

In [49]:
class VanillaNeuralNetwork(nn.Module):
    
    def __init__(self, model_dim: int):
        super().__init__()
        self.first_linear_layer = nn.Linear(model_dim, model_dim * 4)
        self.relu = nn.ReLU()
        self.second_linear_layer = nn.Linear(model_dim * 4, model_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        return self.dropout(self.second_linear_layer(self.relu(self.first_linear_layer(x))))

# Transformer Block

In [50]:
class TransformerBlock(nn.Module):

    def __init__(self, model_dim: int, num_heads: int):
        super().__init__()
        self.mhsa = MultiHeadedSelfAttention(model_dim, num_heads)
        self.vanilla_nn = VanillaNeuralNetwork(model_dim)
        self.layer_norm_one = nn.LayerNorm(model_dim)
        self.layer_norm_two = nn.LayerNorm(model_dim)

    def forward(self, embedded):
        embedded = embedded + self.mhsa(self.layer_norm_one(embedded))  # Pre-norm + residual connection
        embedded = embedded + self.vanilla_nn(self.layer_norm_two(embedded))  # Pre-norm + residual connection
        return embedded

# GPT Class

In [51]:
class GPT(nn.Module):

    def __init__(self, vocab_size: int, context_length: int, model_dim: int, num_blocks: int, num_heads: int):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, model_dim)
        self.pos_embedding = nn.Embedding(context_length, model_dim)
        self.transformer_blocks = nn.Sequential(*(TransformerBlock(model_dim, num_heads) for _ in range(num_blocks)))
        self.layer_norm_three = nn.LayerNorm(model_dim)
        self.vocab_projection = nn.Linear(model_dim, vocab_size)

    def forward(self, context):
        _, T = context.shape
        # Compute token and positional embeddings
        token_embeddings = self.token_embedding(context)
        pos_embedding = self.pos_embedding(torch.arange(T))
        embedded = token_embeddings + pos_embedding
        output = self.transformer_blocks(embedded)  # Pass through transformer blocks
        # Layer norm + projection to vocabulary size
        return self.vocab_projection(self.layer_norm_three(output))

# Generation

In [94]:
def generate(model, new_chars: int, context, context_length: int, int_to_char: dict) -> str:
    res = []
    for i in range(new_chars):
        # Ensure the context length does not exceed the max length
        context = context[:, -context_length:]
        # Predict the next character probabilities
        prediction = model(context)  # B, T, Vocab_Size
        last_time_step = prediction[:, -1, :]  # B, Vocab_Size
        # Softmax to get the probabilities for the next character
        probabilities = nn.functional.softmax(last_time_step, dim=-1)
        # Sample the next character based on probabilities
        next_char = torch.multinomial(probabilities, 1)  # Sample and get the index
        # Update context with the new character
        context = torch.cat((context, next_char), dim=-1)
        # Append the next character to the result list
        res.append(int_to_char[next_char.item()])
    return "".join(res)

# Training

In [None]:
import csv
from torch.utils.data import Dataset, DataLoader

vocab_size = 124
context_length = 128
model_dim = 252
num_blocks = 6
num_heads = 6

int_to_char = {0: '\n', 1: ' ', 2: '!', 3: '"', 4: '#', 5: '$', 6: '%', 7: '&', 8: "'", 9: '(', 10: ')', 11: '*', 12: '+', 13: ',', 14: '-', 15: '.', 16: '/', 17: '0', 18: '1', 19: '2', 20: '3', 21: '4', 22: '5', 23: '6', 24: '7', 25: '8', 26: '9', 27: ':', 28: ';', 29: '=', 30: '?', 31: '@', 32: 'A', 33: 'B', 34: 'C', 35: 'D', 36: 'E', 37: 'F', 38: 'G', 39: 'H', 40: 'I', 41: 'J', 42: 'K', 43: 'L', 44: 'M', 45: 'N', 46: 'O', 47: 'P', 48: 'Q', 49: 'R', 50: 'S', 51: 'T', 52: 'U', 53: 'V', 54: 'W', 55: 'X', 56: 'Y', 57: 'Z', 58: '[', 59: ']', 60: '_', 61: 'a', 62: 'b', 63: 'c', 64: 'd', 65: 'e', 66: 'f', 67: 'g', 68: 'h', 69: 'i', 70: 'j', 71: 'k', 72: 'l', 73: 'm', 74: 'n', 75: 'o', 76: 'p', 77: 'q', 78: 'r', 79: 's', 80: 't', 81: 'u', 82: 'v', 83: 'w', 84: 'x', 85: 'y', 86: 'z', 87: '{', 88: '|', 89: '}', 90: '~', 91: 'à', 92: 'á', 93: 'è', 94: 'é', 95: 'ë', 96: 'ñ', 97: 'ó', 98: 'ú', 99: 'ʉ', 100: '̱', 101: 'ω', 102: 'я', 103: 'ӕ', 104: 'ԍ', 105: 'ԏ', 106: 'Ԡ', 107: 'ե', 108: 'լ', 109: 'ջ', 110: 'ُ', 111: '٪', 112: '\u06dd', 113: 'ۢ', 114: '۪', 115: '\u2005', 116: '–', 117: '—', 118: '‘', 119: '’', 120: '“', 121: '”', 122: '…',
               123: '\u205f'}
char_to_int = {char: idx for idx, char in int_to_char.items()}

class TweetDataset(Dataset):
    def __init__(self, text):
        self.data = [char_to_int[char] for char in text if char in char_to_int]
    
    def __len__(self):
        return len(self.data) - context_length
    
    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx:idx + context_length], dtype=torch.long)
        y = torch.tensor(self.data[idx + 1:idx + 1 + context_length], dtype=torch.long)
        return x, y

# Initialize model, optimizer, and loss function
model = GPT(vocab_size, context_length, model_dim, num_blocks, num_heads)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

# Hyperparameters
num_epochs = 5
batch_size = 64

# Load dataset and create DataLoader
text = ""
with open("Donald-Tweets!.csv", newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile)
    text = " ".join(row["Tweet_Text"] for row in reader if row["Tweet_Text"])
dataset = TweetDataset(text)
dataloader = DataLoader(dataset, batch_size, shuffle=True)

# Train the model
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, (x, y) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output.view(-1, vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if batch_idx % 10 == 0:  # Print every 10 batches
            print(f"Epoch {epoch + 1} of {num_epochs}, Batch {batch_idx} of {len(dataloader)}, Loss: {loss.item():.4f}")
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")

# Save trained weights
torch.save(model.state_dict(), "trained_weights.pt")

# Testing

In [92]:
model = GPT(vocab_size, context_length, model_dim, num_blocks, num_heads)
WEIGHT_PATH = 'trained_weights.pt'
model.load_state_dict(torch.load(WEIGHT_PATH, map_location=torch.device('cpu')))
model.eval()
new_chars = 280  # 5000
context = torch.zeros(1, 1, dtype = torch.int64)

In [97]:
print(generate(model, new_chars, context,
               context_length,
               int_to_char))

"Crooked Hillary said, she would be going to Rocket Birmston. How can she run? "@AprilLaJune: OREGON votes today! Go vote for @realDonaldTrump and kick it BIG TIME!  #MAGA #Debates  __" "@ihatematt: @realDonaldTrump @megynkelly it makes me not watch the debate by a poll" "@apollo


# ONNX Export

In [98]:
model = GPT(vocab_size, context_length, model_dim, num_blocks, num_heads)
WEIGHT_PATH = 'trained_weights.pt'
model.load_state_dict(torch.load(WEIGHT_PATH, map_location=torch.device('cpu')))
model.eval()

dummy_input = torch.randint(0, vocab_size, (1, context_length), dtype=torch.int32)
torch.onnx.export(
    model, 
    dummy_input, 
    "model.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    dynamic_axes={"input": {0: "batch_size", 1: "seq_len"}, "output": {0: "batch_size", 1: "seq_len"}}
)