In [None]:
from datasets import load_dataset
import torch
import tiktoken

# Setup device
device = torch.device("cuda")

In [None]:
device

In [None]:
# Hyperparams
batch_size = 32
block_size = 128
emb_dim = 768
num_heads = 12
num_layers = 12
dropout = 0.3
linear_scaleup = 4
temperature = 2
weights_decay = 0.01

# Optimization loop parameters
n_epochs = 100
eval_interval =10

In [None]:
# Load WikiText-103 dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
tokenizer = tiktoken.get_encoding("cl100k_base")

In [None]:
def preprocess_dataset(dataset, tokenizer, block_length=128):
    dataset_string = ' '.join(dataset['text'])
    tokenized_dataset = tokenizer.encode(dataset_string)
    return tokenized_dataset

In [None]:
# train_dataset = preprocess_dataset(dataset['train'],tokenizer)
# val_dataset = preprocess_dataset(dataset['validation'],tokenizer)
# test_dataset = preprocess_dataset(dataset['test'],tokenizer)
# print(len(train_dataset), len(val_dataset), len(test_dataset))

In [None]:
# Load WikiText-103 dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
tokenized_dataset = {}

# Create tokenized text blobs
train_tokens = preprocess_dataset(dataset['train'],tokenizer)
val_tokens = preprocess_dataset(dataset['validation'],tokenizer)
test_tokens = preprocess_dataset(dataset['test'],tokenizer)
print(len(train_tokens), len(val_tokens), len(test_tokens))

# Convert to tokenized text tensors
train_tensor = torch.squeeze(torch.tensor(train_tokens, dtype=torch.long))
val_tensor = torch.tensor(val_tokens, dtype=torch.long)
test_tensor = torch.tensor(test_tokens, dtype=torch.long)
print(f"Number of training tokens: {train_tensor.shape}, validation tokens: {val_tensor.shape}, test tokens: {test_tensor.shape}")

In [None]:
from torch.utils.data import Dataset
import torch

class WikiData(Dataset):
    def __init__(self, dataset: Dataset, block_size: int=128, batch_size: int=64):
        self.block_size = block_size
        self.batch_size = batch_size
        self.dataset = dataset
        

    def __len__(self) -> int:
        return self.block_size

    def __getitem__(self, idx: int) -> tuple[torch.tensor, torch.tensor]:
        x = torch.zeros((self.batch_size,self.block_size + 1), dtype=torch.long)
        y = torch.zeros((self.batch_size,self.block_size + 1), dtype=torch.long)
        # print(idx)
        x = self.dataset[idx:idx+self.block_size]
        y = self.dataset[idx+1:idx+self.block_size+1]
        return x.to(device), y.to(device)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class Feedforward(nn.Module):
    def __init__(self,emb_dim,dropout) -> None:
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, linear_scaleup * emb_dim),
            nn.GELU(),
            nn.Linear(linear_scaleup * emb_dim, emb_dim),
            nn.Dropout(dropout)
        )

    def forward(self,x):
        return self.ff(x)

class Block(nn.Module):
    def __init__(self, emb_dim, num_heads, dropout) -> None:
        super().__init__()
        self.head_size = emb_dim // num_heads
        self.sa_head = nn.MultiheadAttention(emb_dim, num_heads, dropout)
        self.ff = Feedforward(emb_dim, dropout)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x, targets=None):
        x = self.ln1(x)
        sa_out, _ = self.sa_head(x,x,x)
        x = x + sa_out
        x = x + self.ff(self.ln2(x))
        return x


