# Import libraries

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/Users/shwetank/code/makemore-utils-nbs')
from utils import create_dataset, evaluate_loss, generate
from torch.optim import Adam
import random
from models import Bigram
from sklearn.manifold import TSNE
import numpy as np
import torch
import math

# Check if accelerator is available on your system 

In [None]:
# Setup device - check if accelerator is available
# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    # Check if MPS is available
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        # Fall back to CPU
        device = torch.device("cpu")

print("Device selected:", device)

# Read the data and create training and validation sets

In [None]:
## Read file and create training and validation sets
names, vocab, max_length = create_dataset('../names.txt')
# print(len(names), vocab, max_length)
seed_value = 42
random.seed(seed_value)
n1 = int(0.9*len(names))

## Add stop token
names = ['.' + n + '.' for n in names]

## Check that dataset is shuffled
random.shuffle(names)
# print(names[0:5], names_ss[0:5]) 
train_data = names[:n1]
val_data = names[n1:]

print(train_data[0:5])
print(val_data[0:5])

# Create a character level tokenizer - simplest version of what is possible

In [None]:
# Create functions to encode and decode the string data to torch tensor and back
stoi = {s: i+1 for i, s in enumerate(vocab)} 
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)
print(stoi)

In [None]:
def encode(text_batch: list) -> list:
    tokens = []
    for ridx, row in enumerate(text_batch):
        # Convert row to indices, with padding or truncation as necessary
        row_indices = [stoi[r] for r in row][:max_length]
        tokens.append(row_indices)
    
    return tokens

def decode(token_batch: list) -> list:
    text_batch = []
    for row in token_batch:
        text_batch.append(''.join(itos[r] for r in row))
    
    return text_batch

In [None]:
## Check encoding and decoding
batch_size = 8
train_batch = train_data[:batch_size]
print(train_batch)
print(encode(train_batch))
print(decode(encode(train_batch)))

In [None]:
## Tokenize the train and val datasets and covert them to token indices
encoded_train_data = encode(train_data)
encoded_val_data = encode(val_data)

# Create a model and run a forward pass

In [None]:
## Function to create a batch
def get_batch(data, max_encoded_length = max_length + 2,batch_size=4):
    x = torch.zeros(batch_size, max_encoded_length, dtype=torch.long)
    y = torch.zeros(batch_size, max_encoded_length, dtype=torch.long)
    # print(data)
    row_nums = torch.randint(len(data), size=(batch_size,))
    for i,n in enumerate(row_nums):
        row_len = len(data[n])
        x[i,1:1+row_len] = torch.tensor(data[n])
        y[i,:row_len] = torch.tensor(data[n])
        y[i,row_len:] = -1

    x = x[:,1:]
    y = y[:,1:]
    return x.to(device),y.to(device)

In [None]:
## Print out a batch and see how it looks
x,y  = get_batch(encoded_train_data)
print(x)
print(y)
print(x.shape,y.shape)

In [None]:
## Hyperparameters
batch_size = 64

In [None]:
## Run a forward pass
vocab_size = len(vocab) + 1 # +1 for '.' stop character
model = Bigram(vocab_size).to(device)
total_params = sum(p.numel() for p in model.parameters())
print("Total parameters:", total_params)
xb, yb = get_batch(encoded_train_data,batch_size=batch_size)
# print(xb.shape, yb.shape)
# print(xb)
# print(yb)
logits, loss = model(xb,yb)
print('Measured loss:',loss.item())
expected_loss = -math.log(1/26.)
print('Expected loss assuming uniform:', expected_loss)

# Choose a reasonable learning rate to train your model

In [None]:
# Function to do a learning rate sweep
def get_lr_loss(model, optimizer, dataset, batch_size, num_epochs, device, lr_start_exp=-3, lr_end_exp=0.5):

    lrexp = torch.linspace(lr_start_exp, lr_end_exp, num_epochs, requires_grad=False)
    lrs_val = 10**lrexp

    lri = []
    lossi = []
    # Training loop with mini-batches and lr sweep
    for epoch in range(num_epochs):

        ## Set learning rate
        for g in optimizer.param_groups:
            g['lr'] = lrs_val[epoch]

        xb, yb = get_batch(dataset, batch_size=256)


        # Forward pass
        _, loss = model(xb, yb)
        lri.append(lrs_val[epoch])
        lossi.append(loss.item())

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return lri, lossi

