## Data Input

In [1]:
with open('../data/tiny_shakespeare.txt', 'r') as f:
    text = f.read()

In [2]:
print(len(text))

1115394


In [3]:
print(text[:400])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it 


## Character Level Tokenization

In [4]:
chars = sorted(set(text))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
char_to_index = {ch: i for i, ch in enumerate(chars)}
index_to_char = {i: ch for i, ch in enumerate(chars)} 

encode = lambda s: [char_to_index[c] for c in s]
decode = lambda ids: ''.join([index_to_char[i] for i in ids])

input_txt = "hello world"
encoded_data = encode(input_txt)
decoded_data = decode(encoded_data)
print(encoded_data)
print(decoded_data)

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
hello world


In [6]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:400])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

## Train, Validation Split

In [7]:
train_split_percentage = 0.9
n = int(train_split_percentage * len(data))
train_data = data[:n]
validation_data = data[n:]

## Context Block

In [8]:
context_length = 8

print(train_data[:context_length + 1])
print(decode(train_data[:context_length + 1].numpy()))

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])
First Cit


In [9]:
x = train_data[:context_length]
y = train_data[1:context_length + 1]

print(x)
print(y)
print()

for t in range(context_length):
    context = x[:t + 1] 
    target = y[t]
    print(f"input: {context}, target: {target}")
    print(f"input: {context}, target: {target}")

tensor([18, 47, 56, 57, 58,  1, 15, 47])
tensor([47, 56, 57, 58,  1, 15, 47, 58])

input: tensor([18]), target: 47
input: tensor([18]), target: 47
input: tensor([18, 47]), target: 56
input: tensor([18, 47]), target: 56
input: tensor([18, 47, 56]), target: 57
input: tensor([18, 47, 56]), target: 57
input: tensor([18, 47, 56, 57]), target: 58
input: tensor([18, 47, 56, 57]), target: 58
input: tensor([18, 47, 56, 57, 58]), target: 1
input: tensor([18, 47, 56, 57, 58]), target: 1
input: tensor([18, 47, 56, 57, 58,  1]), target: 15
input: tensor([18, 47, 56, 57, 58,  1]), target: 15
input: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
input: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58
input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


## Hyperparameters

In [10]:
torch.manual_seed(42)
batch_size = 64
context_length = 256
train_iters = 10000
eval_iters = 30
learning_rate = 6e-4

## Batching

In [11]:
def get_batch(split):
    data = train_data if split == "train" else validation_data
    ix = torch.randint(low=0, high=len(data) - context_length, size=(batch_size, ))
    x = torch.stack([data[i: i + context_length] for i in ix])
    y = torch.stack([data[i + 1: i + context_length + 1] for i in ix])
    return x, y


xb, yb = get_batch("train")
print("inputs")
print(xb.shape)
print(xb)
print("targets")
print(yb.shape)
print(yb)

print('-' * 10)

# Demo with showing only first batch and 20 char context.
for b in range(batch_size):
    for t in range(context_length):
        context = xb[b, :t + 1]
        target = yb[b, t]
        print(f"input: {context}, target: {target}")

        if t > 8:
            break
    break

inputs
torch.Size([64, 256])
tensor([[54, 43, 63,  ..., 54, 43, 63],
        [57,  1, 61,  ..., 47, 52,  1],
        [57,  1, 58,  ..., 46, 39, 58],
        ...,
        [23, 17,  1,  ..., 24, 17, 10],
        [43, 50, 54,  ...,  1, 21, 51],
        [63, 53, 59,  ..., 63,  1, 58]])
targets
torch.Size([64, 256])
tensor([[43, 63,  1,  ..., 43, 63, 12],
        [ 1, 61, 53,  ..., 52,  1, 61],
        [ 1, 58, 46,  ..., 39, 58, 46],
        ...,
        [17,  1, 27,  ..., 17, 10,  0],
        [50, 54,  1,  ..., 21, 51, 54],
        [53, 59,  1,  ...,  1, 58, 39]])
