In [55]:
%pip install torchinfo

Note: you may need to restart the kernel to use updated packages.


In [56]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary

In [57]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [58]:
# download file to tmp/data.txt
!wget -O tmp/data.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-09-28 23:03:50--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8003::154, 2606:50c0:8001::154, 2606:50c0:8002::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8003::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: 'tmp/data.txt'

     0K .......... .......... .......... .......... ..........  4% 2.78M 0s
    50K .......... .......... .......... .......... ..........  9% 7.89M 0s
   100K .......... .......... .......... .......... .......... 13% 12.5M 0s
   150K .......... .......... .......... .......... .......... 18% 10.2M 0s
   200K .......... .......... .......... .......... .......... 22% 16.8M 0s
   250K .......... .......... .......... .......... .......... 27% 12.6M 0s
   300K .......... .......... .......... .......... .......... 32% 19.3M 0s
  

In [59]:
with open('tmp/data.txt', 'r') as f:
    text = f.read()
print('text length:', len(text))

text length: 1115394


In [60]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [61]:
chars = ["[PAD]", *sorted(list(set(text)))]
vocab_size = len(chars)
print("".join(chars))
print("vocab size:", vocab_size)

[PAD]
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 66


In [62]:
# Create mapping from character to index and vice versa
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[47, 48, 48, 2, 59, 47, 44, 57, 44]
hii there


In [63]:
# store in tensor
data = torch.tensor(encode(text), dtype=torch.int64, device=device)
print(data.shape, data.dtype)
print(data[:100]) # the 100 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([19, 48, 57, 58, 59,  2, 16, 48, 59, 48, 65, 44, 53, 11,  1, 15, 44, 45,
        54, 57, 44,  2, 62, 44,  2, 55, 57, 54, 42, 44, 44, 43,  2, 40, 53, 64,
         2, 45, 60, 57, 59, 47, 44, 57,  7,  2, 47, 44, 40, 57,  2, 52, 44,  2,
        58, 55, 44, 40, 50,  9,  1,  1, 14, 51, 51, 11,  1, 32, 55, 44, 40, 50,
         7,  2, 58, 55, 44, 40, 50,  9,  1,  1, 19, 48, 57, 58, 59,  2, 16, 48,
        59, 48, 65, 44, 53, 11,  1, 38, 54, 60], device='cuda:0')


In [64]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [65]:
block_size = 100
train_data[:block_size+1]

tensor([19, 48, 57, 58, 59,  2, 16, 48, 59, 48, 65, 44, 53, 11,  1, 15, 44, 45,
        54, 57, 44,  2, 62, 44,  2, 55, 57, 54, 42, 44, 44, 43,  2, 40, 53, 64,
         2, 45, 60, 57, 59, 47, 44, 57,  7,  2, 47, 44, 40, 57,  2, 52, 44,  2,
        58, 55, 44, 40, 50,  9,  1,  1, 14, 51, 51, 11,  1, 32, 55, 44, 40, 50,
         7,  2, 58, 55, 44, 40, 50,  9,  1,  1, 19, 48, 57, 58, 59,  2, 16, 48,
        59, 48, 65, 44, 53, 11,  1, 38, 54, 60,  2], device='cuda:0')

In [66]:
a = train_data[:block_size]
y = train_data[1:block_size+1]

In [67]:
torch.manual_seed(1337)
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 10 # what is the maximum context length for predictions?

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape, xb.dtype, xb.device)
print('targets:')
print(yb.shape, yb.dtype, yb.device)

inputs:
torch.Size([32, 10]) torch.int64 cuda:0
targets:
torch.Size([32, 10]) torch.int64 cuda:0