class Xformer(nn.Module):
    def __init__(self, emb_dim, num_heads, num_layers, dropout): 
        super().__init__()
        self.tok_emb = nn.Embedding(tokenizer.n_vocab, emb_dim)
        self.pos_emb = nn.Embedding(tokenizer.n_vocab, emb_dim)
        blocks = []
        for _ in range(num_layers):
            blocks.append(Block(emb_dim, num_heads, dropout))
        
        self.blocks = nn.Sequential(*blocks, nn.LayerNorm(emb_dim))
        self.lm_head = nn.Linear(emb_dim, tokenizer.n_vocab)

    def forward(self, x, targets=None):
        x = self.tok_emb(x) + self.pos_emb(x)
        x = self.blocks(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            # print(logits.view(-1, logits.size(-1)).shape, targets.view(-1).shape)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return(logits,loss)

In [None]:
## Get dataloader for train and  val datasets from training and test data
from torch.utils.data import DataLoader
train_data_obj = WikiData(train_tensor, block_size=block_size)
val_data_obj = WikiData(val_tensor, block_size=block_size)
train_dataloader = DataLoader(train_data_obj, batch_size=1)
val_dataloader = DataLoader(val_data_obj, batch_size=1)

In [None]:
# Single test run
from torch.optim import Adam, AdamW
xb, yb = next(iter(train_dataloader))
print(xb.shape, yb.shape)
model = Xformer(emb_dim, num_heads, num_layers, dropout).to(device)
optimizer = Adam(model.parameters(), lr=0.001)
logits, loss = model(xb,yb)
print(loss)

In [None]:
from utils import get_model_size
get_model_size(model)/1e6

In [None]:
# # ## Optimal lr sweep
# from utils import get_lr_loss
# import matplotlib.pyplot as plt
# num_epochs = 100
# lri, lossi =  get_lr_loss(model, optimizer, train_dataloader, num_epochs, device, -5, -3)
# plt.plot(lri, lossi)
# # Add labels to the x-axis and y-axis
# plt.xlabel('LR (Learning Rate)')
# plt.ylabel('Loss')

In [None]:
train_dataloader = DataLoader(train_data_obj, batch_size=batch_size)
val_dataloader = DataLoader(val_data_obj, batch_size=batch_size)
tr_loss = []
vl_loss = []

In [None]:
from utils import evaluate_loss
lr = 1e-5
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weights_decay)
for steps in range(30):
    xb, yb = next(iter(train_dataloader))
    xb = xb.to(device)
    yb = yb.to(device)
    logits, loss = model(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    tr_lossi, vl_lossi = evaluate_loss(model, train_dataloader, val_dataloader, device, num_batches=10)
    tr_loss.append(tr_lossi)
    vl_loss.append(vl_lossi)
    print('tr_loss: ', tr_lossi, 'val_loss: ', vl_lossi, 'single shot loss:', loss.item())

    # if steps % eval_interval-1 == 0:
    #     print('tr_loss: ', tr_lossi, 'val_loss: ', vl_lossi, 'single shot loss:', loss.item())

In [None]:
## Plot loss 
import matplotlib.pyplot as plt

plt.plot(tr_loss[-50:], label='Training Loss')
plt.plot(vl_loss[-50:], label='Validation Loss')
plt.legend()
plt.show()
print('training loss: ', round(torch.mean(torch.tensor(tr_loss[-10:])).item(),4)), 
print('validation loss: ', round((torch.mean(torch.tensor(vl_loss[-10:]))).item(),4))

In [None]:
@torch.no_grad()
def generate(model, max_new_tokens=block_size, stub='', batch_size=1,temperature=temperature):
    stub_token = tokenizer.encode(stub)
    stub_token_tensor = torch.tensor(stub_token, dtype=torch.long).to(device)
    idx = torch.unsqueeze(stub_token_tensor, dim=0)
    for _ in range(max_new_tokens):
        # print('idx shape:',idx.shape)
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        idx_cond = idx_cond.to(device)
        logits, _ = model(idx_cond)
        # Pick only the logits from most recent time step. Karpathy also does a divide by temp?
        # This is just Platt scaling which makes the various Softmax curves closes adding more randomness
        # see scratch.ipynb. https://en.wikipedia.org/wiki/Platt_scaling
        logits = logits[:,-1,:]/temperature
        probs = F.softmax(logits, dim=-1)
        # print('prob dist:',probs)
        idx_next = torch.multinomial(probs, num_samples=1)
        # print('idx_next shape:',idx_next.shape)
        idx = torch.cat((idx, idx_next), dim=1)

    sample = tokenizer.decode(idx[0].tolist())

    return sample



In [None]:
stub = "And so we begin, where we started "
samples = generate(model, block_size, stub, temperature)
print(samples)

In [None]:
# Specify the file path where you want to save the model weights
file_path = 'model_weights.pth'

# Save the model weights
torch.save(model.state_dict(), file_path)