In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def get_config():
  return {
      "block_size": 8,
      "batch_size": 32,
      "vocab_size":65,
      "max_tokens":1000,
      "n_embd":32,
      'lr':1e-3,
      'epochs':5000,
      'head_size':8,
      'num_head':4
  }
config = get_config()

In [3]:
with open("/content/input.txt",'r',encoding='utf-8') as f:
  text = f.read()

In [4]:
print(f"total len of the text {len(text)}")

total len of the text 1115394


In [5]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [6]:
def construct_vocabulary(text):
  chars = sorted(list(set(text)))
  vocab_size = len(chars)
  return (vocab_size,chars)

In [7]:
vocab_size,chars = construct_vocabulary(text)

In [8]:
print("".join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [9]:
class Tokenizer:
  def __init__(self,chars):
    print("started building the vocab...")
    self.stoi = {ch:i for i,ch in enumerate(chars)}
    self.itos = {i:ch for i,ch in enumerate(chars)}
    print(f"done...")

  def encode(self,s):
    "input: a string"
    return [self.stoi[c] for c in s]

  def decode(self,l):
    "input: a list"
    return "".join(self.itos[i] for i in l)

  def build_(self,text):
    return torch.tensor(self.encode(text),dtype = torch.long)

In [10]:
tokenizer = Tokenizer(chars)

started building the vocab...
done...


In [None]:
data = tokenizer.build_(text)

In [12]:
class Train_Test_Split:
  def __init__(self,train_per):
    self.train_per = train_per

  def __call__(self,data):
    n = int(self.train_per*len(data))
    train_data = data[:n]
    val_data = data[n:]
    return train_data,val_data

In [13]:
tts = Train_Test_Split(0.9)
train_data,val_data = tts(data)

In [14]:
x = train_data[:config['batch_size']]
y = train_data[1:config['batch_size']+1]

In [15]:
torch.manual_seed(1337)

class Batching:
  def __init__(self,train_data,val_data,config):
    self.train_data = train_data
    self.val_data = val_data
    self.batch_size = config['batch_size']
    self.block_size = config['block_size']

  def __call__(self,split):
    self.split = split
    data = self.train_data if self.split == "train" else self.val_data
    ix = torch.randint(0,len(data)-self.block_size,(self.batch_size,))
    x = torch.stack([data[i:i+self.block_size] for i in ix])
    y = torch.stack([data[i+1:i+self.block_size+1] for i in ix])
    return x,y

In [16]:
batching = Batching(train_data,val_data,config)
x,y = batching("train")

In [17]:


class BigramLanguageModel(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.vocab_size = config['vocab_size']
    self.token_emb_table = nn.Embedding(self.vocab_size,self.vocab_size)

  def forward(self,idx,target=None):
    logits = self.token_emb_table(idx)

    if target is None:
      loss = None
    else:
      B,T,C = logits.shape
      logits = logits.view(B*T,C)
      target = target.view(-1)
      criterion = nn.CrossEntropyLoss()
      loss = criterion(logits,target)
    return logits,loss

  def generate(self,idx,config):
    for i in range(config['max_tokens']):
      logits,_ = self(idx)
      logits = logits[:,-1,:]
      probs = F.softmax(logits,dim=1)
      idx_next = torch.multinomial(probs,num_samples=1)
      idx = torch.cat((idx,idx_next),dim = 1)

    return idx

In [18]:
blm = BigramLanguageModel(config)

In [19]:
idx = torch.zeros((1,1),dtype = torch.long)
print(tokenizer.decode(blm.generate(idx,config)[0].tolist()))


JLg,3D&OM .3YCjfolRwqXaDyttW!GmaUT-IIvuZV?sYfjzUvTQ3RwL ?etyLeg.COHW
Ri$ELkJMXpBEX;-G&Orl!bcH ;cq.z,rbyQZoN:QVT:cVOtkTLhuMYe-gqGhTxDlqYBkLDJnAsLJOVeJYD
J:r:HHESAbIqa!SKO.zJkSD$AzSQemsLuMElSU
NjOaXnHFzYtIIu&MENop?pNDfSegXRwp!,CDWBSCtA&O:y,hc?bSFm!,NEbD UXzKGW$b?K'LXW$hCwBFQpbfJGtiMFfKLNrHRml$ZWxCm:Q.NlnkdhdwqRDSG.HmBISxRKt&fcEnjSDyhwKwlutO;PRWB,Fb'KMW$ZWoALAtNau'eqaIHbsbI3l-zx,bpoqYiSKpRJo':TgEgKl-b;MRw$E'zlDaec?ZfENG3?-caHHWlHa.3j.oNeiuRw?'Nq HqG ?qnwlA d.CodJzLV.v!EGhTnICnRN&WW,bZUiFkT.kTIhOh,;vGyjyeijojkSZAkmTUCF$nJoyiS!kgQ
-IF,JD,UdyatSVhykKI?QvCnZVo?esSf X3Pmx?WWbLfxmM3dG
:eJ:AuD,.kf&Pw$kT.ZW$sS cGgKQz.K.re-uOuFKpOoCDx-Q.zs
xG;uZSK
jRvcGpB
jpboh,$TnKYSKVgKY;PqZwt TlLemMTlRDCQzuFB?r N$:T.G&?b!KpoiKCBreujOFqZRrjOWysEc;lp'!TnJMAbvEz'euZWVDdvkVxWCtxaZZ.CoHAVemSKJ-id?,mx,lGbfTxBr,.,;PgUPZBkcHZ-NsrH!JNvdZWqc?' ?VEpbZW$P:3.zvhEPnx,;vYQ
IvPSSBgvDo&Aksyd-B
Dbfj:hoXhMKimFaxKJmx3tjJjdka;MEXjX3J3hgj.J:dZPqJ-jWB isrhFJOGNt ,xH! ;n ey'IuukSnlCW
jk! iTN? hDFyBI Wlm'z&,MEXKCjJpbLcY;YDqWbIhhhmHHrY

In [20]:
import torch.optim as optim
optimizer = torch.optim.AdamW(blm.parameters(),lr = 1e-3)
batching = Batching(train_data,val_data,config)
for steps in range(1000):
    xb,yb = batching("train")

    logits,loss = blm(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

3.7732717990875244


In [21]:
idx = torch.zeros((1,1),dtype = torch.long)
print(tokenizer.decode(blm.generate(idx,config)[0].tolist()))


BSpbweglA!QF:
waDKxh,XRIo3Y,W'zVeiLe,,XxL:K.ziedg&I!ColP,pLkYxY Letfn&boqYRD&'dG?'tnDc
RSldGQF aJOHqqaw?KQ&Se-wr,J.SKeBI:
I$WutjSq:g&n,FzRDwh-Q.xpbvrit'
Jc m.!;rY$!haw:r'- UwaxuZng,p'T-gEItduGtvC&ml.
JklaVS!RwnKs33IALg?WPhmujOn BY&PSTnQ3?'kb:OpanV
by$W$ERway,LgcTFMSHZKO
Q.?Hap'Hqfo
Je,;?K,CsLyRrQ.MAhMbojyRxhmyw?o
tticarzd:wenIfejXx,DwxEHdeBBD:TSCBCURDoXcutMBH3dakph;l!B
PMjoCUA HZswigKINTnBEX;ybsNGtNo
j$'zWEtk,r.zRNoNRrq'zp.e,pRd-vP;Cdg&-
Lfj-UdariRDI-g whesX;QxBmheAWuZ nokh3j,F:QG'aKO&Pn eJp
JzivCtnG,XjRldG&xThthMxzuMBThovu$htot:;sLPshmWEExOVVLAm sLeBnAEcRtL
cZpoNAOV!pimkwidFzU!Ih.I;aiPYR d.CdoC mBSC.AnaqpaQ3deGtoVFabsOIIGoCOAuonEXLPnZnfLY;&MiriGAWQepl-AUDX33jR $rlSGD
totNTUYBI
OJVDg
r T.CEIEcEJ&w,IfRwlM.EGS Kp$ZWBqfRYyF$TKTupB3cEXju;SxQURShvvDxN:hFyC:PgE:eFq!my?LkTII?LgjBA:IZUr:BHWB MyfSfo3cYDHimpAn?OofUL
P irjeFH
lyg$WdePX uZOQ.DTlo,J&w.zYQUk pGXjD Ve! szeBoopeBDqvLgEiSKJIfO:JFqqXaIX3IXyfoiPwigmUqcNlav.jO
AzgUT;JCRIEQe3jOrNtybRieBMThqbt'zN&UABoCYrXm-pBYW.euFccWq!T:hbPYRwpoXSe ?onMre

# The Mathematical trick in self-attention

In [22]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [23]:
x

tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]],

        [[ 1.3488, -0.1396],
         [ 0.2858,  0.9651],
         [-2.0371,  0.4931],
         [ 1.4870,  0.5910],
         [ 0.1260, -1.5627],
         [-1.1601, -0.3348],
         [ 0.4478, -0.8016],
         [ 1.5236,  2.5086]],

        [[-0.6631, -0.2513],
         [ 1.0101,  0.1215],
         [ 0.1584,  1.1340],
         [-1.1539, -0.2984],
         [-0.5075, -0.9239],
         [ 0.5467, -1.4948],
         [-1.2057,  0.5718],
         [-0.5974, -0.6937]],

        [[ 1.6455, -0.8030],
         [ 1.3514, -0.2759],
         [-1.5108,  2.1048],
         [ 2.7630, -1.7465],
         [ 1.4516, -1.5103],
         [ 0.8212, -0.2115],
         [ 0.7789,  1.5333],
         [ 1.6097, -0.4032]]])