----------
input: tensor([54]), target: 43
input: tensor([54, 43]), target: 63
input: tensor([54, 43, 63]), target: 1
input: tensor([54, 43, 63,  1]), target: 58
input: tensor([54, 43, 63,  1, 58]), target: 46
input: tensor([54, 43, 63,  1, 58, 46]), target: 43
input: tensor([54, 43, 63,  1, 58, 46, 43]), target: 0
input: tensor([54, 43, 63,  1, 58, 46, 43,  0]), target: 19
input: tensor([54, 43, 63,  1, 58, 46, 43,  0, 19]), tar

## Model

In [12]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)


class ScaledDotProductAttention(nn.Module):
    def __init__(self, embed_dim: int, context_length: int, head_dim: int = 32, causal: bool = True, dropout: float = 0.2):
        super(ScaledDotProductAttention, self).__init__()
        self.head_dim = head_dim
        self.causal = causal
        self.dropout = dropout
        self.to_key = nn.Linear(embed_dim, head_dim, bias=False)
        self.to_query = nn.Linear(embed_dim, head_dim, bias=False)
        self.to_value = nn.Linear(embed_dim, head_dim, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones((context_length, context_length))))

    def forward(self, x):
        b, t, c = x.shape
        k = self.to_key(x)
        q = self.to_key(x)
        v = self.to_key(x)

        # Attention scores and masking for autoregressive for transformer decoder only LM.
        attn_weights = q @ k.transpose(-1, -2) * (self.head_dim ** -0.5)
        if self.causal:
            attn_weights = attn_weights.masked_fill(self.tril[:t, :t] == 0, float('-inf'))
        
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = F.dropout(attn_weights, p=self.dropout)

        out = attn_weights @ v
        return out, attn_weights


In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(
            self, embed_dim: int, context_length: int, num_heads: int, head_dim: int = 32, causal: bool = True, dropout: float = 0.2
    ):
        """Scripted dot product attention which splits embed dim to number of head for parallel computation in multi head attention.
        """
        super(MultiHeadAttention, self).__init__()
        self.dropout = dropout
        self.sdp_heads = nn.ModuleList([
            ScaledDotProductAttention(
                embed_dim=embed_dim, context_length=context_length, head_dim=embed_dim // num_heads, causal=causal, dropout=dropout
            ) for _ in range(num_heads)
        ])
        self.projection = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        head_attn_out = []
        head_attn_weights_out = []
        for h in self.sdp_heads:
            attn_out, attn_weights = h(x)
            head_attn_out.append(attn_out)
            head_attn_weights_out.append(attn_weights)

        head_attn_out = torch.cat(head_attn_out, dim=-1)
        head_attn_weights_out = torch.cat(head_attn_weights_out, dim=-1)
        out = self.projection(head_attn_out)
        out = F.dropout(out, p=self.dropout)
        return out, head_attn_weights_out

In [14]:
class TransformerBlock(nn.Module):
    def __init__(
            self, embed_dim: int, context_length: int, num_heads: int, head_dim: int = 32, causal: bool = True, dropout: float = 0.2
    ):
        """Using prenorm like gpt. Since multi head attention implementation return attention output and weights as tuple
        the 0th index is used mha for outputs. 
        """
        super(TransformerBlock, self).__init__()
        self.mhsa = torch.nn.Sequential(
            nn.LayerNorm(embed_dim),
            torch.jit.script(
                MultiHeadAttention(
                    embed_dim=embed_dim, context_length=context_length, num_heads=num_heads, head_dim=head_dim, causal=causal, dropout=dropout
                )
            )
        )
        self.mhsa = torch.jit.script(self.mhsa)
        print(self.mhsa.code)

        self.ff = torch.nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.SiLU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(p=dropout),
        )        

    def forward(self, x):
        x = x + self.mhsa(x)[0]
        x = x + self.ff(x)
        return x

Instead of direct logits embedding is generated from input indices which is reshaped to vocab size for softmax in loss. Also positional embedding is added to each index position.

In [15]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)


class GPTLanguageModel(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            device: str,
            context_length: int = 8,
            embed_dim: int = 32,
            head_dim: int = 32,
            num_heads: int = 4,
            dropout: float = 0.2,
    ):
        super(GPTLanguageModel, self).__init__()
        self.device = device
        self.context_length = context_length
        self.token_embedding_table = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
        self.position_embedding_table = nn.Embedding(num_embeddings=context_length, embedding_dim=embed_dim)
        
        # self.attn_head = ScaledDotProductAttention(embed_dim=embed_dim, head_dim=head_dim, context_length=context_length)
        # self.lm_head = nn.Linear(head_dim, vocab_size)

        # self.attn_head = torch.jit.script(
        #     MultiHeadAttention(
        #         embed_dim=embed_dim, context_length=context_length, causal=True, num_heads=num_heads,
        #     )
        # )

        # print(self.attn_head.code)

        self.transformer_blocks = nn.Sequential(
            TransformerBlock(context_length=context_length, head_dim=head_dim, embed_dim=embed_dim, num_heads=num_heads, causal=True, dropout=dropout),
            TransformerBlock(context_length=context_length, head_dim=head_dim, embed_dim=embed_dim, num_heads=num_heads, causal=True, dropout=dropout),
            TransformerBlock(context_length=context_length, head_dim=head_dim, embed_dim=embed_dim, num_heads=num_heads, causal=True, dropout=dropout),
            nn.LayerNorm(embed_dim),
        )
        self.lm_head = nn.Linear(embed_dim, vocab_size)
        

    def forward(self, idx, targets=None):
        """Logits in shape of (batch, time_dim, channel_dim) which is reshaped to 2d tensor for cross entropy loss. 
        `t` is the time dimension context_length and `c` is channel dim each token embedding.
        """
        b, t = idx.shape

        token_embeddings = self.token_embedding_table(idx)  # (b, t, embed_dim)
        position_embeddings = self.position_embedding_table(torch.arange(t, device=self.device))    # (t, embed_dim)
        x = token_embeddings + position_embeddings  # (b, t, embed_dim)

        # x, attn_weights = self.attn_head(x)
        x = self.transformer_blocks(x)

        logits = self.lm_head(x)    # (b, t, vocab_size)

        if targets is None:
            loss = None
        else:
            b, t, c = logits.shape
            logits = logits.view(b * t, c)
            targets = targets.view(b * t)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    @torch.jit.export
    def generate(self, idx, max_new_tokens: int):
        """Generates new token from previous token and sample top 1 from softmax probs for next id. Due to positional 
        embedding table set to context length anything beyond will cause error so idx are truncated to last context
        indices.
        """
        for _ in range(max_new_tokens):
            idx_truncated = idx if idx.size(1) <= self.context_length else idx[:, -self.context_length:]
            logits, _ = self(idx_truncated)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            ids_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, ids_next], dim=1)

        return idx

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
model = GPTLanguageModel(
    vocab_size=vocab_size, device=device, context_length=context_length, embed_dim=256, num_heads=4, dropout=0.2
).to(device)

# model = torch.jit.trace(model, (xb.to(device), yb.to(device)))
# model = torch.jit.script(model)

logits, loss = model(xb.to(device), yb.to(device))
print(logits.shape)
print(loss)


pred_token_idx = model.generate(torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=100)
print(pred_token_idx)
print(pred_token_idx.shape)
print(decode(pred_token_idx[0].tolist()))

