# GPT From Scratch
Inspired by Andrej Karpathy's video lecture: https://www.youtube.com/watch?v=kCc8FmEb1nY

In [28]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn import functional as F
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [29]:
ds = load_dataset("roneneldan/TinyStories")

In [30]:
ds

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})

In [31]:
# combining all the text that we have
text = "\n".join(ds['train']['text'] + ds['validation']['text'])

In [32]:
print("#characters:", len(text))
print(text[:1000])

#characters: 1921305229
One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.
Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.

One day, Beep was driving in the park when he saw a big tree. The tree had many lea

In [33]:
text = text[:1921300]

In [34]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))  # all unique characters
print(vocab_size)


 !"$',-.012345678:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz¡¦©«±³»ÂÃâœ˜“”€™
89


In [35]:
# encoding characters to integers
ctoi = {char:i for i, char in enumerate(chars)}
itoc = {i:char for i, char in enumerate(chars)}

In [36]:
# defining the encoding and decoding functions as lambda functions
encode = lambda s: [ctoi[c] for c in s]
decode = lambda l: ''.join([itoc[i] for i in l])

In [37]:
print(encode("hello NITK"))
print(decode(encode("hello NITK")))

[54, 51, 58, 58, 61, 1, 34, 29, 40, 31]
hello NITK


