In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import os
import pickle
import requests

In [2]:
from model import BigramLanguageModel, ModelConfig

In [3]:
out_dir = './model/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

In [4]:
batch_size = 64
block_size = 256
max_iters = 5000
eval_interval = 200
learning_rate = 3e-4
eval_iters = 200
n_embed = 384
n_head = 6
n_layer = 6
dropout = 0.2

In [5]:
# prepare the dataset
torch.manual_seed(1337)
!wget 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'

--2024-06-16 17:57:46--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2024-06-16 17:57:46 (27.3 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [6]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)

chars, vocab_size

(['\n',
  ' ',
  '!',
  '$',
  '&',
  "'",
  ',',
  '-',
  '.',
  '3',
  ':',
  ';',
  '?',
  'A',
  'B',
  'C',
  'D',
  'E',
  'F',
  'G',
  'H',
  'I',
  'J',
  'K',
  'L',
  'M',
  'N',
  'O',
  'P',
  'Q',
  'R',
  'S',
  'T',
  'U',
  'V',
  'W',
  'X',
  'Y',
  'Z',
  'a',
  'b',
  'c',
  'd',
  'e',
  'f',
  'g',
  'h',
  'i',
  'j',
  'k',
  'l',
  'm',
  'n',
  'o',
  'p',
  'q',
  'r',
  's',
  't',
  'u',
  'v',
  'w',
  'x',
  'y',
  'z'],
 65)

In [7]:
# create a mapping from chars to ints
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

data = torch.tensor(encode(text), dtype=torch.long)

# lets split into trainh and test
n = int(0.9*len(data)) # 90% will be train, rest validation
train_data = data[:n]
val_data = data[n:]

In [8]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [9]:
# model initiation
model_args = dict(n_layer=n_layer, n_head=n_head,
                    n_embed=n_embed, block_size=block_size,
                  vocab_size=vocab_size, dropout=dropout)
modconf = ModelConfig(**model_args)
model = BigramLanguageModel(modconf).to(device)

# optimiser
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
best_val_loss = 1e9
for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters-1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    if losses['val'] < best_val_loss:
        best_val_loss = losses['val']
        if iter > 0:
            checkpoint = {
                'model': model.state_dict(),
                'model_args': model_args
                }
            print(f"saving checkpoint to {out_dir}")
            torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2849, val loss 4.2823
step 200: train loss 2.4205, val loss 2.4410
saving checkpoint to ./model/
step 400: train loss 2.1404, val loss 2.1907
saving checkpoint to ./model/
step 600: train loss 1.8839, val loss 2.0001
saving checkpoint to ./model/
step 800: train loss 1.7156, val loss 1.8831
saving checkpoint to ./model/
step 1000: train loss 1.5979, val loss 1.7756
saving checkpoint to ./model/
step 1200: train loss 1.5220, val loss 1.7197
saving checkpoint to ./model/
step 1400: train loss 1.4611, val loss 1.6547
saving checkpoint to ./model/
step 1600: train loss 1.4092, val loss 1.6236
saving checkpoint to ./model/
step 1800: train loss 1.3741, val loss 1.5977
saving checkpoint to ./model/
step 2000: train loss 1.3413, val loss 1.5718
saving checkpoint to ./model/
step 2200: train loss 1.3143, val loss 1.5484
saving checkpoint to ./model/
step 2400: train loss 1.2900, val loss 1.5394
saving checkpoint to ./model/
step 2600: train loss 1.2648, val loss 1.5242
sav

In [10]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))



But with prison: I will steal win those mercy;
And Clarence the Lord Ga father, I my grant too love.

CLIFFORD:
And, no more: do you not, my lord;
I have last it, these were my wars: but you'll try heart of eyes,
And with woaks for the world's grace:
We will hence of it. Dire you come! know you to fled.
I was old down with her.
Clarding fie, even more fear ough, paint presately gladlic.

SICINIUS:
Go, be not. Here's slew!
Bid what nothing subs here her sins and bed!

SICINIUS:
Masterly, tugh unt
