<a href="https://colab.research.google.com/github/samitha278/gpt-from-scratch/blob/main/GPT_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import gdown
import random

In [5]:
file_id = "1ia6z4itw7WJWpnoTohURX6Lm-AnZmVZz"
url = f"https://drive.google.com/uc?id={file_id}"

output = "input.txt"
gdown.download(url, output, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1ia6z4itw7WJWpnoTohURX6Lm-AnZmVZz
To: /content/input.txt
100%|██████████| 1.12M/1.12M [00:00<00:00, 130MB/s]


'input.txt'

In [6]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

len(text)

1115394

In [7]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [8]:
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for i,s in enumerate(chars)}

encode = lambda s : [stoi[ch] for ch in s]
decode = lambda l : ''.join(itos[i] for i in l)

print(encode("Hello world"))
print(decode(encode("Hello world")))

[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
Hello world


In [9]:
data = torch.tensor(encode(text))
data[:100]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])

In [10]:
n = int(0.9* len(data))
train = data[:n]
val = data[n:]

In [11]:
block_size = 8

x = train[:block_size]
y = train[1:block_size+1]

for i in range(block_size):
  context = x[:i+1]
  target = y[i]

  print(f'{context} target:{target}')

tensor([18]) target:47
tensor([18, 47]) target:56
tensor([18, 47, 56]) target:57
tensor([18, 47, 56, 57]) target:58
tensor([18, 47, 56, 57, 58]) target:1
tensor([18, 47, 56, 57, 58,  1]) target:15
tensor([18, 47, 56, 57, 58,  1, 15]) target:47
tensor([18, 47, 56, 57, 58,  1, 15, 47]) target:58


In [12]:
ix = torch.randint(100-block_size, (4,))
ix

tensor([27, 65,  3, 33])

In [13]:
torch.manual_seed(278)


batch_size = 4
block_size = 8


def get_batch(split):

  data = train if split=='train' else val
  ix = torch.randint(len(data)-block_size, (batch_size,))
  x = torch.stack([data[i:block_size+i] for i in ix])
  y = torch.stack([data[i+1:block_size+i+1] for i in ix])
  return x,y


xb,yb = get_batch('train')

xb,yb

(tensor([[47, 57,  0, 42, 39, 59, 45, 46],
         [40, 59, 58,  1, 40, 39, 57, 58],
         [58, 53,  1, 46, 43, 56,  1, 57],
         [50, 63,  0, 58, 39, 56, 56, 63]]),
 tensor([[57,  0, 42, 39, 59, 45, 46, 58],
         [59, 58,  1, 40, 39, 57, 58, 39],
         [53,  1, 46, 43, 56,  1, 57, 53],
         [63,  0, 58, 39, 56, 56, 63,  1]]))

## Bigram Language Model

In [14]:
torch.manual_seed(278)

class BigramLM(nn.Module):

  def __init__(self,vocab_size):
    super().__init__()

    self.token_emb_table = nn.Embedding(vocab_size,vocab_size)


  def __call__(self,idx , targets=None):

    logits = self.token_emb_table(idx)    #shape: (b,t,c)
    if targets is None:
      loss = None
    else:
      loss = F.cross_entropy(logits.view(-1,vocab_size) , targets.view(-1))

    return logits , loss


  def generate(self,idx, max_new_tokens):

    for _ in range(max_new_tokens):

      logits , loss = self(idx)

      logits = logits[:,-1,:] #from all batch's last element's logits

      probs = F.softmax(logits,dim=1)

      ix = torch.multinomial(probs,num_samples=1)

      idx = torch.cat((idx,ix), dim=1)

    return idx



bigram = BigramLM(vocab_size)
logits , loss = bigram(xb,yb)

print(logits.shape,loss)



idx= torch.zeros((1,1),dtype= torch.long)

print(decode(bigram.generate(idx,max_new_tokens=100)[0].tolist()))







torch.Size([4, 8, 65]) tensor(4.5468, grad_fn=<NllLossBackward0>)

hDkMQcyOQpP-rU-,VfVk:rXwxj Ug$$kNRxr.x'R3ULl!WC?fErPF'K'nybrlziq:IF:J.-YVN.jj$R-kDwR
hWiDAg,rHH'!JzL


