# Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Processing

In [3]:
data = None

with open("text.txt", "r") as f:
    data = f.read()

In [4]:
data = data[:100000]

In [5]:
vocab = set(data)

vocab = sorted(vocab)

print(f"Vocabulary -\n{vocab}\n")
print(f"Vocab Size - {len(vocab)}")

Vocabulary -
['\n', ' ', '!', '&', "'", ',', '-', '.', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

Vocab Size - 61


In [6]:
char_to_idx = {char: i for i, char in enumerate(vocab)}
idx_to_char = {v: k for k, v in char_to_idx.items()}

In [7]:
split_index = int(len(data) * 0.8)

train_data = data[:split_index]
val_data = data[split_index:]

# Dataset and DataLoader

In [8]:
class ShakespeareDataset(Dataset):
    def __init__(self, data, vocab, context_length):

        self.data = data
        self.context_length = context_length
        
        self.vocab = vocab
        self.vocab_size = len(vocab)


    def __len__(self):
        return len(self.data) - self.context_length - 1


    def __getitem__(self, i):

        context = self.data[i : i+self.context_length]
        next_token = self.data[i+self.context_length]

        context_tokens = list(context)
        encoded_tokens = [char_to_idx[token] for token in context_tokens]
        encoded_next_token = char_to_idx[next_token]

        return torch.tensor(encoded_tokens), torch.tensor(encoded_next_token)

In [9]:
train_dataset = ShakespeareDataset(train_data, vocab, context_length=100)
val_dataset = ShakespeareDataset(val_data, vocab, context_length=100)

batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [10]:
a, b = next(iter(train_loader))

a.size(), b.size()

(torch.Size([128, 100]), torch.Size([128]))

# Model Architecture

In [None]:
class ShakespeareModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, context_length=512):
        super(ShakespeareModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = nn.Embedding(context_length, d_model)
        
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)
        
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        
        seq_length = x.size(1)
        
        position_indices = torch.arange(seq_length, device=x.device).unsqueeze(0).expand(x.size(0), -1)
        embedded = self.embedding(x) + self.positional_encoding(position_indices)
        
        output = self.transformer_decoder(embedded, embedded)
        
        output = output[:, -1, :]
        logits = self.fc(output)
        
        return logits

In [None]:
vocab_size = len(vocab)

model = ShakespeareModel(vocab_size, d_model=128, nhead=8, num_layers=6, context_length=512).to(device)

# Training & Validation Loops

In [None]:
leanring_rate = 1e-4

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=leanring_rate)

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    
    model.train()
    total_loss = 0

    progress_bar = tqdm(dataloader, desc="Training")
    for input_context, next_tokens in progress_bar:
        optimizer.zero_grad()
        
        input_context = input_context.to(device)
        next_tokens = next_tokens.to(device)
        
        logits = model(input_context)

        loss = criterion(logits, next_tokens)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        progress_bar.set_postfix(loss=f'{loss.item():.4f}')

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
def val_epoch(model, dataloader, criterion, device):
    
    model.eval()
    total_loss = 0

    with torch.inference_mode():

        progress_bar = tqdm(dataloader, desc="Validation")
        for input_context, next_tokens in progress_bar:
            
            input_context = input_context.to(device)
            next_tokens = next_tokens.to(device)
            
            logits = model(input_context)

            loss = criterion(logits, next_tokens)
            total_loss += loss.item()

            progress_bar.set_postfix(loss=f'{loss.item():.4f}')

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
def train_model(model, train_loader, val_loader, criterion, device, num_epochs):

    for epoch in range(1, num_epochs+1):

        print(f"Epoch - [{epoch}/{num_epochs}]")
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss = val_epoch(model, val_loader, criterion, device)

        print("Train Loss -", train_loss)
        print("Val Loss -", val_loss, "\n")

# Generation

In [None]:
def generate(model, start_sequence, max_length=100):
    model.eval()
    
    with torch.inference_mode():
        start_tensor = torch.tensor(start_sequence, device="cuda").unsqueeze(0)
        generated_sequence = start_tensor
        
        for _ in range(max_length):
            input_seq = generated_sequence
            logits = model(input_seq)
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_sequence = torch.cat((generated_sequence, next_token), dim=1)

    return generated_sequence.squeeze().cpu().numpy().tolist()

In [None]:
start_char = "First"
start_token = [char_to_idx[char] for char in start_char]

generated_tokens = generate(model, start_token)
generated_text = [idx_to_char[item] for item in generated_tokens]

"".join(generated_text)