In [1]:
# !pip install wandb

In [None]:
import wandb
import torch
import torch.nn as nn
from torch.nn import functional as F
from torcheval.metrics.text import Perplexity
from tqdm import tqdm
from attention import GPTLanguageModel
from mamba import MambaLanguageModel
from xlstm import XLSTMLanguageModel
torch.manual_seed(1337)
import gc

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-08-05 04:10:10--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.3’


2024-08-05 04:10:10 (76.2 MB/s) - ‘input.txt.3’ saved [1115394/1115394]



In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mshetty-sau[0m ([33mshetty-sau-northeastern-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
batch_size = 125 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 500000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
torch.set_default_device(device)
device

'cuda'

In [6]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [7]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [8]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        metric = Perplexity()
        metric.to(device)
        metric.reset()
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            BT, C = logits.shape
            logits = logits.view(batch_size, BT//batch_size, C)
            Y = Y.view(batch_size, BT//batch_size)
            metric.update(logits, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
        out[str(split)+"_perplexity"] = metric.compute()
        del metric
    model.train()
    return out
    
# @torch.no_grad()
# def estimate_loss():
#     out = {}
#     model.eval()
#     for split in ['train', 'val']:
#         losses = torch.zeros(eval_iters)
#         for k in range(eval_iters):
#             X, Y = get_batch(split)
#             logits, loss = model(X, Y)
#             losses[k] = loss.item()
#         out[split] = losses.mean()

#     for split in ['perplexity_train', 'perplexity_val']:
#         metric = Perplexity()
#         metric.to(device)
#         metric.reset()
#         for k in range(eval_iters):
#             X, Y = get_batch(split)
#             logits, loss = model(X, Y)
#             BT, C = logits.shape
#             logits = logits.view(batch_size, BT//batch_size, C)
#             Y = Y.view(batch_size, BT//batch_size)
#             metric.update(logits, Y)
#         out[split] = metric.compute()
#         del metric
#     model.train()
#     return out

In [9]:
model = GPTLanguageModel(vocab_size)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

10.788929 M parameters


In [None]:
# Parameters for early stopping
wandb.watch(model, optimizer, log="all", log_freq=100)

patience = 5
min_delta = 0.001
best_val_loss = float('inf')
patience_counter = 0

for iter in tqdm(range(max_iters)):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f} train_perplexity {losses['train_perplexity']:.4f}, val_perplexity  {losses['val_perplexity']:.4f}")
        wandb.log({"train_loss": losses['train'], "val_loss": losses['val'], "train_perplexity":losses['train_perplexity'], "val_perplexity":losses['val_perplexity']})

        # Early stopping check
        if losses['val'] < best_val_loss - min_delta:
            best_val_loss = losses['val']
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at iteration {iter} with best val loss {best_val_loss:.4f}")
            break

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

  0%|          | 0/500000 [00:00<?, ?it/s]

step 0: train loss 1.8867, val loss 2.0116 train_perplexity 6.5975, val_perplexity - 7.4750


  0%|          | 500/500000 [08:52<123:31:42,  1.12it/s] 

step 500: train loss 1.3665, val loss 1.5937 train_perplexity 3.9216, val_perplexity - 4.9220


  0%|          | 1000/500000 [17:45<123:37:08,  1.12it/s]

step 1000: train loss 1.2162, val loss 1.5129 train_perplexity 3.3745, val_perplexity - 4.5401


  0%|          | 1500/500000 [26:36<122:37:36,  1.13it/s] 

In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

# MAMBA

In [None]:
# clear GPU memory
del model
delet m
gc.collect()
torch.cuda.empty_cache() 

In [None]:
batch_size = 125 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 500000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2

model = MambaLanguageModel(vocab_size, batch_size = batch_size, block_size = block_size,
                max_iters = max_iters, eval_interval = eval_interval, learning_rate = learning_rate,
                n_embd = n_embd, n_head = n_head,
                n_layer = n_layer, dropout = dropout)

m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
wandb.watch(model, optimizer, log="all", log_freq=100)

# Parameters for early stopping
patience = 5
min_delta = 0.001
best_val_loss = float('inf')
patience_counter = 0

for iter in tqdm(range(max_iters)):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f} train_perplexity {losses['train_perplexity']:.4f}, val_perplexity  {losses['val_perplexity']:.4f}")
        wandb.log({"train_loss": losses['train'], "val_loss": losses['val'], "train_perplexity":losses['train_perplexity'], "val_perplexity":losses['val_perplexity']})

        # Early stopping check
        if losses['val'] < best_val_loss - min_delta:
            best_val_loss = losses['val']
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at iteration {iter} with best val loss {best_val_loss:.4f}")
            break

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

# xLSTM

In [None]:
# clear GPU memory
del model
delet m
gc.collect()
torch.cuda.empty_cache() 

In [None]:
batch_size = 125 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 500000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layer = ['m', 'm', 's']
dropout = 0.2
torch.set_default_device(device)

x = torch.zeros(batch_size, block_size, n_embd)

model = XLSTMLanguageModel(vocab_size, x, batch_size = batch_size, block_size = block_size,
                max_iters = max_iters, eval_interval = eval_interval, learning_rate = learning_rate,
                device = device, eval_iters = eval_interval, dropout = dropout, layers=layer)

m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
wandb.watch(model, optimizer, log="all", log_freq=100)

# Parameters for early stopping
patience = 5
min_delta = 0.001
best_val_loss = float('inf')
patience_counter = 0

for iter in tqdm(range(max_iters)):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f} train_perplexity {losses['train_perplexity']:.4f}, val_perplexity  {losses['val_perplexity']:.4f}")
        wandb.log({"train_loss": losses['train'], "val_loss": losses['val'], "train_perplexity":losses['train_perplexity'], "val_perplexity":losses['val_perplexity']})


        # Early stopping check
        if losses['val'] < best_val_loss - min_delta:
            best_val_loss = losses['val']
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at iteration {iter} with best val loss {best_val_loss:.4f}")
            break

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
for _ in range(500):
    context, out = model.generate(context, max_new_tokens=1)
    print(decode(out[0].tolist()), end="")