## Data Input

In [254]:
with open('../data/tiny_shakespeare.txt', 'r') as f:
    text = f.read()

In [255]:
print(len(text))

1115394


In [256]:
print(text[:400])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it 


## Character Level Tokenization

In [257]:
chars = sorted(set(text))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [258]:
char_to_index = {ch: i for i, ch in enumerate(chars)}
index_to_char = {i: ch for i, ch in enumerate(chars)} 

encode = lambda s: [char_to_index[c] for c in s]
decode = lambda ids: ''.join([index_to_char[i] for i in ids])

input_txt = "hello world"
encoded_data = encode(input_txt)
decoded_data = decode(encoded_data)
print(encoded_data)
print(decoded_data)

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
hello world


In [259]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:400])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

## Train, Validation Split

In [260]:
train_split_percentage = 0.9
n = int(train_split_percentage * len(data))
train_data = data[:n]
validation_data = data[n:]

## Context Block

In [261]:
context_length = 8

print(train_data[:context_length + 1])
print(decode(train_data[:context_length + 1].numpy()))

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])
First Cit


In [262]:
x = train_data[:context_length]
y = train_data[1:context_length + 1]

print(x)
print(y)
print()

for t in range(context_length):
    context = x[:t + 1] 
    target = y[t]
    print(f"input: {context}, target: {target}")
    print(f"input: {context}, target: {target}")

tensor([18, 47, 56, 57, 58,  1, 15, 47])
tensor([47, 56, 57, 58,  1, 15, 47, 58])

input: tensor([18]), target: 47
input: tensor([18]), target: 47
input: tensor([18, 47]), target: 56
input: tensor([18, 47]), target: 56
input: tensor([18, 47, 56]), target: 57
input: tensor([18, 47, 56]), target: 57
input: tensor([18, 47, 56, 57]), target: 58
input: tensor([18, 47, 56, 57]), target: 58
input: tensor([18, 47, 56, 57, 58]), target: 1
input: tensor([18, 47, 56, 57, 58]), target: 1
input: tensor([18, 47, 56, 57, 58,  1]), target: 15
input: tensor([18, 47, 56, 57, 58,  1]), target: 15
input: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
input: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58
input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


## Batching

In [263]:
torch.manual_seed(42)
batch_size = 4
context_length = 8


def get_batch(split):
    data = train_data if split == "train" else validation_data
    ix = torch.randint(low=0, high=len(data) - context_length, size=(batch_size, ))
    x = torch.stack([data[i: i + context_length] for i in ix])
    y = torch.stack([data[i + 1: i + context_length + 1] for i in ix])
    return x, y


xb, yb = get_batch("train")
print("inputs")
print(xb.shape)
print(xb)
print("targets")
print(yb.shape)
print(yb)

print('-' * 10)

for b in range(batch_size):
    for t in range(context_length):
        context = xb[b, :t + 1]
        target = yb[b, t]
        print(f"input: {context}, target: {target}")

inputs
torch.Size([4, 8])
tensor([[57,  1, 46, 47, 57,  1, 50, 53],
        [ 1, 58, 46, 43, 56, 43,  1, 41],
        [17, 26, 15, 17, 10,  0, 32, 53],
        [57, 58,  6,  1, 61, 47, 58, 46]])
targets
torch.Size([4, 8])
tensor([[ 1, 46, 47, 57,  1, 50, 53, 60],
        [58, 46, 43, 56, 43,  1, 41, 39],
        [26, 15, 17, 10,  0, 32, 53,  1],
        [58,  6,  1, 61, 47, 58, 46,  0]])
