In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# block_size = 8
block_size = 64
# batch_size = 4
batch_size = 128
# max_iters = 1000
max_iters = 3000
# learning_rate = 3e-3, 3e-4, 1e-3, 1e-4
learning_rate = 3e-4
# eval_iters = 250
eval_iters = 100
n_embed = 384 # no of dimensions we need to capture with all n_heads
n_head = 4 # no of attention blocks //parallel
n_layer = 4 # no of decoder blocks //sequential
dropout = 0.2 # 20% of total neurons

In [2]:
print(device)

cuda


In [3]:
with open('wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()
chars = sorted(list(set(text)))
print(len(text))

232309


In [4]:
#encoder and decoder
string_to_int = {ch:i for i, ch in enumerate(chars)}
int_to_string = {i:ch for i, ch in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

encode('hello')

[61, 58, 65, 65, 68]

In [5]:
#loading text(wizard_of_oz) into tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([80,  1,  1, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,
         1, 47, 33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1, 26,
        49,  0,  0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,
         0,  0,  1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,  1,
        47, 33, 50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,  1,
        36, 25, 38, 28,  1, 39, 30,  1, 39, 50])


In [6]:
# splitting to training and validation sets

n= int(0.8*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    #Randomly picks batch_size starting positions (indices) in the data.
    #block_size is the context length (the number of tokens the model sees at once).
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # print(ix)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [7]:
@torch.no_grad()
def estimate_loss():

    out = {}
    model.eval() # puts model in evaluation mode {dropout is off}

    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # puts the model on training mode, weights and biases are updated during this phase {dropout is active}
    return out

In [10]:
# initialize neural net
from torch.nn import functional as F
import torch.nn as nn


vocab_size = len(chars)

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias = False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # masking

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head_size)
        B,T,C = x.shape
        k = self.key(x) # (B, T, hs)
        q = self.query(x) # (B, T, hs)
        

        # compute attention scores ("affinities")
        # 1/sqrt(dot prod (q, k))
        # (-2, -1) - flips the second last dimension with the last dimension {(T, hs) -> (hs, T)}
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T) {hs - > head size}
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)

        # perform the weighted aggregation of the values
        v = self.value(x) # (B, T, hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)

        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) # parallel Heads
        self.proj = nn.Linear(head_size * num_heads, n_embed) # projection
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        
        out = torch.cat([h(x) for h in self.heads], dim = -1) # (B, T, C); dim =-1 => C, here C is Features => (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3, h4, h4, h4, h4]) => so instead of having these multiple features of multiple heads, we just cancatenate them to one
        
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed), # matrix
            nn.ReLU(), 
            nn.Linear(4 * n_embed, n_embed), # matrix
            nn.Dropout(dropout), # a certain percentage of neurons to dropout to preveting overfitting
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embed, n_head):
        # n_embed: embedding dimension, n_head: the no.of heads we'd like
        super().__init__()
        head_size = n_embed // n_head # no.of features each head will be capturing
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    
    def forward(self, x):
        y = self.sa(x) # multi head attention
        x = self.ln1(x + y) # residual connection postnorm (same as in original transformer)
        y = self.ffwd(x) # fed forward
        x = self.ln2(x + y) # residual connection postnorm
        return x



class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # nn.Embedding is like a lookup table
        # the embedding matrix is (vocab_size × vocab_size) — each row directly represents the logits for the next character.
        #n_embed for more large vector size to represent more than just single encoding character
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        # positional embeddings
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        # sequential decoder layers
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head) for _ in range(n_layer)]) # sequential blocks

        self.ln_f = nn.LayerNorm(n_embed) # final layer norm
        self.lm_head = nn.Linear(n_embed, vocab_size)

        self.apply(self._init_weights)
        # print(self.token_embeddings_table(data))

    def _init_weights(self, module):

        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) # std - standared deviation
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std =0.02)

    def forward(self, index, targets=None):
        
        B, T = index.shape

        # idx and targets are both (B, T) tensor of integers
        token_embedding = self.token_embedding_table(index)
        positional_embedding = self.position_embedding_table(torch.arange(T, device = device)) # (T, C)
        x = token_embedding + positional_embedding # (B, T, C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            # reshaping what logits looks like
            # batch, time, channels (vocabulary size)
            B, T, C = logits.shape
            # we are paying attention to C (vocabulary), so we do B*T as they are not that important
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, index, max_new_tokens):
        # index is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            index_cond = index[:, -block_size:]
            logits, loss = self.forward(index_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) #-ve index because last dimension # (B, C)
            #sample from the distribution
            index_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            #append sampled index to the running sequence
            index = torch.cat((index, index_next), dim=1) # (B, T+1)

        return index

