<a href="https://colab.research.google.com/github/samitha278/gpt-from-scratch/blob/main/GPT_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
 import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import gdown
import random

In [2]:
file_id = "1ia6z4itw7WJWpnoTohURX6Lm-AnZmVZz"
url = f"https://drive.google.com/uc?id={file_id}"

output = "input.txt"
gdown.download(url, output, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1ia6z4itw7WJWpnoTohURX6Lm-AnZmVZz
To: /content/input.txt
100%|██████████| 1.12M/1.12M [00:00<00:00, 14.1MB/s]


'input.txt'

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

len(text)

1115394

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for i,s in enumerate(chars)}

encode = lambda s : [stoi[ch] for ch in s]
decode = lambda l : ''.join(itos[i] for i in l)

print(encode("Hello world"))
print(decode(encode("Hello world")))

[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
Hello world


In [6]:
data = torch.tensor(encode(text))
data[:100]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])

In [7]:
n = int(0.9* len(data))
train = data[:n]
val = data[n:]

In [8]:
block_size = 8

x = train[:block_size]
y = train[1:block_size+1]

for i in range(block_size):
  context = x[:i+1]
  target = y[i]

  print(f'{context} target:{target}')

tensor([18]) target:47
tensor([18, 47]) target:56
tensor([18, 47, 56]) target:57
tensor([18, 47, 56, 57]) target:58
tensor([18, 47, 56, 57, 58]) target:1
tensor([18, 47, 56, 57, 58,  1]) target:15
tensor([18, 47, 56, 57, 58,  1, 15]) target:47
tensor([18, 47, 56, 57, 58,  1, 15, 47]) target:58


In [9]:
ix = torch.randint(100-block_size, (4,))
ix

tensor([57, 75, 26, 74])

In [10]:
torch.manual_seed(278)


batch_size = 4
block_size = 8


def get_batch(split):

  data = train if split=='train' else val
  ix = torch.randint(len(data)-block_size, (batch_size,))
  x = torch.stack([data[i:block_size+i] for i in ix])
  y = torch.stack([data[i+1:block_size+i+1] for i in ix])
  return x,y


xb,yb = get_batch('train')

xb,yb

(tensor([[47, 57,  0, 42, 39, 59, 45, 46],
         [40, 59, 58,  1, 40, 39, 57, 58],
         [58, 53,  1, 46, 43, 56,  1, 57],
         [50, 63,  0, 58, 39, 56, 56, 63]]),
 tensor([[57,  0, 42, 39, 59, 45, 46, 58],
         [59, 58,  1, 40, 39, 57, 58, 39],
         [53,  1, 46, 43, 56,  1, 57, 53],
         [63,  0, 58, 39, 56, 56, 63,  1]]))

## Bigram Language Model

In [11]:
torch.manual_seed(278)

class BigramLM(nn.Module):

  def __init__(self,vocab_size):
    super().__init__()

    self.token_emb_table = nn.Embedding(vocab_size,vocab_size)


  def __call__(self,idx , targets=None):

    logits = self.token_emb_table(idx)    #shape: (b,t,c)
    if targets is None:
      loss = None
    else:
      loss = F.cross_entropy(logits.view(-1,vocab_size) , targets.view(-1))

    return logits , loss


  def generate(self,idx, max_new_tokens):

    for _ in range(max_new_tokens):

      logits , loss = self(idx)

      logits = logits[:,-1,:] #from all batch's last element's logits

      probs = F.softmax(logits,dim=1)

      ix = torch.multinomial(probs,num_samples=1)

      idx = torch.cat((idx,ix), dim=1)

    return idx



bigram = BigramLM(vocab_size)
logits , loss = bigram(xb,yb)

print(logits.shape,loss)



idx= torch.zeros((1,1),dtype= torch.long)

print(decode(bigram.generate(idx,max_new_tokens=100)[0].tolist()))







torch.Size([4, 8, 65]) tensor(4.5468, grad_fn=<NllLossBackward0>)

hDkMQcyOQpP-rU-,VfVk:rXwxj Ug$$kNRxr.x'R3ULl!WC?fErPF'K'nybrlziq:IF:J.-YVN.jj$R-kDwR
hWiDAg,rHH'!JzL


In [12]:
print([i.shape for i in bigram.parameters()])   #token_emb_table

[torch.Size([65, 65])]


In [13]:
optimizer = torch.optim.AdamW(bigram.parameters(),lr = 1e-3)

In [14]:
batch_size = 32