In [38]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1921300]) torch.int64
tensor([35, 60, 51,  1, 50, 47, 71,  6,  1, 47,  1, 58, 55, 66, 66, 58, 51,  1,
        53, 55, 64, 58,  1, 60, 47, 59, 51, 50,  1, 32, 55, 58, 71,  1, 52, 61,
        67, 60, 50,  1, 47,  1, 60, 51, 51, 50, 58, 51,  1, 55, 60,  1, 54, 51,
        64,  1, 64, 61, 61, 59,  8,  1, 39, 54, 51,  1, 57, 60, 51, 69,  1, 55,
        66,  1, 69, 47, 65,  1, 50, 55, 52, 52, 55, 49, 67, 58, 66,  1, 66, 61,
         1, 62, 58, 47, 71,  1, 69, 55, 66, 54,  1, 55, 66,  1, 48, 51, 49, 47,
        67, 65, 51,  1, 55, 66,  1, 69, 47, 65,  1, 65, 54, 47, 64, 62,  8,  1,
        32, 55, 58, 71,  1, 69, 47, 60, 66, 51, 50,  1, 66, 61,  1, 65, 54, 47,
        64, 51,  1, 66, 54, 51,  1, 60, 51, 51, 50, 58, 51,  1, 69, 55, 66, 54,
         1, 54, 51, 64,  1, 59, 61, 59,  6,  1, 65, 61,  1, 65, 54, 51,  1, 49,
        61, 67, 58, 50,  1, 65, 51, 69,  1, 47,  1, 48, 67, 66, 66, 61, 60,  1,
        61, 60,  1, 54, 51, 64,  1, 65, 54, 55, 64, 66,  8,  0,  0, 32, 55, 58,
      

In [39]:
# we will just want to split this limited data into train and validation
n = int(0.9*len(data)) # 90% training
train_data = data[:n]
val_data = data[n:]

In [107]:
n

1729170

## Defining the length of the sampled blocks

In [40]:
block_size = 8 # so our "context window" length is 8

In [41]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f2f8e2fb9b0>

## Defining the Dataset

In [42]:
class CharDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        chunk = self.data[idx : idx + self.block_size + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return x, y

In [43]:
train_dataset = CharDataset(train_data, block_size)
val_dataset = CharDataset(val_data, block_size)

In [44]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [45]:
xb, yb = next(iter(train_loader))
print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

inputs:
torch.Size([32, 8])
tensor([[ 1, 66, 54, 51,  1, 51, 60, 50],
        [48, 47, 49, 57,  1, 66, 61,  1],
        [ 1, 28, 51,  1, 49, 61, 67, 58],
        [51,  1, 66, 54, 51,  1, 51, 47],
        [64, 55, 50, 51,  1, 66, 54, 51],
        [54, 61, 67, 53, 54, 66,  1, 61],
        [ 8,  1, 28, 51,  1, 65, 47, 55],
        [24, 47, 50,  6,  1, 82, 87, 83],
        [60, 50,  8,  1, 24, 61, 53,  1],
        [51, 60,  1, 69, 55, 66, 54,  1],
        [47, 50,  8,  1, 28, 51,  1, 65],
        [47,  1, 49, 64, 55, 51, 50,  8],
        [61, 69,  1, 65, 69, 51, 51, 66],
        [ 3,  1,  0,  0, 40, 55, 59, 59],
        [55, 60, 53,  1, 66, 54, 51,  1],
        [52, 67, 58,  1, 52, 55, 65, 54],
        [ 1, 54, 51, 58, 62,  1, 71, 61],
        [47, 64, 51,  1, 48, 61, 66, 54],
        [55, 58, 71,  1, 66, 61, 58, 50],
        [51,  1, 49, 61, 59, 62, 58, 51],
        [64, 64, 71,  8,  1, 39, 54, 51],
        [ 1, 48, 55, 64, 50, 65,  1, 47],
        [55, 50,  8,  0,  0, 40, 54, 51],
      

## Context-target preview

In [46]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"when input is {context.tolist()} the target: {target}")

when input is [1] the target: 66
when input is [1, 66] the target: 54
when input is [1, 66, 54] the target: 51
when input is [1, 66, 54, 51] the target: 1
when input is [1, 66, 54, 51, 1] the target: 51
when input is [1, 66, 54, 51, 1, 51] the target: 60
when input is [1, 66, 54, 51, 1, 51, 60] the target: 50
when input is [1, 66, 54, 51, 1, 51, 60, 50] the target: 1
when input is [48] the target: 47
when input is [48, 47] the target: 49
when input is [48, 47, 49] the target: 57
when input is [48, 47, 49, 57] the target: 1
when input is [48, 47, 49, 57, 1] the target: 66
when input is [48, 47, 49, 57, 1, 66] the target: 61
when input is [48, 47, 49, 57, 1, 66, 61] the target: 1
when input is [48, 47, 49, 57, 1, 66, 61, 1] the target: 66
when input is [1] the target: 28
when input is [1, 28] the target: 51
when input is [1, 28, 51] the target: 1
when input is [1, 28, 51, 1] the target: 49
when input is [1, 28, 51, 1, 49] the target: 61
when input is [1, 28, 51, 1, 49, 61] the target: 67

## Bigram Language Model
* The vocabulary size is vocab_size.
* The model uses an embedding table of shape (vocab_size, vocab_size) to directly map each input token to the logits over the next token.

In [47]:
class BiGramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # embedding layer (it's like a lookup table) that maps each token (index) to a vector of size vocab_size.
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # B: batch size
        # T: time/context length
        # idx.shape = (B, T) --> used as a query to the lookup table
        # targets.shape = (B, T) --> each idx is mapped to one fixed ground truth target
        logits = self.token_embedding_table(idx) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, vocab_size = logits.shape
            # logits: for each token position in the batch, the model outputs a probability distribution over all possible next tokens.
            # targets: the correct next token index for each position.
            logits = logits.view(B*T, vocab_size)
            targets = targets.view(B*T)
            # We use cross-entropy loss to measure how well the predicted probability distribution (from logits) matches the true distribution (one-hot at the correct target index).
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # only the logits from the last time step are required as we want to find the logits for the next time step based on the current time step
            # since it’s a bigram model, it has no context window, only 1-step memory, so the output is essentially a Markov chain
            logits = logits[:, -1, :] # becomes (B, vocab_size)
            # softmaxxing for probabilities
            probs = F.softmax(logits, dim=-1) # (B, vocab_size)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [59]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [58]:
bgm = BiGramLanguageModel(vocab_size).to(device)
logits, loss = bgm(xb, yb)
print(logits.shape)
print(loss)

torch.Size([256, 89])
tensor(4.9285, device='cuda:0', grad_fn=<NllLossBackward0>)


In [61]:
print(decode(bgm.generate(idx = torch.zeros((1, 1), dtype=torch.long).to(device), max_new_tokens=100)[0].tolist()))


¦sœ? xW2XBâx!7H'pfL©!:©wuQ:ztVPJS2©˜7±ZKHQTpU¦B,œs?WRâuHPxA--4œ4pCDJE“o6Â6âEA”ALH³€g±YgH,Gy4My¦x.O;»


## Training BiGramLanguageModel

In [62]:
optimizer = torch.optim.AdamW(bgm.parameters(), lr=1e-3)


In [63]:
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

In [64]:
batch_size = 32
block_size = 8
eval_interval = 500
epochs = 20000
eval_iters = 200

In [65]:
def compute_accuracy(logits, targets):
    preds = torch.argmax(logits, dim=-1)
    correct = (preds == targets).float()
    return correct.sum() / correct.numel()

In [66]:
device

'cuda'

In [67]:
step = 0
bgm.train()
pbar = tqdm(total=epochs)
train_iter = iter(train_loader)

while step < epochs:
    try:
        xb, yb = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        xb, yb = next(train_iter)

    xb, yb = xb.to(device), yb.to(device)
    # print(xb.shape, yb.shape)
    logits, loss = bgm(xb, yb)
    if(step%100 == 0):
        print(f"Epoch[{step+1}/{epochs}] ---- Training loss = {loss}")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # record training metrics every eval_interval steps
    if step % eval_interval == 0:
        bgm.eval()
        with torch.no_grad():
            # training eval
            train_logits, _ = bgm(xb, yb)
            train_acc = compute_accuracy(train_logits.view(-1, vocab_size), yb.view(-1)).item()
            train_losses.append(loss.item())
            train_accuracies.append(train_acc)

            #  validation eval
            val_loss_total, val_correct, val_total = 0.0, 0, 0
            for val_xb, val_yb in val_loader:
                val_xb, val_yb = val_xb.to(device), val_yb.to(device)
                val_logits, val_loss = bgm(val_xb, val_yb)
                val_loss_total += val_loss.item()
                
                preds = val_logits.view(-1, vocab_size).argmax(dim=-1)  # shape: (B*T,)
                targets = val_yb.view(-1)  # shape: (B*T,)
                val_correct += (preds == targets).sum().item()
                val_total += val_yb.numel()
            
            val_losses.append(val_loss_total / len(val_loader))
            val_accuracies.append(val_correct / val_total)
            print(f"Epoch[{step+1}/{epochs}] ---- VALIDATION LOSS = {val_loss_total} --- VALIDATION ACC = {val_correct/val_total}")
        bgm.train()
    
    step += 1
    pbar.update(1)

pbar.close()

  0%|          | 0/20000 [00:00<?, ?it/s]

Epoch[1/20000] ---- Training loss = 4.922492504119873
Epoch[1/20000] ---- VALIDATION LOSS = 29878.781811237335 --- VALIDATION ACC = 0.0012440012075660258
Epoch[101/20000] ---- Training loss = 4.875812530517578
Epoch[201/20000] ---- Training loss = 4.686099529266357
Epoch[301/20000] ---- Training loss = 4.595996379852295
Epoch[401/20000] ---- Training loss = 4.52496337890625
Epoch[501/20000] ---- Training loss = 4.364192485809326
Epoch[501/20000] ---- VALIDATION LOSS = 26125.220281362534 --- VALIDATION ACC = 0.01293058577362301
Epoch[601/20000] ---- Training loss = 4.305454254150391
Epoch[701/20000] ---- Training loss = 4.176881790161133
Epoch[801/20000] ---- Training loss = 4.035976886749268
Epoch[901/20000] ---- Training loss = 4.035033702850342
Epoch[1001/20000] ---- Training loss = 3.858185291290283
Epoch[1001/20000] ---- VALIDATION LOSS = 23004.388361930847 --- VALIDATION ACC = 0.1464896003581058
Epoch[1101/20000] ---- Training loss = 3.826618194580078
Epoch[1201/20000] ---- Traini

In [68]:
print(decode(bgm.generate(idx = torch.zeros((1, 1), dtype=torch.long).to(device), max_new_tokens=1000)[0].tolist()))


Ugaltaire Bare "Hed te fr toacolike they, e She s m, ts t ink elee then che. aband wonoon The g pat frthowale tr wanike thenntie ste lll's e he enoulanysood " h towathetha w umad wis, uly ars fogomithe at t bouno the wito he, shedee cld wabldofupetcamy?" gopp!" t tithas le topy ho Hedan thano the heoke befoorivenghe wost to bey!"Liman g hey. prasoumook fllad sm! ase wasacad. sotofriton t thrflave. vike. wamad pperm ay's hed mery's waple them. s brset, wend soy uplds. Ithe Shand t wencin meund vie. havey, che o Tivye Thand cito choo." se lelo Shillthaf h min wayoper d. teeryoo we berede ithe p ad w â€ ber toor l. tt hedaine tt ind aie steckutheruda tend'do w th toppp alin ad vediedd he he he sert enke He The her tsatthe siethand. om vey f sath s wat f fe the dere. Ithelend ke 
Thed.

Onen, uttuthedom s s uthervo whed d t s canthe Thili! as a bim. meltediewaitisersorcong Soo lfftr a ay ba h walisen ateys be spe t ver "I˜”â€ thter tht towod, as He hushe g hese t avess athernepinghe nttt.

### Self-Attention Tested

In [69]:
torch.manual_seed(1337)
B,T,vocab_size = 4,8,32 # batch, time, vocab_size
x = torch.randn(B,T,vocab_size)

# single-head self-attention
head_size = 16
key = nn.Linear(vocab_size, head_size, bias=False)
query = nn.Linear(vocab_size, head_size, bias=False)
value = nn.Linear(vocab_size, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) . (B, 16, T) = (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [70]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

# Building GPT

In [71]:
vocab_size = 89

In [90]:
eval_interval = 500
epochs = 5000
eval_iters = 200
n_emb = 384
num_heads = 6
num_layers = 6
dropout = 0.2
block_size = 8
batch_size = 32

In [91]:
class SelfAttentionHead(nn.Module):
    """Single head for self-attention"""
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_emb, head_size, bias = False)
        self.query = nn.Linear(n_emb, head_size, bias = False)
        self.value = nn.Linear(n_emb, head_size, bias = False)

        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # to save alongside model parameters
        # buffer --> gradient doesn't flow through it

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x = (B, T, n_emb) --> not necessarily vocab_size now
        B, T, n_emb = x.shape
        k = self.key(x) # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)

        # we want to now calculate the attention scores (affinities of each query vector with each key vector)
        attn = q@k.transpose(-2,-1) * k.shape[-1]**(-0.5) # (B, T, head_size) . (B, head_size, T) = (B, T, T)
        # masking
        masked_attn = attn.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)

        # we have seen that k.shape[-1] == q.shape[-1] == T is not always necessarily == head_size, which is why we code it the way we have

        # softmaxxing gives probability distribution
        prob_masked_attn = F.softmax(masked_attn, dim = -1) # (B, T == each query, T == probability dist over each key)

        # applying dropout
        prob_masked_attn = self.dropout(prob_masked_attn)

        # now we need to perform weighted aggregation of the values
        v = self.value(x) # (B, T, head_size)
        out = prob_masked_attn @ v # (B, T, T) . (B, T, head_size) = (B, T, head_size)
        return out
        

In [92]:
class MultiHeadAttention(nn.Module):
    """Multiple self-attention heads in parallel"""
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size*num_heads, n_emb)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x = (B, T, n_emb)
        # we want to get the output from all the heads and concatenate along the feature dimension
        out = torch.cat([head(x) for head in self.heads], dim=-1) # (B, T, num_heads * head_size)
        out = self.proj(out)  # project back to n_emb
        out = self.dropout(out)
        return out

In [93]:
class FeedForward(nn.Module):
    """Feed forward after the Multi-head Attn"""
    def __init__(self, n_emb):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(n_emb, 4*n_emb), 
            nn.ReLU(),
            nn.Linear(4*n_emb, n_emb),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ff(x)

In [94]:
class Block(nn.Module):
    """Full transformer decoder block used in our GPT implementation"""
    def __init__(self, n_emb, num_heads):
        super().__init__()
        head_size = n_emb // num_heads
        self.selfattn = MultiHeadAttention(num_heads, head_size)
        self.ff = FeedForward(n_emb)
        self.ln1 = nn.LayerNorm(n_emb)
        self.ln2 = nn.LayerNorm(n_emb)

    def forward(self, x):
        x = x + self.selfattn(self.ln1(x))  # Residual + Self-attention
        x = x + self.ff(self.ln2(x))  # Residual + FeedForward
        return x

In [95]:
class GPTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_emb) # same as token_embedding_table
        self.position_embedding = nn.Embedding(block_size, n_emb) # we can say position_embedding_table

        self.blocks = nn.Sequential(
            *[
                Block(n_emb, num_heads) for _ in range(num_layers)
            ]
        )

        self.ln_f = nn.LayerNorm(n_emb) # final layer norm
        self.head = nn.Linear(n_emb, vocab_size) # final projection to logits

        self.block_size = block_size

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # check for the input length
        assert T <= self.block_size, f"Cannot forward, sequence length {T} exceeds block size {self.block_size}"

        token_embed = self.token_embedding(idx) # (B, T, n_emb)
        position_embed = self.position_embedding(torch.arange(T, device = device)) # (T, n_emb)
        x = token_embed + position_embed # (B, T, n_emb)
        x = self.blocks(x) # (B, T, n_emb)
        x = self.ln_f(x) # (B, T, n_emb)
        logits = self.head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None

        else:
            B, T, vocab_size = logits.shape
            logits = logits.view(B*T, vocab_size)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:] # crop to block_size
            logits, _ = self(idx_cond)
            logits = logits[:,-1,:] # # (B, vocab_size) - last token logits
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, next_token), dim=1) # (B, T+1)
        return idx
        