model = GPTLanguageModel(vocab_size)
m = model.to(device)

# context = torch.zeros((1, 1), dtype = torch.long, device=device)# torch.long is equivalent to int64
# generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
# print(generated_chars)

In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) #adam weight decay optimizer algo

for iter in range(max_iters):

    if iter% eval_iters == 0:
        losses = estimate_loss()
        print(f'step: {iter}, train loss: {losses['train']}, val loss: {losses['val']}')
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model.forward(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

step: 0, train loss: 4.430508613586426, val loss: 4.428066253662109
step: 100, train loss: 2.347662925720215, val loss: 2.4141321182250977
step: 200, train loss: 1.9848394393920898, val loss: 2.099562168121338
step: 300, train loss: 1.7537264823913574, val loss: 1.8793383836746216
step: 400, train loss: 1.6072111129760742, val loss: 1.7548702955245972
step: 500, train loss: 1.5103707313537598, val loss: 1.6807934045791626
step: 600, train loss: 1.43593430519104, val loss: 1.642560601234436
step: 700, train loss: 1.3897312879562378, val loss: 1.6086680889129639
step: 800, train loss: 1.339085340499878, val loss: 1.5744613409042358
step: 900, train loss: 1.295628309249878, val loss: 1.558104395866394
step: 1000, train loss: 1.2562941312789917, val loss: 1.5326963663101196
step: 1100, train loss: 1.2257294654846191, val loss: 1.5238935947418213
step: 1200, train loss: 1.1970484256744385, val loss: 1.50117826461792
step: 1300, train loss: 1.1649980545043945, val loss: 1.4997669458389282
st

In [14]:
# """
# with learning rate 3e-3

# step: 0, train loss: 4.430576324462891, val loss: 4.433215618133545
# step: 100, train loss: 3.1462972164154053, val loss: 3.1632158756256104
# step: 200, train loss: 3.1415157318115234, val loss: 3.154750347137451
# step: 300, train loss: 3.140493392944336, val loss: 3.159609317779541
# step: 400, train loss: 3.1439969539642334, val loss: 3.153151750564575
# step: 500, train loss: 3.14066743850708, val loss: 3.155769109725952
# step: 600, train loss: 3.139163017272949, val loss: 3.150160551071167
# step: 700, train loss: 3.140573740005493, val loss: 3.150819778442383
# step: 800, train loss: 3.1403863430023193, val loss: 3.1547629833221436
# step: 900, train loss: 3.1438207626342773, val loss: 3.1520254611968994
# step: 1000, train loss: 3.139432907104492, val loss: 3.150862216949463
# step: 1100, train loss: 3.1476073265075684, val loss: 3.151153326034546
# step: 1200, train loss: 3.139864444732666, val loss: 3.1527931690216064
# step: 1300, train loss: 3.142347812652588, val loss: 3.154132127761841
# step: 1400, train loss: 3.143246650695801, val loss: 3.150099277496338
# step: 1500, train loss: 3.1446292400360107, val loss: 3.157273054122925
# step: 1600, train loss: 3.143803596496582, val loss: 3.153653144836426
# step: 1700, train loss: 3.142355442047119, val loss: 3.150972604751587
# step: 1800, train loss: 3.1386635303497314, val loss: 3.1512415409088135
# step: 1900, train loss: 3.1410346031188965, val loss: 3.153546094894409
# step: 2000, train loss: 3.1436045169830322, val loss: 3.150817632675171
# step: 2100, train loss: 3.1439507007598877, val loss: 3.1505837440490723
# step: 2200, train loss: 3.1417930126190186, val loss: 3.1531710624694824
# step: 2300, train loss: 3.145106315612793, val loss: 3.1531436443328857
# step: 2400, train loss: 3.1468560695648193, val loss: 3.1498031616210938
# step: 2500, train loss: 3.1356236934661865, val loss: 3.1475374698638916
# step: 2600, train loss: 3.134104013442993, val loss: 3.1500349044799805
# step: 2700, train loss: 3.1425528526306152, val loss: 3.1578030586242676
# step: 2800, train loss: 3.1384406089782715, val loss: 3.148427963256836
# step: 2900, train loss: 3.1381731033325195, val loss: 3.1510093212127686
# 3.1193087100982666

# with learning rate 3e-4
# step: 0, train loss: 4.476689338684082, val loss: 4.472619533538818
# step: 100, train loss: 2.344285011291504, val loss: 2.4153695106506348
# step: 200, train loss: 1.9366097450256348, val loss: 2.048823356628418
# step: 300, train loss: 1.7199596166610718, val loss: 1.8546732664108276
# step: 400, train loss: 1.593166470527649, val loss: 1.758353590965271
# step: 500, train loss: 1.5044441223144531, val loss: 1.6803436279296875
# step: 600, train loss: 1.4344788789749146, val loss: 1.6400684118270874
# step: 700, train loss: 1.3803069591522217, val loss: 1.6018896102905273
# step: 800, train loss: 1.3321964740753174, val loss: 1.5683866739273071
# step: 900, train loss: 1.2911683320999146, val loss: 1.5531294345855713
# step: 1000, train loss: 1.2569479942321777, val loss: 1.5313323736190796
# step: 1100, train loss: 1.2294522523880005, val loss: 1.5260602235794067
# step: 1200, train loss: 1.1987494230270386, val loss: 1.5222253799438477
# step: 1300, train loss: 1.1687437295913696, val loss: 1.510951280593872
# step: 1400, train loss: 1.1426775455474854, val loss: 1.499334692955017
# step: 1500, train loss: 1.11307692527771, val loss: 1.5131258964538574
# step: 1600, train loss: 1.0965144634246826, val loss: 1.4980086088180542
# step: 1700, train loss: 1.0660827159881592, val loss: 1.5012909173965454
# step: 1800, train loss: 1.0489522218704224, val loss: 1.5051543712615967
# step: 1900, train loss: 1.0238697528839111, val loss: 1.4971965551376343
# step: 2000, train loss: 1.0014030933380127, val loss: 1.5083410739898682
# step: 2100, train loss: 0.9760680198669434, val loss: 1.5179238319396973
# step: 2200, train loss: 0.9566689133644104, val loss: 1.5259524583816528
# step: 2300, train loss: 0.9409555196762085, val loss: 1.5345654487609863
# step: 2400, train loss: 0.914814293384552, val loss: 1.5460710525512695
# step: 2500, train loss: 0.8977524042129517, val loss: 1.5580456256866455
# step: 2600, train loss: 0.8807427883148193, val loss: 1.5575963258743286
# step: 2700, train loss: 0.8536536693572998, val loss: 1.5677181482315063
# step: 2800, train loss: 0.8382221460342407, val loss: 1.5816705226898193
# step: 2900, train loss: 0.8155530691146851, val loss: 1.601585865020752
# # """

In [15]:
import pickle

# Example: model is your trained model object
# e.g. model = MyGPTModel(...)

# Save the model
with open("model-wizard-of-oz.pkl", "wb") as f:
    pickle.dump(model, f)

print("✅ Model saved to model.pkl")

✅ Model saved to model.pkl


In [16]:
# with open('model-wizard-of-oz.pkl', 'rb') as f:
#     model = 

In [9]:
import pickle

with open("model-wizard-of-oz.pkl", "rb") as f:
    model = pickle.load(f)

print("✅ Model loaded successfully")


✅ Model loaded successfully


In [11]:
prompt = input('Prompt:\n')
context = torch.tensor(encode(prompt), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context.unsqueeze(0), max_new_tokens=64)[0].tolist())
print(generated_chars)

Prompt:
 hello


helloOGsgZXyQ(ot﻿kAAS]!Mpek?bOsZ5g!BqQ?1dwYGlBjA.Y).Qv582Dim4ZNe'Blt8
