<a href="https://colab.research.google.com/github/sujith2303/GPT/blob/main/GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import AdamW

In [21]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7eecf01404d0>

In [22]:
batch_size = 1024 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 50000
eval_interval = 500
learning_rate = 3e-4
print_steps = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 100
head_dim = 64
num_heads = 6
n_layer = 1
bias = False
vocab_size = 65
kwargs = {
    "dropout": 0.2,
    "norm_type":"pre",
    "linear_dropout":0.2,
    "head_dropout": 0.2,
    "block_size" : 256
    }

In [23]:
class Head(nn.Module):
    def __init__(self, embed_dim, head_dim, bias = True, device = "cpu",*args, **kwargs) -> None:
        super().__init__()
        self.query   = nn.Linear(embed_dim, head_dim, bias= bias, device=device)
        self.key     = nn.Linear(embed_dim, head_dim, bias= bias, device=device)
        self.value   = nn.Linear(embed_dim, head_dim, bias= bias, device=device)
        self.dropout = nn.Dropout(kwargs.get("head_dropout",0.2))

    def forward(self, x):
        B, T, C =  x.shape
        device = x.device
        tril =  torch.tril(torch.ones(x.shape[1],x.shape[1])).to(device) # Move tril to the same device as x

        q = self.query(x)   ## B T H
        k = self.key(x)     ## B T H
        v = self.value(x)   ## B T H

        matmul =  q @ k.transpose(-2,-1)  * k.shape[-1]**0.5 ## B T H   B H T  = B T T

        wei = matmul.masked_fill(tril == 0, float('-inf'))  # (B, T, T)

        wei  = F.softmax(wei, dim = -1)
        attn_score = wei
        wei  = self.dropout(wei)

        wei = wei @ v ## B T T ->   B T C       == B T C
        return wei,attn_score

In [24]:
class MultiHeadAttention(nn.Module):
    def __init__(self, head_dim, num_heads, bias = False, device = "cpu",*args, **kwargs):
        super().__init__()
        embed_dim = head_dim * num_heads
        self.multiheads = nn.ModuleList([Head(embed_dim, head_dim, bias=bias, device = device, *args, **kwargs) for _ in range(num_heads)])
        self.out = nn.Linear(embed_dim, embed_dim,device = device)
        self.dropout = nn.Dropout(kwargs.get("dropout",0.2))

    def forward(self, x):
        out = torch.cat([h(x)[0] for h in self.multiheads], dim=-1)
        out = self.dropout(self.out(out))
        return out

In [25]:
class LinearBlock(nn.Module):
    def __init__(self, embed_dim, bias = False, device = "cpu",*args, **kwargs) -> None:
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, 4*embed_dim,device = device),
            nn.ReLU(),
            nn.Linear(4*embed_dim, embed_dim, device = device),
            nn.Dropout(kwargs.get("linear_dropout",0.2))
        )

    def forward(self, x):
        return self.linear(x)

In [26]:
class Block(nn.Module):
    def __init__(self, head_dim, num_heads, bias = False, device = "cpu",*args, **kwargs):
        super().__init__()
        embed_dim = num_heads*head_dim
        self.mha    = MultiHeadAttention(head_dim=head_dim, num_heads = num_heads, bias = bias, device =device,*args,**kwargs)
        self.linear = LinearBlock(embed_dim = embed_dim, bias= bias,device = device, *args, **kwargs)
        self.norm1  = nn.LayerNorm(embed_dim, device = device)
        self.norm2  = nn.LayerNorm(embed_dim, device=device)
        self.norm_type = kwargs.get("norm_type","pre")
    def forward(self, x):
        if self.norm_type=="pre":
            x = self.mha(self.norm1(x))+x  ## pre norm
            x = self.linear(self.norm2(x))+x
        else:
            x = self.norm1(self.mha(x) + x)
            x = self.norm2(self.linear(x)+x)

        return x