In [24]:
xbow = torch.zeros((B,T,C))
for b in range(B):
  for t in range(T):
    x_prev = x[b,:t+1]
    xbow[b,t] = torch.mean(x_prev,0)

In [25]:
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1,True)
xbow2 = wei@x # (T,T)@(B,T,C)--> (B,T,C)

### Trick-1

In [26]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a,1,keepdim = True)
b = torch.randint(0,10,(3,2)).float()
c = a@b
print(a)
print("\n")
print(b)
print("\n")
print(c)
print("\n")

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])


tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])




### Trick-2

In [27]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0,float('-inf'))
wei = F.softmax(wei,dim = 1)
xbow3 = wei@x
xbow3

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

## Self-Attention

In [28]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C) #shape (B,T,C)

head_size = 16
key = nn.Linear(C,head_size,bias = False)
query = nn.Linear(C,head_size,bias = True)
value = nn.Linear(C,head_size,bias = False)
k = key(x) #(B,T,head_size)
q = key(x)
wei = q@k.transpose(-2,-1) # (B,T,16)@(B,16,T)--> (B,T,T)
tril = torch.tril(torch.ones(T,T))
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei,dim =-1)
v = value(x) #(B,T,head_size)
out = wei@v
out.shape #(B,T,head_size)

torch.Size([4, 8, 16])