In [91]:
class NGramLanguageModel(nn.Module):
    def __init__(self, vocab_size, n):
        super().__init__()
        super().to(device)
        self.n = n
        assert n >= 3, "n should be at least 3"
        embedding_size = int(vocab_size / 1.5 + 10)
        intermediate_size = vocab_size + 10 * n
        self.token_embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
        self.fc = nn.Linear(embedding_size * n, intermediate_size)
        self.dropout = nn.Dropout(0.2)
        self.relu = nn.ReLU()
        self.final = nn.Linear(intermediate_size, vocab_size)

    # Create separate function for forward calculation
    def forward(self, x, only_last=False):
        assert len(x.shape) == 2, "input shape should be (batch, time)"

        if only_last:
            # pad time dim to at least n
            x = x[:, -self.n :]
            x = F.pad(x, (self.n - x.shape[1], 0), value=0)
            B, N = x.shape
            x = x.view(B, 1, N)
        else:
            new_x = torch.zeros((x.shape[0], x.shape[1], self.n), dtype=torch.int64, device=device) - 69
            for time_index in range(x.shape[1]):
                row = x[:, max(0, time_index - self.n + 1) : time_index + 1]
                row = F.pad(row, (self.n - row.shape[1], 0), value=0)
                new_x[:, time_index] = row
            x = new_x

        x = self.token_embedding(x)
        x = x.view(x.shape[0], x.shape[1], -1)
        x = self.fc(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.final(x)
        return x

    def loss(self, logits, targets):
        B, T, C = logits.shape
        logits_flat = logits.view(B * T, C)
        loss = F.cross_entropy(logits_flat, targets.view(B * T))
        return loss

    def generate(self, x, max_len_new, temperature=1.0):
        for _ in range(max_len_new):
            logits = self(x, True)[:, -1] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            x = torch.cat([x, next_token], dim=1)

        return x

In [98]:
model = NGramLanguageModel(vocab_size, 7)
summary(
    model,
    input_data=[torch.zeros((256, 10), dtype=torch.long, device=device), True],
    verbose=2,
    device=device,
    col_names=["input_size", "output_size", "num_params", "mult_adds"],
)
model.to(device)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Mult-Adds
NGramLanguageModel                       [256, 10]                 [256, 1, 66]              --                        --
├─Embedding: 1-1                         [256, 1, 7]               [256, 1, 7, 54]           3,564                     912,384
│    └─weight                                                                                └─3,564
├─Linear: 1-2                            [256, 1, 378]             [256, 1, 136]             51,544                    13,195,264
│    └─weight                                                                                ├─51,408
│    └─bias                                                                                  └─136
├─ReLU: 1-3                              [256, 1, 136]             [256, 1, 136]             --                        --
├─Dropout: 1-4                           [256, 1, 136]             [

NGramLanguageModel(
  (token_embedding): Embedding(66, 54, padding_idx=0)
  (fc): Linear(in_features=378, out_features=136, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (relu): ReLU()
  (final): Linear(in_features=136, out_features=66, bias=True)
)

In [99]:
print(model)
# logits, loss = model(xb, yb)
logits = model(xb)
loss = model.loss(logits, yb)
print('logits:', logits.shape)
print('loss:', loss)

print(decode(model.generate(xb, 10)[0].tolist()))

NGramLanguageModel(
  (token_embedding): Embedding(66, 54, padding_idx=0)
  (fc): Linear(in_features=378, out_features=136, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (relu): ReLU()
  (final): Linear(in_features=136, out_features=66, bias=True)
)
logits: torch.Size([256, 10, 66])
loss: tensor(4.2405, device='cuda:0', grad_fn=<NllLossBackward0>)
wful busindBJBihmh,$


In [100]:
# train the model
optimizer = optim.AdamW(model.parameters(), lr=0.01)
batch_size = 256
for step in range(1001):
    xb, yb = get_batch('train')
    # logits, loss = model(xb, yb)
    logits = model(xb)
    loss = model.loss(logits, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if step % 50 == 0:
        print(f'step: {step}, loss: {loss.item():.3f}')

step: 0, loss: 4.234
step: 50, loss: 2.183
step: 100, loss: 2.120
step: 150, loss: 2.031
step: 200, loss: 2.003
step: 250, loss: 2.015
step: 300, loss: 1.958
step: 350, loss: 1.982
step: 400, loss: 1.930
step: 450, loss: 1.930
step: 500, loss: 1.877
step: 550, loss: 1.893
step: 600, loss: 1.916
step: 650, loss: 1.899
step: 700, loss: 1.850
step: 750, loss: 1.948
step: 800, loss: 1.880
step: 850, loss: 1.908
step: 900, loss: 1.890
step: 950, loss: 1.904
step: 1000, loss: 1.967


In [104]:
for x in model.generate(torch.tensor([encode("LUCENT")] * 5, device=device), 200, 0.9):
    print(decode(x.tolist()))
    print('----')

LUCENTEO:
Awd ther atd Juppet now death fore even make the god twar the now
I gatill fawner love, and here great to minge to the stones tend been'd ineafes with heart; beth if seme in Corillonou, More his p
----
LUCENTIO:
Abould: whe cood to grave spunist
Wrot
Tome this sted earder this much and
Withese it
fulling lords
Is heory. servensiness spoked.

yheress hen lory.

LEONTES:
Why, low lond arrus; my do the venry
----
LUCENTIO:
How saus with the the prots
Of Prone har fpray, ther, thome cot my Englory holdiert,
Pould sweich as whene elcent nal tence
Alouted.

POLICINIUS:
Withenhy,
What insfor ermped he tuous,
I lis evert
----
LUCENTIO:
Nou dost kings carnou hus grovie noblone's.

WARWICK:
Oll quatthe condst in best
Liking hather. Draw you, buy lital;
When your pothie catitiod more, bast. But we,t look is ore naty. O copmucher, p
----
LUCENTIO:
The heesty,
Even that stal be not, and of a worn, not for Rer unns than stonour.

GRUMO:
Grest to enee, I thee, the 
Ono wer, more came spisen

In [105]:
model.generate(torch.tensor([encode("hello")], device=device), 10, 0.0001)

tensor([[47, 44, 51, 51, 54, 62,  2, 59, 47, 44,  2, 59, 47, 44,  2]],
       device='cuda:0')

In [109]:
# save model
import os, json

os.makedirs("tmp/ngram", exist_ok=True)

with open("tmp/ngram/config.json", "w") as f:
    json.dump({"chars": chars, "vocab_size": vocab_size, "n": model.n}, f)

torch.onnx.export(
    model,
    (torch.zeros(1, 1, dtype=torch.int32, device=device), True),
    "tmp/ngram/model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch", 1: "time"},
        "output": {0: "batch", 1: "time"},
    },
)

  if only_last:
