In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

device = "mps" if torch.backends.mps.is_available() and torch.backends.mps.is_built() else device
print(device)

block_size = 8
batch_size = 4
max_iter   = 1000

learning_rate = 3e-4
eval_iters = 250

mps


In [2]:
with open('./wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(set(text))

print(chars)
vocab_size = len(chars)

['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\ufeff']


In [3]:
string_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

data   = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([80,  1,  1, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,
         1, 47, 33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1, 26,
        49,  0,  0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,
         0,  0,  1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,  1,
        47, 33, 50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,  1,
        36, 25, 38, 28,  1, 39, 30,  1, 39, 50])


In [4]:
n = int(.8*len(data))

train_data = data[:n]
val_data   = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix   = torch.randint(len(data) - block_size, (batch_size,))
    x    = torch.stack([data[i:i+block_size] for i in ix])
    y    = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

x, y = get_batch('train')
print('inputs:')
print(x)

print('outputs:')
print(y)

inputs:
tensor([[54, 65, 54, 56, 58,  1, 76, 54],
        [62, 67, 60,  1, 68, 67, 58,  1],
        [71,  1, 56, 68, 66, 69, 54, 67],
        [68, 72, 58, 57,  1, 55, 78,  1]], device='mps:0')
outputs:
tensor([[65, 54, 56, 58,  1, 76, 54, 72],
        [67, 60,  1, 68, 67, 58,  1, 68],
        [ 1, 56, 68, 66, 69, 54, 67, 62],
        [72, 58, 57,  1, 55, 78,  1, 54]], device='mps:0')


In [5]:
@torch.no_grad()

def estimate_loss():
    out = {}
    model.eval()

    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)

        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()

        out[split] = losses.mean()
    model.train()
    return out

In [6]:
'''
In training mode(targets is provided); this function computes the loss using ther cross-entropy criterion,
which compares the logits (predicted token probabilities) to the actual target tokens.

In inference mode(targets is none), the function skips losss computation, and loss is set to None 
since the focus in on prediction rather than training.

This function is crucial for models like language models or sequence models, 
where tokens(words, characters, etc) are predicted at each step, and the loss is computed 
to guide model optimization during training.

'''

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, index, targets=None):
        ## Token embedding lookup
        # embedding look up table converts the input token indices to corresponding embeddings(vectors)
        logits = self.token_embedding_table(index) 

        ## Handling loss calculation
        if targets is None:
            loss = None
        else:
            # logits will have the shape (B,T,C)
            # B = batch size
            # T = sequence length
            # C = # of embedding dimensions
            
            B, T, C = logits.shape
            logits  = logits.view(B*T, C) # reshape/flatten 3D tensors to 2D
            targets = targets.view(B*T) # reshape/flatten to 1D, target tensor has to be 1D
            loss    = F.cross_entropy(logits, targets)
            # calculates the cross-entropy loss between the predicted logits and the true target tokens. 
            # Cross-entropy is a common loss function for classification tasks, 
            # which measures how well the predicted distribution matches the actual distribution.
        return logits, loss
    # this function is used to generate new tokens based on the given index
    def generate(self, index, max_new_tokens):
        
        for _ in range(max_new_tokens):
            # computes output logits  for the current index
            logits, loss = self.forward(index) 
            # logits represents the unnormalised scors for each token
            
            logits       = logits[:, -1, :]
            # converts ;logits to probability
            # softmax is sort of normalization in output layer
            probs        = F.softmax(logits, dim=-1)
            index_next   = torch.multinomial(probs, num_samples=1)
            index        = torch.cat((index, index_next), dim=1) 
        return index

'''
This "generate" function uses the model to iteratively generate a sequence of tokens.

KEY flow:
1) i takes an input token sequence(index)
2) for each new token, 

- it generates the logits (predictions) for the next token
- it extracts the logits of the last token in the current sequence
- converts these logits to probabilities using the softmax function
- it samples the next token based on the predicted probabilities
- it appends this token to the current sequence

3) after generating the requested number of new tokens, it returns the complete sequence

'''

'\nThis "generate" function uses the model to iteratively generate a sequence of tokens.\n\nKEY flow:\n1) i takes an input token sequence(index)\n2) for each new token, \n\n- it generates the logits (predictions) for the next token\n- it extracts the logits of the last token in the current sequence\n- converts these logits to probabilities using the softmax function\n- it samples the next token based on the predicted probabilities\n- it appends this token to the current sequence\n\n3) after generating the requested number of new tokens, it returns the complete sequence\n\n'

In [7]:
model = BigramLanguageModel(vocab_size)
m     = model.to(device)

context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())

print(generated_chars)


eQ'QwZo﻿5a4U
MQwF-RdSf1G R"V GHZXxb4RU)WKNd1VFlZY3AQ8nh?A96YqXC;eF 3"pF.dTb)]h[PZN!1-9dSE﻿esTE
?b:M2Hr'EJMUoSr6sSCIiau﻿dX1J6
﻿AF79MfV.na[!d_MHO)4Rt"9'1r)kC*&p4O﻿JQ""7MHoPinHwZNfOtwY-J:_djnVXxp7X(fWcIPb(Q,
F8htyYLr?)K3MHgvgEX.V.5tKzTkMd87XrJ6
9(-YH!S3NroxpwfQ",WLR
iz,]&MF_QL8C*NBX
?WOLr[wHxxdYAOtgsi5BwZZ?*NBDHSB2XgsObsS2JUpa[aRvLrY"w:mrAMFXL&wB*]_'Lx4xxsQfu(-O X_i3.bLw_mKfgL0lSV,z,[vZo!JyGUXw0Zs﻿6DV.yeJq3NVMmY_mHe!ShWn
,ttc!c"B,
G[*'-BGUnFm)kAq  u
3_EfW)aN8nupad,E
R[LPS_x&erGnXh! PSEC,"r6AWW6Y!p7


In [8]:
## Create pyTorch optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iter):
    if iter % eval_iters == 0:
        losses = estimate_loss()
        print(f"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}")

    xb, yb = get_batch('train')

    logits, loss = model.forward(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

step: 0, train loss: 4.824, val loss: 4.803
step: 250, train loss: 4.761, val loss: 4.776
step: 500, train loss: 4.708, val loss: 4.718
step: 750, train loss: 4.659, val loss: 4.658
4.544988632202148


In [9]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())

print(generated_chars)


"&&c﻿o0LV*RCd_ X[[:1E﻿CdSqlXhzp;gK3BS:(flqDuX;mSAW(lYv1xxk)SEX6RMCGSPc3'nht8﻿;lGk'
m&UO!KfVwcV.F5Bqb3OK'IL0E0E*D?f HEG)] kz9An-O"BCnRynuE﻿JSq2:x2Ncp aI5hmH1U'kkA)udQoB﻿O6]d:T8naCAa[cHgswD7gf_t7W(L_fLxMDU0Zo﻿P5hac6Qgf:*?aOt8xxSzAZCc3Oy *U6O),VRkkgN*vp2HEy A'e*N(f08nL-9RkA_5f3aN:5WWWWX!SGH-9Zu'Ju-Iv1aaNfDYrMBQn!K3NtYAXp7Xv058naFZE?(fldv)1GU7""VKT1KNWso_[H2'AWWWRMm7JGVmX_EMx7m96xUEs-*v?)C5BLte-bJ8mVVq2-
*Dawt-hWVxd0OCEie?KN8RQJmr7[V?(a[vc30bBDd7cL8&0bAq0j3.nY_dN9Z9j*w﻿6UOsM
W(pmP'vmv?"?nN&]D﻿G*KofQ