In [96]:
model = GPTModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

10.712153 M parameters


In [97]:
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long).to(device), max_new_tokens=100)[0].tolist()))


c
?:YtubA-dZH'Wd”St2'ÂWu,'U±qWdl“y;¦:edte“6“AgO€B?6lÂ23K2'v4ISBâ™ULbD³NœoEoMeX6LSVbif;Ã0qqme-ZuSâ?CT


In [98]:
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

In [99]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [104]:
step = 0
model.train()
pbar = tqdm(total=epochs)
train_iter = iter(train_loader)

while step < epochs:
    try:
        xb, yb = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        xb, yb = next(train_iter)

    xb, yb = xb.to(device), yb.to(device)
    # print(xb.shape, yb.shape)
    logits, loss = model(xb, yb)
    if(step%100 == 0):
        print(f"Epoch[{step+1}/{epochs}] ---- Training loss = {loss}")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # record training metrics every eval_interval steps
    if step % eval_interval == 0:
        model.eval()
        with torch.no_grad():
            # training eval
            # print("Training eval")
            train_logits, _ = model(xb, yb)
            train_acc = compute_accuracy(train_logits.view(-1, vocab_size), yb.view(-1)).item()
            train_losses.append(loss.item())
            train_accuracies.append(train_acc)

            #  validation eval
            # print("Validation eval")
            val_loss_total, val_correct, val_total = 0.0, 0, 0
            for val_xb, val_yb in val_loader:
                val_xb, val_yb = val_xb.to(device), val_yb.to(device)
                val_logits, val_loss = model(val_xb, val_yb)
                val_loss_total += val_loss.item()
                
                preds = val_logits.view(-1, vocab_size).argmax(dim=-1)  # shape: (B*T,)
                targets = val_yb.view(-1)  # shape: (B*T,)
                val_correct += (preds == targets).sum().item()
                val_total += val_yb.numel()
            
            val_losses.append(val_loss_total / len(val_loader))
            val_accuracies.append(val_correct / val_total)
            print(f"Epoch[{step+1}/{epochs}] ---- VALIDATION LOSS = {val_loss_total} --- VALIDATION ACC = {val_correct/val_total}")
        model.train()
    
    step += 1
    pbar.update(1)

