In [1]:
!pip install mamba_ssm
!pip install wandb



In [2]:
import wandb
import torch
import torch.nn as nn
from torch.nn import functional as F
from torcheval.metrics.text import Perplexity

from GPT import GPTLanguageModel
from Mamba import MambaLanguageModel
from xLSTMmodel import XLSTMLanguageModel
torch.manual_seed(1337)

<torch._C.Generator at 0x7dbb6f574ab0>

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-08-04 13:22:11--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2024-08-04 13:22:11 (17.2 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33msidmital4[0m ([33msidmital4-northeastern-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
torch.set_default_device(device)

In [6]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [7]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [8]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()

    for split in ['perplexity_train', 'perplexity_val']:
        metric = Perplexity()
        metric.to(device)
        metric.reset()
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            BT, C = logits.shape
            logits = logits.view(batch_size, BT//batch_size, C)
            Y = Y.view(batch_size, BT//batch_size)
            metric.update(logits, Y)
        out[split] = metric.compute()
        del metric
    model.train()
    return out

In [None]:
model = GPTLanguageModel(vocab_size)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

[34m[1mwandb[0m: Currently logged in as: [33msidmital4[0m ([33msidmital4-northeastern-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


10.788929 M parameters


In [None]:
# Parameters for early stopping
wandb.watch(model, optimizer, log="all", log_freq=100)

patience = 10
min_delta = 0.001
best_val_loss = float('inf')
patience_counter = 0

for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wandb.log({"train_loss": losses['train'], "val_loss": losses['val']})

        # Early stopping check
        if losses['val'] < best_val_loss - min_delta:
            best_val_loss = losses['val']
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at iteration {iter} with best val loss {best_val_loss:.4f}")
            break

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2221, val loss 4.2306
step 500: train loss 1.7600, val loss 1.9146
step 1000: train loss 1.3903, val loss 1.5987
step 1500: train loss 1.2644, val loss 1.5271
step 2000: train loss 1.1835, val loss 1.4978
step 2500: train loss 1.1233, val loss 1.4910
step 3000: train loss 1.0718, val loss 1.4804
step 3500: train loss 1.0179, val loss 1.5127
step 4000: train loss 0.9604, val loss 1.5102
step 4500: train loss 0.9125, val loss 1.5351
step 4999: train loss 0.8589, val loss 1.5565


In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


But with prison, I will steal for the fimker.

KING HENRY VI:
To prevent it, as I love this country's cause.

HENRY BOLINGBROKE:
I thank bhop my follow. Walk ye were so?

NORTHUMBERLAND:
My lord, I hearison! Who may love me accurse
Some chold or flights then men shows to great the cur
Ye cause who fled the trick that did princely action?
Take my captiving sound, althoughts thy crown.

RICHMOND NE:
God neit will he not make it wise this!

DUKE VINCENTIO:
Worthy Prince forth from Lord Claudio!

Lo


# MAMBA

In [None]:
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2

model = MambaLanguageModel(vocab_size, batch_size = batch_size, block_size = block_size,
                max_iters = max_iters, eval_interval = eval_interval, learning_rate = learning_rate,
                device = device, eval_iters = eval_interval, n_embd = n_embd, n_head = n_head,
                n_layer = n_layer, dropout = dropout)

m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
wandb.watch(model, optimizer, log="all", log_freq=100)

# Parameters for early stopping
patience = 10
min_delta = 0.001
best_val_loss = float('inf')
patience_counter = 0

for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wandb.log({"train_loss": losses['train'], "val_loss": losses['val']})

        # Early stopping check
        if losses['val'] < best_val_loss - min_delta:
            best_val_loss = losses['val']
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at iteration {iter} with best val loss {best_val_loss:.4f}")
            break

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
wandb.watch(model, optimizer, log="all", log_freq=100)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wandb.log({"train_loss": losses['train'], "val_loss": losses['val']})

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2706, val loss 4.2682
step 500: train loss 1.3013, val loss 1.5702
step 1000: train loss 1.1619, val loss 1.5380
step 1500: train loss 1.0705, val loss 1.5556
step 2000: train loss 0.9952, val loss 1.5859
step 2500: train loss 0.9253, val loss 1.6294
step 3000: train loss 0.8677, val loss 1.6666
step 3500: train loss 0.8140, val loss 1.7122
step 4000: train loss 0.7709, val loss 1.7465
step 4500: train loss 0.7313, val loss 1.7960
step 4999: train loss 0.7008, val loss 1.8019


In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Than, my dearly, I wonder, being one;
No word marrow is the people's sovereign; a more drunkail.
Adtive, can this good great son! Where's his sight!
For so we scorn'd within, and make betwixt thee by thy
Lift and fraulinc: upon eein hours in a saunt,
Can the stony so fair an olize, so it was a tower;
Let them go with him to know.

FArst Murderer:
We charge to her: and I'll take it as of moider,
How is it were to live, like labour,
Holy fath bring our sorrow! hath a half too heavy
As thou hast ha


# xLSTM

In [11]:
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layer = ['m', 'm', 's']
dropout = 0.2
torch.set_default_device(device)

x = torch.zeros(batch_size, block_size, n_embd)

model = XLSTMLanguageModel(vocab_size, x, batch_size = batch_size, block_size = block_size,
                max_iters = max_iters, eval_interval = eval_interval, learning_rate = learning_rate,
                device = device, eval_iters = eval_interval, dropout = dropout, layers=layer)

m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

VBox(children=(Label(value='0.012 MB of 0.012 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

9.007681 M parameters


In [None]:
wandb.watch(model, optimizer, log="all", log_freq=100)

# Parameters for early stopping
patience = 10
min_delta = 0.001
best_val_loss = float('inf')
patience_counter = 0

for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wandb.log({"train_loss": losses['train'], "val_loss": losses['val']})

        # Early stopping check
        if losses['val'] < best_val_loss - min_delta:
            best_val_loss = losses['val']
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at iteration {iter} with best val loss {best_val_loss:.4f}")
            break

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.3506, val loss 4.3495
step 500: train loss 2.6212, val loss 2.6416
step 1000: train loss 2.5178, val loss 2.5526
step 1500: train loss 2.4819, val loss 2.5075
step 2000: train loss 2.4704, val loss 2.4976
step 2500: train loss 2.4643, val loss 2.4901
step 3000: train loss 2.4603, val loss 2.4894
step 3500: train loss 2.4592, val loss 2.4850
step 4000: train loss 2.4573, val loss 2.4868


In [None]:
for _ in range(500):
    context, out = model.generate(context, max_new_tokens=1)
    print(decode(out[0].tolist()), end="")