In [15]:
print([i.shape for i in bigram.parameters()])   #token_emb_table

[torch.Size([65, 65])]


In [16]:
optimizer = torch.optim.AdamW(bigram.parameters(),lr = 1e-3)

In [17]:
batch_size = 32

for i in range(1000):

  xb,yb = get_batch('train')

  logits , loss = bigram(xb,yb)

  optimizer.zero_grad(set_to_none=True)

  loss.backward()

  optimizer.step()

  if i%10000 == 0:
    print(loss.item())


4.7092204093933105


In [18]:
idx= torch.zeros((1,1),dtype= torch.long)

print(decode(bigram.generate(idx,max_new_tokens=100)[0].tolist()))


hW. U EQ&KHb
SKB3FKB&jq&p;JoYXMvlik:cusuILA:ivFrod3Y!
K?$ne FsT liO't JCoAHpEqKLKm!mLD3fMWArtawnJICl


### Averaging past context

In [19]:
x,y = get_batch('train')
x[:4]

tensor([[49,  1, 46, 47, 57,  1, 45, 56],
        [ 1, 50, 47, 60, 43,  6,  0, 21],
        [61, 56, 39, 54,  1, 53, 59, 56],
        [46,  7,  7,  0,  0, 29, 33, 17]])

In [20]:
xb,yb = bigram(x,y)
B,T,C = xb.shape
xb.shape

torch.Size([32, 8, 65])

In [21]:
xbow = torch.zeros(xb.shape)

for b in range(B):
  for t in range(T):

    xprev = xb[b,:t+1]
    xbow[b,t] = xprev.mean(0)

### Averaging with matrix multiplication

In [22]:
avg8 = torch.tensor([1/i for i in range(1,T+1)])
avg8

tensor([1.0000, 0.5000, 0.3333, 0.2500, 0.2000, 0.1667, 0.1429, 0.1250])

In [23]:
avg88 = (avg8*torch.ones(T,T)).T

In [24]:
avg88_tril = torch.tril(avg88)
avg88_tril

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [25]:
a = torch.tril(torch.ones(T,T))
a = a / a.sum(1,keepdim=True)
a


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [26]:
xbow_mat = torch.zeros(xb.shape)

for batch in range(B):
  xbow_mat[batch] = a @ xb[batch]


In [27]:
xbow2 = a @ xb

In [28]:
torch.allclose(xbow,xbow2)

False

### Summary of averaging

In [29]:
a = torch.tril(torch.ones(T,T))
a = a / a.sum(1,keepdim=True)

xbow = a @ xb


### Version 3

In [30]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))

wei = F.softmax(wei,dim = 1)

xbow3 = wei @ xb
(xbow3 == xbow2).all()

tensor(True)

### mat mul high dimension

In [32]:
a = torch.randn((2,4,8))
b = torch.randn((2,4,4))
b @ a

tensor([[[ 0.0238, -0.5205,  0.5773,  1.1216, -0.0888, -0.1119, -0.0598,
           0.7334],
         [ 0.1574,  1.1242, -1.5157, -1.7866,  1.0083, -0.0088,  0.0619,
          -0.5218],
         [-1.1012,  2.6861, -0.9613, -1.1384, -0.4386,  1.1619,  0.1321,
          -4.3564],
         [-0.5589,  0.4285, -1.6508,  1.8339,  2.3923,  0.0622,  0.2930,
           2.7234]],

        [[-0.5058,  0.0702, -1.3275,  0.4296,  0.1739, -0.7650,  2.2329,
          -0.2287],
         [ 0.7833, -2.1584,  0.9986, -2.3059, -1.6058,  0.6488, -2.2950,
           2.6596],
         [ 0.1490,  1.2945,  1.2348,  1.6891,  2.2756, -0.6578,  0.9311,
          -2.0301],
         [-0.1444,  0.8026,  0.3997,  1.3197,  1.8432, -0.2162,  1.4962,
          -1.0151]]])

### Version 4 - Self attention

In [33]:
torch.manual_seed(278)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)


head_size = 16

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)



k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))

wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])