In [None]:
#a quick attempt at a toy version of github/havenhq/mamba-chat using parts of github/karpathy/nanoGPT 
import torch
import time
torch.manual_seed(7)
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
device = "cuda"

In [2]:
with open("./tiny_shakespeare.txt", "r") as f:
    text = f.read()
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

input_size = 1024
batch_size = 4

c2i = {c: i for i, c in enumerate(chars)}
i2c = {i: c for i, c in enumerate(chars)}

encode = lambda s: [c2i[c] for c in s]
decode = lambda l: "".join([i2c[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n = int(len(data) * 0.9)

train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - input_size, (batch_size,))
    x = torch.stack([data[i : i + input_size] for i in ix])
    y = torch.stack([data[i + 1 : i + input_size + 1] for i in ix])
    x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    return x, y

model = MambaLMHeadModel(d_model=768,n_layer=12,vocab_size=vocab_size,device=device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
loss_fct = torch.nn.CrossEntropyLoss()

In [4]:
# training
t_start = time.time()
max_iters = 600
print_interval = 100
for iter in range(max_iters):
        
    # next batch
    xb, yb = get_batch("train")

    logits = model(xb).logits
    B, T, C = logits.shape
    logits = logits.view(B * T, C)
    yb = yb.view(B * T)
    loss = loss_fct(logits, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if iter % print_interval == 0:
        print('  iteration {} loss: {}'.format(iter + 1, loss.item()))

print("\n")
print(f"Number of parameters: {n_params}")
print(f"Training time:        {(time.time() - t_start)/60:.2f} min")

  iteration 1 loss: 6.594520092010498
  iteration 101 loss: 1.9037998914718628
  iteration 201 loss: 1.771459937095642
  iteration 301 loss: 1.5928815603256226
  iteration 401 loss: 1.6136791706085205
  iteration 501 loss: 1.4766533374786377


Number of parameters: 45320448
Training time:        1.55 min


In [16]:
prompt_tokens = torch.tensor(encode("Thou shall toil "),dtype=torch.long, device=device).unsqueeze(1).T
out_tokens = model.generate(prompt_tokens, max_length=200,top_k=10,top_p=1.0,temperature=1.1,cg=True)
list_chars = out_tokens.tolist()[0]
print(decode(list_chars))

Thou shall toil my cause.
What tearn, my less huth them?
The meets the cause; the thousand hapure
Honesty his cannot beating myscler father,
Ha! why changes than he inform, imenon
Than the grieve bird