for i in range(1000):

  xb,yb = get_batch('train')

  logits , loss = bigram(xb,yb  )

  optimizer.zero_grad(set_to_none=True)

  loss.backward()

  optimizer.step()

  if i%10000 == 0:
    print(loss.item())


4.7092204093933105


In [15]:
idx= torch.zeros((1,1),dtype= torch.long)

print(decode(bigram.generate(idx,max_new_tokens=100)[0].tolist()))


hW. U EQ&KHb
SKB3FKB&jq&p;JoYXMvlik:cusuILA:ivFrod3Y!
K?$ne FsT liO't JCoAHpEqKLKm!mLD3fMWArtawnJICl


### Averaging past context

In [16]:
x,y = get_batch('train')
x[:4]

tensor([[49,  1, 46, 47, 57,  1, 45, 56],
        [ 1, 50, 47, 60, 43,  6,  0, 21],
        [61, 56, 39, 54,  1, 53, 59, 56],
        [46,  7,  7,  0,  0, 29, 33, 17]])

In [17]:
xb,yb = bigram(x,y)
B,T,C = xb.shape
xb.shape

torch.Size([32, 8, 65])

In [18]:
xbow = torch.zeros(xb.shape)

for b in range(B):
  for t in range(T):

    xprev = xb[b,:t+1]
    xbow[b,t] = xprev.mean(0)

### Averaging with matrix multiplication

In [19]:
avg8 = torch.tensor([1/i for i in range(1,T+1)])
avg8

tensor([1.0000, 0.5000, 0.3333, 0.2500, 0.2000, 0.1667, 0.1429, 0.1250])

In [20]:
avg88 = (avg8*torch.ones(T,T)).T

In [21]:
avg88_tril = torch.tril(avg88)
avg88_tril

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [22]:
a = torch.tril(torch.ones(T,T))
a = a / a.sum(1,keepdim=True)
a

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [23]:
xbow_mat = torch.zeros(xb.shape)

for batch in range(B):
  xbow_mat[batch] = a @ xb[batch]


In [24]:
xbow2 = a @ xb

In [25]:
torch.allclose(xbow,xbow2)

False

### Summary of averaging

In [26]:
a = torch.tril(torch.ones(T,T))
a = a / a.sum(1,keepdim=True)

xbow = a @ xb


### Version 3

In [27]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))

wei = F.softmax(wei,dim = 1)

xbow3 = wei @ xb
(xbow3 == xbow2).all()

tensor(True)

### Version 4 - Self attention

In [28]:
torch.manual_seed(278)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)


head_size = 16

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)



k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))

wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

### mat mul high dimension

In [29]:
a = torch.randn((2,4,8))
b = torch.randn((2,4,4))
b @ a

tensor([[[-0.1080,  3.4493, -0.0389, -2.0157,  0.3596, -1.5135,  1.6457,
          -1.3712],
         [-1.0791, -0.5694, -0.1386,  2.9522,  0.5635,  0.4879, -1.8812,
           3.2153],
         [-1.7204, -0.5798,  0.2062,  0.4670,  0.4478, -0.1849, -0.1060,
           1.6224],
         [-2.2160,  4.4720,  0.0718, -1.4928,  1.2393, -2.3581,  1.6715,
           0.7310]],

        [[ 3.7422, -0.6046,  1.8760, -1.7515,  1.8201, -2.3911, -2.3626,
           0.6900],
         [ 3.2248,  0.7812,  2.1391, -0.6751,  1.5220, -2.4509, -3.0965,
           1.6815],
         [ 3.8133, -1.2514,  2.2410, -2.2957,  2.3149, -2.5146, -0.9141,
           0.3985],
         [ 2.1424, -1.9206,  4.1068, -1.7650,  2.1677, -1.9662,  5.3301,
           0.0766]]])

In [30]:
xt,yt = get_batch('train')

xt.shape

torch.Size([32, 8])

In [31]:
token_embd = nn.Embedding(65,40)
token_embd.weight.shape

out = token_embd(xt)
out.shape

torch.Size([32, 8, 40])

In [32]:
lm_head = nn.Linear(40,65)
lm_head.weight.T.shape

torch.Size([40, 65])

In [33]:
lm_head(out).shape

torch.Size([32, 8, 65])

## Step 2 Updated Bigram


In [34]:
# hyperparameters
batch_size = 32
block_size = 64
eval_iters = 10000
n_embd = 128
head_size = n_embd
max_iter  = 10000
learning_rate = 1e-2

In [45]:
class selfAttentionHead(nn.Module):


    def __init__(self,head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))


    def forward(self,x):

        B,T,C = x.shape

        key = self.key(x)
        query = self.query(x)

        weight = query @ key.transpose(-2,-1) * C**-0.5
        weight = weight.masked_fill(self.tril[:T,:T]==0,float('-inf'))
        weight = F.softmax(weight,dim=-1)

        value = self.value(x)

        self.out = weight @ value

        return self.out



class transformerDecoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.embd_table = nn.Embedding(vocab_size,n_embd)
        self.pos_embd_table = nn.Embedding(block_size,n_embd)

        self.sa_head = selfAttentionHead(head_size)

        self.lm_head = nn.Linear(head_size,vocab_size)




    def forward(self,input, targets = None):

        B,T = input.shape

        token_embd = self.embd_table(input)
        pos_embd = self.pos_embd_table(torch.arange(T))

        x = token_embd+pos_embd
        x = self.sa_head(x)

        logits = self.lm_head(x)


        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C),targets.view(-1))

        return logits , loss



    def generate(self,input,max_token):

        for _ in range(max_token):
            input_cond = input[:,-block_size:]
            logits , loss = self(input_cond)
            logits = logits[:,-1,:]

            probs = F.softmax(logits, dim=-1)
            next_index = torch.multinomial(probs,1)

            input = torch.cat((input,next_index),dim = 1)

        return input


    def train(self):

        # create a PyTorch optimizer
        optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)

        for i in range(max_iter):

            xb,yb = get_batch('train')

            # evaluate model
            logits , loss = model(xb,yb)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()


            if i% (max_iter/10) == 0:
                print(f'{i}/{max_iter}  {loss}')



# ----------------------------------------------------------


# model evaluation
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out




# Train model
model = transformerDecoder()
#model.train()





# genarate from the model
# context = torch.zeros((1,1),dtype=torch.long)
# print(decode(model.generate(context,max_token=100)[0].tolist()))





In [46]:
model.__dict__['_modules']

{'embd_table': Embedding(65, 128),
 'pos_embd_table': Embedding(8, 128),
 'sa_head': selfAttentionHead(
   (key): Linear(in_features=128, out_features=16, bias=False)
   (query): Linear(in_features=128, out_features=16, bias=False)
   (value): Linear(in_features=128, out_features=16, bias=False)
 ),
 'lm_head': Linear(in_features=16, out_features=65, bias=True)}

### test

In [47]:
batch_size = 1
block_size = 8
eval_iters = 10000
n_embd = 128
head_size = 16
max_iter  = 10000
learning_rate = 1e-2

In [48]:
input= torch.tensor([[53, 50, 43, 56, 10,  0, 25, 39]])
targets = torch.tensor([50, 43, 56, 10,  0, 25, 39, 42])

In [49]:
simple_model = transformerDecoder()

In [51]:
simple_model(input,targets)

(tensor([[[ 1.5113e-01, -8.6961e-01,  7.4428e-01, -5.3501e-01, -7.4094e-01,
           -1.6690e-01, -2.4325e-02,  1.6348e-01, -1.8223e-01, -4.7238e-01,
            3.5420e-02,  6.2345e-01,  7.7287e-01, -2.0956e-02,  5.5341e-01,
            8.5361e-01, -2.2363e-02,  1.0059e-01, -1.9700e-01, -8.4409e-02,
           -6.7420e-02,  3.4854e-01, -4.2118e-01,  6.4072e-01, -1.1880e-02,
           -1.5592e-01,  8.0290e-01, -2.1901e-02,  5.8783e-01, -8.9990e-01,
            5.3407e-01,  4.6970e-01,  3.8428e-01,  1.4025e-01,  6.1066e-01,
           -5.7978e-02, -1.9403e-01, -4.1039e-01,  4.5685e-01,  4.0871e-01,
           -7.6304e-01, -1.4910e+00, -3.8945e-01,  5.1665e-01, -6.5653e-01,
           -1.4551e-01, -4.0071e-03,  1.8534e-01, -2.2413e-01,  5.3468e-01,
            7.6167e-02,  6.4892e-01,  7.1710e-02,  1.6055e-01, -1.2484e-02,
           -1.7057e-01, -7.0365e-01,  4.3949e-01, -3.5248e-01,  5.9428e-01,
           -8.6703e-02,  6.1627e-02, -1.9796e-01, -3.5068e-02,  5.3454e-01],
          [

In [52]:
head1 = selfAttentionHead(16)

In [53]:
input = torch.randn((batch_size,block_size,n_embd))    #(1,8,32)

In [54]:
att_out = head1(input)

In [55]:
att_out.shape

torch.Size([1, 8, 16])