In [None]:
# Run the learning rate sweep and plot the results
import matplotlib.pyplot as plt
num_epochs = 100
batch_size = 128
optimizer = Adam(model.parameters())
lri, lossi =  get_lr_loss(model, optimizer, encoded_train_data, batch_size, num_epochs, device, -3, 1)
plt.plot(lri, lossi)
# Add labels to the x-axis and y-axis
plt.xlabel('LR (Learning Rate)')
plt.ylabel('Loss')

# Run a training loop

In [None]:
## Initialize loss matrices and batch size

tr_loss = []
val_loss = []
tr_loss_raw = []
batch_size = 128

In [None]:
## Initialize training parameters
lr = 0.1
optimizer = Adam(model.parameters(), lr=lr)
n_epochs = 1000

for steps in range(n_epochs):
    xtr, ytr = get_batch(encoded_train_data, batch_size=batch_size)
    xval, yval = get_batch(encoded_val_data, batch_size=batch_size)
    eval_dataset = {'train': (xtr,ytr), 'val': (xval, yval)}
    logits, loss = model(xtr,ytr)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    tr_lossi, val_lossi = evaluate_loss(model, eval_dataset, num_batches=16)
    tr_loss.append(tr_lossi)
    val_loss.append(val_lossi)
    tr_loss_raw.append(loss.item())

    ## Print losses
    if steps % 100 == 0:
        print(steps, ' --> train loss: ', tr_lossi, 'validation loss: ', val_lossi, 'single shot loss:', loss.item())

In [None]:
## Plot loss 
import matplotlib.pyplot as plt

plt.plot(tr_loss)
plt.plot(val_loss)

print('training loss: ', round(torch.mean(torch.tensor(tr_loss[-100:])).item(),4)), 
print('validation loss: ', round((torch.mean(torch.tensor(val_loss[-100:]))).item(),4))

# Analyze your results

In [None]:
## Function to generate tokens and decode them
def print_samples(model, max_new_tokens, device, num=8):
    """ samples from the model and pretty prints the decoded samples """
    X_init = torch.zeros((num, 1), dtype=torch.long).to(device)
    X_samp = generate(model, X_init, max_new_tokens, device)[:,1:].tolist()
    sample_list = []
    for row in X_samp:
        crop_index = row.index(0) if 0 in row else len(row)
        row = row[:crop_index]
        sample_list.append(row)

    print(decode(sample_list))
    

In [None]:
## Pretty print the samples
print_samples(model, max_length, device, 8)

In [None]:
## Plot the cross correlation matrix for the embeddings - note that its asymmetric
embeddings_matrix = model.bigram_embedding.weight.data.cpu().numpy()

# Create vectors for ticklabels
ticklabels = list(itos.values())
ticklabels.insert(0,'stop')
x_ticklabel_vec = np.arange(len(ticklabels))
y_ticklabel_vec = np.arange(len(ticklabels))
print(ticklabels)

# Plot the embedding matrix as a 2D matrix plot
plt.figure(figsize=(7,7))
plt.imshow(embeddings_matrix, cmap='Blues', aspect='auto')
plt.colorbar(label='Embedding Value')
plt.title('2D Matrix Plot of Embeddings')
plt.xlabel('Second alphabet')
plt.ylabel('Context alphabet')
plt.xticks(x_ticklabel_vec, ticklabels)
plt.yticks(y_ticklabel_vec, ticklabels)
plt.show()

In [None]:
## Plot tsne or pca for embeddings
# t-SNE example:
tsne = TSNE(n_components=2, perplexity=5)
reduced_embeddings = tsne.fit_transform(embeddings_matrix)

# visualize dimensions 0 and 1 of the embedding matrix for all characters
plt.figure(figsize=(8,8))
plt.scatter(reduced_embeddings[:,0].data, reduced_embeddings[:,1].data, s=200)
for i in range(embeddings_matrix.shape[0]):
    plt.text(reduced_embeddings[i,0].item(), reduced_embeddings[i,1].item(), ticklabels[i], ha="center", va="center", color='white')
plt.grid('minor')