----------
input: tensor([57]), target: 1
input: tensor([57,  1]), target: 46
input: tensor([57,  1, 46]), target: 47
input: tensor([57,  1, 46, 47]), target: 57
input: tensor([57,  1, 46, 47, 57]), target: 1
input: tensor([57,  1, 46, 47, 57,  1]), target: 50
input: tensor([57,  1, 46, 47, 57,  1, 50]), target: 53
input: tensor([57,  1, 46, 47, 57,  1, 50, 53]), target: 60
input: tensor([1]), target: 58
input: tensor([ 1, 58]), target: 46
input: tensor([ 1, 58, 46]), target: 43
input: tensor([ 1, 58, 46, 43]), target: 56
input: tensor([ 1, 58, 46, 43, 56]), target: 43
input: tensor([ 1, 58, 46, 43, 56

## Model

In [264]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)


class ScaledDotProductAttention(nn.Module):
    def __init__(self, embed_dim: int, context_length: int, head_dim: int = 32, causal: bool = True):
        super(ScaledDotProductAttention, self).__init__()
        self.head_dim = head_dim
        self.causal = causal
        self.to_key = nn.Linear(embed_dim, head_dim, bias=False)
        self.to_query = nn.Linear(embed_dim, head_dim, bias=False)
        self.to_value = nn.Linear(embed_dim, head_dim, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones((context_length, context_length))))

    def forward(self, x):
        b, t, c = x.shape
        k = self.to_key(x)
        q = self.to_key(x)
        v = self.to_key(x)

        # Attention scores and masking for autoregressive for transformer decoder only LM.
        attn_weights = q @ k.transpose(-1, -2) * (self.head_dim ** -0.5)
        if self.causal:
            attn_weights = attn_weights.masked_fill(self.tril[:t, :t] == 0, float('-inf'))
        
        attn_weights = F.softmax(attn_weights, dim=-1)
        out = attn_weights @ v
        return out, attn_weights


In [265]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, context_length: int, num_heads: int, head_dim: int = 32, causal: bool = True):
        super(MultiHeadAttention, self).__init__()
        self.sdp_heads = nn.ModuleList([
            ScaledDotProductAttention(
                embed_dim=embed_dim, context_length=context_length, head_dim=embed_dim // num_heads, causal=causal
            ) for _ in range(num_heads)
        ])

    def forward(self, x):
        head_attn_out = []
        head_attn_weights_out = []
        for h in self.sdp_heads:
            attn_out, attn_weights = h(x)
            head_attn_out.append(attn_out)
            head_attn_weights_out.append(attn_weights)

        head_attn_out = torch.cat(head_attn_out, dim=-1)
        head_attn_weights_out = torch.cat(head_attn_weights_out, dim=-1)
        return head_attn_out, head_attn_weights_out

Instead of direct logits embedding is generated from input indices which is reshaped to vocab size for softmax in loss. Also positional embedding is added to each index position.

In [266]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)


class BigramLanguageModel(nn.Module):
    def __init__(
            self, vocab_size: int, device: str, context_length: int = 8, embed_dim: int = 32, head_dim: int = 32, num_heads: int = 4
        ):
        super(BigramLanguageModel, self).__init__()
        self.device = device
        self.context_length = context_length
        self.token_embedding_table = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
        self.position_embedding_table = nn.Embedding(num_embeddings=context_length, embedding_dim=embed_dim)
        
        # self.attn_head = ScaledDotProductAttention(embed_dim=embed_dim, head_dim=head_dim, context_length=context_length)
        # self.lm_head = nn.Linear(head_dim, vocab_size)

        self.attn_head = torch.jit.script(
            MultiHeadAttention(
                embed_dim=embed_dim, context_length=context_length, causal=True, num_heads=num_heads,
            )
        )
        # print(self.attn_head.code)
        self.lm_head = nn.Linear(embed_dim, vocab_size)


    def forward(self, idx, targets=None):
        """Logits in shape of (batch, time_dim, channel_dim) which is reshaped to 2d tensor for cross entropy loss. 
        `t` is the time dimension context_length and `c` is channel dim each token embedding.
        """
        b, t = idx.shape

        token_embeddings = self.token_embedding_table(idx)  # (b, t, embed_dim)
        position_embeddings = self.position_embedding_table(torch.arange(t, device=self.device))    # (t, embed_dim)
        x = token_embeddings + position_embeddings  # (b, t, embed_dim)

        x, attn_weights = self.attn_head(x)

        logits = self.lm_head(x)    # (b, t, vocab_size)

        if targets is None:
            loss = None
        else:
            b, t, c = logits.shape
            logits = logits.view(b * t, c)
            targets = targets.view(b * t)

            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens: int):
        """Generates new token from previous token and sample top 1 from softmax probs for next id. Due to positional 
        embedding table set to context length anything beyond will cause error so idx are truncated to last context
        indices.
        """
        for _ in range(max_new_tokens):
            idx_truncated = idx if idx.size(1) <= self.context_length else idx[:, -self.context_length:]
            logits, _ = self(idx_truncated)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            ids_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, ids_next], dim=1)

        return idx