pbar.close()

  0%|          | 0/5000 [00:00<?, ?it/s]

Epoch[1/5000] ---- Training loss = 3.0992045402526855
Epoch[1/5000] ---- VALIDATION LOSS = 18190.350291252136 --- VALIDATION ACC = 0.24980546215425614
Epoch[101/5000] ---- Training loss = 2.321361780166626
Epoch[201/5000] ---- Training loss = 2.247704029083252
Epoch[301/5000] ---- Training loss = 2.0146918296813965
Epoch[401/5000] ---- Training loss = 1.9151442050933838
Epoch[501/5000] ---- Training loss = 1.9007267951965332
Epoch[501/5000] ---- VALIDATION LOSS = 11237.518462061882 --- VALIDATION ACC = 0.4446614651107109
Epoch[601/5000] ---- Training loss = 1.766221284866333
Epoch[701/5000] ---- Training loss = 1.8081618547439575
Epoch[801/5000] ---- Training loss = 1.7742477655410767
Epoch[901/5000] ---- Training loss = 1.9560261964797974
Epoch[1001/5000] ---- Training loss = 1.671762466430664
Epoch[1001/5000] ---- VALIDATION LOSS = 10500.084830343723 --- VALIDATION ACC = 0.473748451504773
Epoch[1101/5000] ---- Training loss = 1.665525197982788
Epoch[1201/5000] ---- Training loss = 1.

In [None]:
vocab_size

In [106]:
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long).to(device), max_new_tokens=1000)[0].tolist()))


Wick your chan birk an a sealh. 

Mack and spart." Timmy was dot it laughed to looked. She home, bear when. And special car cound at y, Lily lamped in amine home to make about tok.

Bentink there," said, "You're diden. Sammy, when and with a griends. They were give and take happy tooks everyone was so the both and the latten shark the trake.

Sam loods to ust is a big back yarty to the smiled and was day, they skide. He want not with arroway pick, busy - calloof the adn't dreamp. Timmy ater someone was a didn't lyook, but the was he cattant the decom. Lily and him was sâ€mined his fish, that some you?"

Daid notiched ever fast and spedful was may well boaves courn, find let you, Mury lated to go look to smaly that had like her do coki, "That secaree was friend scade and save limbith'. Her mom was a maide Bob?"

Mummy ater. "Sammy in nevery like fun. Tom to to the duck anythings found eat playing like to trucky it corning!" Beily looked are was a liked the sparthing they were pall. He 