def forward(self,
    input: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = getattr(self, "0")
  _1 = getattr(self, "1")
  input0 = (_0).forward(input, )
  return (_1).forward(input0, )

def forward(self,
    input: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = getattr(self, "0")
  _1 = getattr(self, "1")
  input0 = (_0).forward(input, )
  return (_1).forward(input0, )

def forward(self,
    input: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = getattr(self, "0")
  _1 = getattr(self, "1")
  input0 = (_0).forward(input, )
  return (_1).forward(input0, )

torch.Size([16384, 65])
tensor(4.3623, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor([[ 0, 11, 29, 10, 14, 50, 37, 64, 62,  1, 37,  4, 37, 34, 49, 42, 25, 25,
         11,  7, 33, 13, 18, 44,  3, 56, 52, 40,  8, 30, 30, 55, 30, 37,  7, 20,
         47, 18, 26,  2, 33, 29,  7, 36, 56, 61, 56, 56, 62, 14,  5, 15,  5, 16,
         26, 61, 33, 23, 21, 57, 18, 24, 22, 43, 45, 44, 35, 11, 42, 56, 61, 14,
         29, 44,  7, 49, 59, 12, 30, 21, 19,  5, 

In [18]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [19]:
from contextlib import nullcontext
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=torch.float32)

## Multiple Batch Loss Evaluation 

In [20]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            x = x.to(device)
            y = y.to(device)
            with ctx:
                logits, loss = model(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()            
    return out

## Training

In [21]:
for i in range(train_iters):
    xb, yb = get_batch('train')
    xb = xb.to(device)
    yb = yb.to(device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if i % eval_iters == 0:
        # losses = estimate_loss()
        # print(f"step: {i}: train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}")

        x, y = get_batch('val')
        x = x.to(device)
        y = y.to(device)
        with ctx:
            val_logits, val_loss = model(x, y)

        print(f"step: {i}: train loss: {loss:.4f}, val loss: {val_loss:.4f}")


print(f"final train loss: {loss.item():.4f}, val loss: {val_loss:.4f}")


 does not have profile information (Triggered internally at ..\third_party\nvfuser\csrc\graph_fuser.cpp:108.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


step: 0: train loss: 4.3476, val loss: 3.9937
step: 30: train loss: 2.6456, val loss: 2.6201
step: 60: train loss: 2.5480, val loss: 2.5463
step: 90: train loss: 2.5338, val loss: 2.5047
step: 120: train loss: 2.5105, val loss: 2.5265
step: 150: train loss: 2.4912, val loss: 2.5133
step: 180: train loss: 2.4843, val loss: 2.4958
step: 210: train loss: 2.4869, val loss: 2.5044
step: 240: train loss: 2.4718, val loss: 2.5022
step: 270: train loss: 2.4556, val loss: 2.4960
step: 300: train loss: 2.4536, val loss: 2.4874
step: 330: train loss: 2.4661, val loss: 2.4851
step: 360: train loss: 2.4591, val loss: 2.5088
step: 390: train loss: 2.4555, val loss: 2.4527
step: 420: train loss: 2.4594, val loss: 2.4853
step: 450: train loss: 2.4335, val loss: 2.4724
step: 480: train loss: 2.4613, val loss: 2.4853
step: 510: train loss: 2.4479, val loss: 2.4818
step: 540: train loss: 2.4321, val loss: 2.4613
step: 570: train loss: 2.4315, val loss: 2.4366
step: 600: train loss: 2.4329, val loss: 2.45

KeyboardInterrupt: 

## Generation

In [23]:
pred_token_idx = model.generate(torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=800)
# print(pred_token_idx)
print(pred_token_idx.shape)
print(decode(pred_token_idx[0].tolist()))

torch.Size([1, 1001])

You resof thou me's of his claimen; to, hon net cousing,
Their hearlen, and, my knows bunrooples, by merry;
We use mittle his deadeth brother this fleshing of out,
And next my that hon.
Catesby to hence. No, here that me? Bushop'st pounn inded, and
To fearsh'd next seedition, of fight
The wings authros? Come, the Luck, but ouf the srage.

DUCHESS II:
Smorrow the us for 't was you, folly you art
COMINUS:
Nor the shaldieven the treape happy of loved, and all.
They fait;
The see forswatining? O gentle becape any heavourt
Hare bitte chonoq him lie gentlemen'd-
uto fawers they will
She's spon traidf and thy me? they off me?

First Citizen PRossent:
Come have mean, indumand'd my lovoy of my relive.

DUCHESS OF YORK:

CAPULET:
Therreforcia, and years, doth in I'ld no
Which it; Doth darked his be passague rooving my purpost as be have
That the wouldin musick.


ELBOW:
Unjusts most brow faniful'd Romeo, my look,
I'Trove words no peace you lord blefen me
your deathe, he
Be