## Setup

In [267]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [268]:
model = BigramLanguageModel(vocab_size=vocab_size, device=device).to(device)

# model = torch.jit.trace(model, (xb.to(device), yb.to(device)))

logits, loss = model(xb.to(device), yb.to(device))
print(logits.shape)
print(loss)


pred_token_idx = model.generate(torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=100)
print(pred_token_idx)
print(pred_token_idx.shape)
print(decode(pred_token_idx[0].tolist()))

torch.Size([32, 65])
tensor(4.3043, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor([[ 0, 42,  0,  9, 20,  1, 10,  5, 11, 30, 52, 36, 40, 64, 16, 28,  5, 15,
         52, 32,  7,  2,  7,  0,  5, 59, 37, 53, 33, 42, 42, 36, 30, 39, 30, 52,
         55, 16, 15, 49,  6, 36, 54, 10, 62, 32, 49, 22,  6, 59, 59, 20, 43, 38,
         28, 23, 51, 36, 17, 28,  7,  4, 13, 31,  9, 62, 23, 50, 30, 54, 27, 61,
         47, 50, 15,  8, 25, 21, 17, 21, 32, 28, 18, 37, 55, 56, 45, 47, 43, 54,
         28, 57, 13, 17, 37, 28, 22,  8, 35, 19, 21]], device='cuda:0')
torch.Size([1, 101])

d
3H :';RnXbzDP'CnT-!-
'uYoUddXRaRnqDCk,Xp:xTkJ,uuHeZPKmXEP-&AS3xKlRpOwilC.MIEITPFYqrgiepPsAEYPJ.WGI


In [269]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [270]:
from contextlib import nullcontext
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=torch.float16)

## Multiple Batch Loss Evaluation 

In [271]:
batch_size = 32
train_iters = 10000
eval_iters = 1000

In [272]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            x = x.to(device)
            y = y.to(device)
            with ctx:
                logits, loss = model(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()            
    return out

## Training

In [273]:
for i in range(train_iters):
    xb, yb = get_batch('train')
    xb = xb.to(device)
    yb = yb.to(device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if i % eval_iters == 0:
        # losses = estimate_loss()
        # print(f"step: {i}: train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}")

        x, y = get_batch('val')
        x = x.to(device)
        y = y.to(device)
        with ctx:
            val_logits, val_loss = model(x, y)

        print(f"step: {i}: train loss: {loss:.4f}, val loss: {val_loss:.4f}")


print(f"final train loss: {loss.item():.4f}, val loss: {val_loss:.4f}")


step: 0: train loss: 4.2784, val loss: 4.2993
step: 1000: train loss: 2.5088, val loss: 2.6300
step: 2000: train loss: 2.4585, val loss: 2.3292
step: 3000: train loss: 2.5542, val loss: 2.4001
step: 4000: train loss: 2.3813, val loss: 2.4091
step: 5000: train loss: 2.2930, val loss: 2.3591
step: 6000: train loss: 2.3055, val loss: 2.5621
step: 7000: train loss: 2.2989, val loss: 2.3243
step: 8000: train loss: 2.1807, val loss: 2.3433
step: 9000: train loss: 2.2144, val loss: 2.2489
final train loss: 2.2062, val loss: 2.2489


## Generation

In [274]:
pred_token_idx = model.generate(torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=400)
# print(pred_token_idx)
print(pred_token_idx.shape)
print(decode(pred_token_idx[0].tolist()))

torch.Size([1, 401])

To yo,
Ke evy ditleremang yofordss
Salllangean thy, las thapto the: hankel haid hetien wistwurore cense wince me, wours theevend torest 'elm; hy, whan he wlars oln!-
Whind sowhellsen, ler ron thowe:
Thino be's lor thy, ba phallpeneme thilghshen, hearunghto brinty Hockis thed ar te sue rerak sit
Thanes nomnfoe bivos Yow:
K:
Tharease V: on my for lomat ssot alorulh, mare ake theald be, gt ut wnoide.
