In [None]:
%load_ext autoreload
%autoreload 2
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
# Setup device
# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    # Check if MPS is available
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        # Fall back to CPU
        device = torch.device("cpu")

print("Device selected:", device)

In [None]:
## Get data
dataset = load_dataset("wikitext", name="wikitext-2-raw-v1")
print(dataset)

In [None]:
splits = ['train', 'validation', 'test']
split = splits[0]
print(dataset[split][:5])

In [None]:
## Train a new tokenizer using Wiki dataset from GPT2 tokenizer
batch_size = 1000
vocab_size = 25000
chunk_size = 1024 # From GPT2

def batch_iterator():
    for split in splits:
        for i in range(0, len(dataset[split]), batch_size):
            yield dataset[split][i : i + batch_size]["text"]
# all_texts = [dataset[i : i + batch_size]["text"] for i in range(0, len(dataset), batch_size)]

gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
print(gpt2_tokenizer.is_fast)
tokenizer = gpt2_tokenizer.train_new_from_iterator(batch_iterator(), vocab_size=25000)

In [None]:
tokenized_text = [tokenizer.tokenize(r) for r in dataset['train'][:5]["text"]]
print(tokenized_text)

In [None]:
## Preprocess the dataset
def preprocess_dataset(dataset, tokenizer, splits):
    encoded_dataset = {}
    
    for s in splits:
        dataset_string = ' '.join(dataset[s]['text'])
        tokenized_chunks = []
        chunk_size = 1024
        
        for i in range(0, len(dataset_string), chunk_size):
            chunk = dataset_string[i:i + chunk_size]
            encoded_chunk = tokenizer.encode(chunk)
            tokenized_chunks.append(encoded_chunk)
        
        # Combine tokenized chunks into a single tensor
        combined_chunks = [token for chunk in tokenized_chunks for token in chunk]
        encoded_dataset[s] = torch.tensor(combined_chunks, dtype=torch.long)
        
    return encoded_dataset


In [None]:
encoded_dataset = preprocess_dataset(dataset, tokenizer, splits)
print(encoded_dataset['train'][:20])
print(tokenizer.decode(encoded_dataset['train'][:20]))

In [None]:
## Print tokens and text
from utils import get_batch
## Hyperparameters
block_size = 128
batch_size = 32
emb_dim = 64
num_layers = 4
num_heads = 16
dropout = 0.2
# x_toks, y_toks = get_batch(encoded_dataset['train'], device, block_size, batch_size)
# print(x_toks)
# print(y_toks)
# txt = [''.join(tokenizer.decode(t)) for t in x_toks.tolist()]
# for row in txt:
#     print(row)
#     print('================')

In [None]:
# Single test run
from models import Xformer_Scratch as Xformer
from torch.optim import Adam
import math
xb, yb =  get_batch(encoded_dataset['train'], device, block_size, batch_size)
model = Xformer(emb_dim, vocab_size, num_heads, num_layers, block_size, dropout).to(device)
optimizer = Adam(model.parameters(), lr=0.001)
logits, loss = model(xb,yb)
xb.shape, yb.shape
print('Measured loss:', loss.item())
print('Expected loss:', -math.log(1./vocab_size))

In [None]:
from utils import get_model_size
get_model_size(model)

In [None]:
# Function to do a learning rate sweep
def get_lr_loss(model, optimizer, dataset, num_epochs, device, lr_start_exp=-3, lr_end_exp=0.5):

    lrexp = torch.linspace(lr_start_exp, lr_end_exp, num_epochs, requires_grad=False)
    lrs_val = 10**lrexp

    lri = []
    lossi = []
    # Training loop with mini-batches and lr sweep
    for epoch in range(num_epochs):

        ## Set learning rate
        for g in optimizer.param_groups:
            g['lr'] = lrs_val[epoch]

        xb, yb = get_batch(dataset, device, block_size, batch_size)


        # Forward pass
        _, loss = model(xb, yb)
        lri.append(lrs_val[epoch])
        lossi.append(loss.item())

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return lri, lossi