In [27]:
class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, num_layers, head_dim, num_heads, bias = False,device = "cpu", *args, **kwargs) -> None:
        super().__init__()
        embed_dim = head_dim * num_heads
        self.block_size = kwargs.get("block_size",256)
        self.embeddings = nn.Embedding(vocab_size, embedding_dim=embed_dim)
        self.pos_embeddings = nn.Embedding(self.block_size, embedding_dim=embed_dim)
        self.blocks = nn.Sequential(*[Block(head_dim=head_dim, num_heads=num_heads,bias = bias, device = device, *args, **kwargs) for _ in range(num_layers)])
        self.num_layers = num_layers
        self.out = nn.Linear(embed_dim,vocab_size, device = device)
        self.norm3 = nn.LayerNorm(embed_dim,vocab_size)

    def forward(self, x, targets = None):
        B,T = x.shape
        device = x.device
        x = self.embeddings(x).to(device) + self.pos_embeddings(torch.arange(T, device=device))
        x = self.blocks(x).to(device)
        x = self.norm3(x).to(device)
        logits = self.out(x).to(device)

        if targets==None:
            loss =  None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, x, max_new_tokens):
        for _ in range(max_new_tokens):
            x = x[:,-self.block_size:]
            logits, _ = self(x)
            B, T, C  = logits.shape
            logits   = logits[:,-1,:]
            probs    = F.softmax(logits,dim=-1)
            next_idx = torch.multinomial(probs, num_samples = 1)
            x = torch.cat((x,next_idx), dim = 1)
        return x

In [28]:
model = GPTLanguageModel(vocab_size=vocab_size,
                         num_layers=  n_layer,
                         head_dim = head_dim,
                         num_heads = num_heads,
                         bias = bias,
                         **kwargs).to(device) # Move the model to the correct device

In [29]:
model

GPTLanguageModel(
  (embeddings): Embedding(65, 384)
  (pos_embeddings): Embedding(256, 384)
  (blocks): Sequential(
    (0): Block(
      (mha): MultiHeadAttention(
        (multiheads): ModuleList(
          (0-5): 6 x Head(
            (query): Linear(in_features=384, out_features=64, bias=False)
            (key): Linear(in_features=384, out_features=64, bias=False)
            (value): Linear(in_features=384, out_features=64, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (out): Linear(in_features=384, out_features=384, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (linear): LinearBlock(
        (linear): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1536, out_features=384, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
  

## Load text

In [30]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

--2025-09-18 13:50:50--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-09-18 13:50:51 (24.3 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [31]:
def get_batch(split):
    data = train_data if split=="train" else val_data # Use val_data for 'val' split
    ix  = torch.randint(len(data)-block_size,(batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [32]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [33]:
estimate_loss()

{'train': tensor(4.1883), 'val': tensor(4.1871)}

In [34]:
optimizer = AdamW(model.parameters())

In [None]:
for iter in range(max_iters):
    if iter % eval_interval==0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    if iter % print_steps ==0:
        context = torch.zeros((1, 1), dtype=torch.long, device=device)
        print("\n\nPREDICTIONS")
        print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))
        print("\n\n")
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.1883, val loss 4.1871


PREDICTIONS
oy xUlhwnuE!OkTenS:N
v3&KfcFIGDpEcfDMLJcm!&$3KxTBf.IXAZriJCKSrWDi m.fBEHjG.GdAuUT3VSbhrYfXuWt$VU ktK&VuVEaYiWFIOY?mY3xspgW h-d&DprvRFP?Ba?pgcxH$Zl-Xe&gQlfe,KZkP,,3PH&byytauL $IJGowRghRZJWU':mQxUxaNPqnGtqlMq lzz!PRB&,OcP,rKm?TzaTJmmETvi!,,K
yBi,
Vb
jZbBw r-I



step 500: train loss 1.9327, val loss 2.0487
step 1000: train loss 1.5270, val loss 1.7196


PREDICTIONS
u.

Patreak peak Telant, a confisl hands pont my know mire?
For the all of York,
Go outhe Friator:
Rethat it onder theehithere full sher hast iain
The ane marreof thou selt,
And on to the Angely fresweat wosserveong of ve age:
And mot; the cripessenced,
And



step 1500: train loss 1.4527, val loss 1.6634
step 2000: train loss 1.4251, val loss 1.6468


PREDICTIONS
bod-s hare fathere.
Sake poor:
Farewell.
Vour per:
'Ay worn sets'

RICHARD:
Terchers a
lowers, it name tater of BUaugs as not:
Misport o'rt will'd and tell; foot with alway than is wife!
Are you were intell s' 

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))

JhhXb
:hJCvXiKKwZxxSMyv.Upm'XwoO :EJXd
MC!HbH&eJuVLSImJ
fX f,L$J.kqfmV.eaV&Ki;W!vKY

uVmOpuTk-IkdINQtH3sb!EjToTisHD GcJyW bGAm$WdIKaPMV .' ywu'UNjtgoi&q?YyosdU?-YDNOLc,jsUJ:HNFdhkg.tC-zGQjK3Z&:jgdab,dssL.MZaBVEUabnAXug Z:tCF tpuWYK.D3T'PwJWAtvv
:MJDUl'ixDqI