In [37]:
class Head(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.key = nn.Linear(config['n_embd'],config['head_size'],bias = False)
    self.query = nn.Linear(config['n_embd'],config['head_size'],bias = True)
    self.value = nn.Linear(config['n_embd'],config['head_size'],bias = False)
    self.register_buffer('tril',torch.tril(torch.ones(config['block_size'],config['block_size'])))

  def forward(self,x):
    B,T,C = x.shape
    k = self.key(x) #(B,T,head_size)
    q = self.query(x)
    wei = q@k.transpose(-2,-1)*C**-0.5
    wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
    wei = F.softmax(wei,dim =-1)
    v = self.value(x)
    out = wei@v
    return out

In [38]:
class MultiHeadAttention(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.heads = nn.ModuleList([Head(config) for _ in range(config['num_head'])])

  def forward(self,x):
    return torch.cat([h(x) for h in self.heads],dim = -1)

In [39]:
class FeedForward(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(config['n_embd'],config['n_embd']),
        nn.ReLU()
    )
  def forward(self,x):
    return self.net(x)

In [40]:
class BigramLanguageModel(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.vocab_size = config['vocab_size']
    self.n_embd = config['n_embd']
    self.token_emb_table = nn.Embedding(self.vocab_size,self.n_embd)
    self.block_size = config['block_size']
    self.lm_head = nn.Linear(self.n_embd,self.vocab_size)
    self.pos_embedding_table = nn.Embedding(self.block_size,self.n_embd)
    self.sa_head = MultiHeadAttention(config)
    self.ffwd = FeedForward(config)

  def forward(self,idx,target=None):
    B,T = idx.shape
    pos_emb = self.pos_embedding_table(torch.arange(T)) #(T,n_embd)
    tok_emb = self.token_emb_table(idx) # (B,T,n_embd)
    x = tok_emb+pos_emb #(B,T,n_embd)
    x = self.sa_head(x) # (B,T,vocab_size)
    x = self.ffwd(x)
    logits = self.lm_head(x)

    if target is None:
      loss = None
    else:
      B,T,C = logits.shape
      logits = logits.view(B*T,C)
      target = target.view(-1)
      criterion = nn.CrossEntropyLoss()
      loss = criterion(logits,target)
    return logits,loss

  def generate(self,idx,config):
    for i in range(config['max_tokens']):
      idx_cond = idx[:, -config['block_size']:] # truncate to block_size
      logits,_ = self(idx_cond)
      logits = logits[:,-1,:]
      probs = F.softmax(logits,dim=1)
      idx_next = torch.multinomial(probs,num_samples=1)
      idx = torch.cat((idx,idx_next),dim = 1)

    return idx

In [41]:
blm = BigramLanguageModel(config)

In [44]:
import torch.optim as optim
optimizer = torch.optim.AdamW(blm.parameters(),lr = config['lr'])
batching = Batching(train_data,val_data,config)
for steps in range(config['epochs']):
    xb,yb = batching("train")

    logits,loss = blm(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.030078172683716


In [45]:
idx = torch.zeros((1,1),dtype = torch.long)
print(tokenizer.decode(blm.generate(idx,config)[0].tolist()))



At sws all ank mny is they nart thir fake sinfousene; hou ing hime cet I to am, thatitke derow mant of gor read ith mur, his ride ban hill bespinopur:
O with that un, youst hat ren.

POLIS bre buruage;
Kyround trione I gike tose, 't chougathmot cliens imm, mand suntly liat thamense angre:
Sitee, youls, corsonty swellitelt with hacis.

Whe vout mar wagughtess you affe, erive the hile donttecht I of nash, the not ad witg; ga hil enandle.

QUEERIV:
Aank intrijenarawrave will compothen
that, nige to gaw, comeseewsapenst:
To same allade lot man shawd alenglliis wout:
That creave Ehound: I Cliugh's chavedas, his ake, my litheat that
Whe tang that,
Whe furre onin, in a me lins the the ducht; groveanmidly,
Ou, thate sis the wir to my son
And acky'd uppand Butohe fatudyre do, usold,
And egrerdind'd youghesppob forone font cake this watin et wrezille brew'd
Thate saing hith. LOM:
EUSIUUS:
Ways, I ay gruoht reeat the angivert's bleas son
Thas mankradubleabre?

PUUT:
He thy, that servow dand pra