In [None]:
# # ## Optimal lr sweep
# from utils import get_lr_loss
import matplotlib.pyplot as plt
num_epochs = 100
lri, lossi =  get_lr_loss(model, optimizer, encoded_dataset['train'], num_epochs, device, -4, -2)
plt.plot(lri, lossi)
# Add labels to the x-axis and y-axis
plt.xlabel('LR (Learning Rate)')
plt.ylabel('Loss')


In [None]:
## Initialize loss matrices and batch size
tr_loss = []
val_loss = []
tr_loss_raw = []
batch_size = 32

In [None]:
from utils import evaluate_loss
## Initialize training parameters
lr = 0.001
optimizer = Adam(model.parameters(), lr=lr)
n_epochs = 100

for steps in range(n_epochs):
    xtr, ytr = get_batch(encoded_dataset['train'], device, block_size, batch_size)
    xval, yval = get_batch(encoded_dataset['validation'], device, block_size, batch_size)
    eval_dataset = {'train': (xtr,ytr), 'validation': (xval, yval)}
    logits, loss = model(xtr,ytr)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    tr_lossi, val_lossi = evaluate_loss(model, eval_dataset, num_batches=16)
    tr_loss.append(tr_lossi)
    val_loss.append(val_lossi)
    tr_loss_raw.append(loss.item())

    ## Print losses
    if steps % 10 == 0:
        print(steps, ' --> train loss: ', tr_lossi, 'validation loss: ', val_lossi, 'single shot loss:', loss.item())

In [None]:
## Plot loss 
import matplotlib.pyplot as plt

plt.plot(tr_loss)
plt.plot(val_loss)

print('training loss: ', round(torch.mean(torch.tensor(tr_loss[-100:])).item(),4)), 
print('validation loss: ', round((torch.mean(torch.tensor(val_loss[-100:]))).item(),4))

In [None]:
from torch.nn import functional as F
@torch.no_grad()
def generate(model, idx, max_new_tokens, device, block_size=16):
    """Generates a single batch of names based on since of idx matrix. Accessed via print_samples"""
    for _ in range(max_new_tokens):
        # print('idx shape:',idx.shape)
        # print(idx)
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        idx_cond = idx_cond.to(device)
        logits, _ = model(idx_cond)
        # Pick only the logits from most recent time step. Karpathy also does a divide by temp?
        # This is just Platt scaling which makes the various Softmax curves closes adding more randomness
        # see scratch.ipynb. https://en.wikipedia.org/wiki/Platt_scaling
        logits = logits[:,-1,:]
        probs = F.softmax(logits, dim=-1)
        # print('prob dist:',probs)
        idx_next = torch.multinomial(probs, num_samples=1)
        # print('idx_next shape:',idx_next.shape)
        idx = torch.cat((idx, idx_next), dim=1)
    return idx

In [None]:
def print_samples(model, tokenizer, seed_text, max_new_tokens, device):
    """ samples from the model and pretty prints the decoded samples """
    # X_init = torch.zeros((num, 1), dtype=torch.long).to(device)
    seed_tokens = torch.tensor(tokenizer.encode(seed_text), dtype=torch.long).to(device)
    seed_tokens = seed_tokens[None, ...]
    X_samp = generate(model, seed_tokens, max_new_tokens, device)[:,1:].tolist()
    # print(X_samp)
    # # print(X_samp)
    for row in X_samp:
        crop_index = row.index(0) if 0 in row else len(row)
        # print(row, crop_index)
        row = row[:crop_index]
        print(tokenizer.decode(row))

In [None]:
## Generate samples
# from utils import print_samples
print_samples(model, tokenizer, 'In the dark ages', 128, device)

In [None]:
# Specify the file path where you want to save the model weights
file_path = 'model_weights.pt'

# Save the model weights
torch.save(model.state_dict(), file_path)

In [None]:
import torch
model = Xformer(emb_dim, vocab_size, num_heads, num_layers, block_size, dropout).to(device)
# Load the saved model weights
file_path = 'wiki-2.pt'
model.load_state_dict(torch.load(file_path))
print_samples(model, tokenizer, 'In the dark ages', 128, device)