In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import math

In [2]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    print(len(text))

1115393


In [3]:
chars = sorted(list(set(text)))
voc_size = len(chars)
d_model=32
print(len(chars))

65


In [4]:
stoi = {char:i for i,char in enumerate(chars)}
itos = {i:char for i,char in enumerate(chars)}
encode = lambda char : [stoi[c] for c in char]
decode = lambda num : [itos[n] for n in num]

In [5]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)

torch.Size([1115393])


In [6]:
n = int(0.9 * len(data))
train_data = data[:n]
validation_data = data[n:]

In [7]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(mode = 'train'):
    data = train_data if mode=='train' else 'validation_data'
    ix = torch.randint(len(data)-batch_size, (batch_size,))
    d = torch.stack([data[i:i+block_size] for i in ix])
    l = torch.stack([data[i+1: i+block_size+1] for i in ix])
    return d,l

x,y = get_batch()
print(x,y)

tensor([[52, 53, 58,  1, 53, 59, 56,  1],
        [ 1, 41, 53, 51, 43, 57,  1, 58],
        [61, 39, 49, 43,  6,  1, 39, 52],
        [ 1, 42, 53, 58, 46,  1, 51, 63]]) tensor([[53, 58,  1, 53, 59, 56,  1, 56],
        [41, 53, 51, 43, 57,  1, 58, 46],
        [39, 49, 43,  6,  1, 39, 52, 42],
        [42, 53, 58, 46,  1, 51, 63,  1]])


In [12]:
class InputEmbedding(nn.Module):
    
    def __init__(self, voc_size, d_model):
        super().__init__()
        self.emb = nn.Embedding(voc_size, d_model)
        self.dim = d_model
    def forward(self,x):
        x = self.emb(x)
        return x * math.sqrt(self.dim)

In [114]:
class PositionEmbedding(nn.Module):
    
    def __init__(self, seq_length, d_model):
        super().__init__()
        self.range = torch.arange(seq_length).unsqueeze(1)
        even_i = torch.pow(10000, torch.arange(0,d_model,2,dtype=torch.float) / math.sqrt(d_model))
        odd_i = torch.pow(10000, torch.arange(1,d_model,2,dtype=torch.float) / math.sqrt(d_model))
        self.pe = torch.zeros(seq_length, d_model)
        self.pe[:, 0::2] = torch.sin(self.range / even_i)
        self.pe[:, 1::2] = torch.cos(self.range / odd_i)   
    
    def forward(self, x):
        b,t,c = x.shape
        return x + self.pe[:t,:]

In [124]:
class LayerNorm(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.alpha = nn.Parameter(torch.randn(1))
        self.beta = nn.Parameter(torch.randn(1))
        
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha*(x-mean/(std**2 + 1e-6)) + self.beta

In [121]:
class FFLayer(nn.Module):
    
    def __init__(self, d_model, p=0.2):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(d_model, d_model*4),
            nn.ReLU(),
            nn.Dropout(p),
            nn.Linear(d_model*4, d_model)
        )
        
    def forward(self,x):
        return self.seq(x)

In [106]:
class ResidualConnection(nn.Module):
    
    def __init__(self, d_model, p=0.2):
        super().__init__()
        self.norm = LayerNorm()
        self.drop = nn.Dropout(p)
        
    def forward(self, x, sublayer):
        return x + self.drop(sublayer(self.norm(x)))

In [127]:
class MuliHead(nn.Module):
    
    def __init__(self, d_model, n_head, p=0.2):
        super().__init__()
        self.w_q = nn.Linear(d_model, d_model, bias = False)
        self.w_k = nn.Linear(d_model, d_model, bias = False)        
        self.w_v = nn.Linear(d_model, d_model, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.head = n_head
        self.d_model = d_model
        self.drop = nn.Dropout(p)
    
    @staticmethod
    def attention(q,k,v,t, d_model, mask):
        qk = q @ k.transpose(-1,-2) / math.sqrt(d_model)
        wei = qk.masked_fill(mask[:t,:t]==0, float('-inf'))
        wei = wei.softmax(dim=-1)

        atten = wei @ v
        return atten
    
    def forward(self,x):
        b,t,c= x.shape
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        q = q.view(b, t, self.head, int(c/self.head)).transpose(1,2)
        k = k.view(b, t, self.head, int(c/self.head)).transpose(1,2)
        v = v.view(b, t, self.head, int(c/self.head)).transpose(1,2)
        attention = MuliHead.attention(q,k,v,t, self.d_model, self.tril)
        attention = attention.transpose(1,2).contiguous().view(b, t, c)
        return self.drop(attention)

In [107]:
class Block(nn.Module):
    
    def __init__(self, d_model):
        super().__init__()
        
        self.block = nn.ModuleList([ResidualConnection(d_model) for i in range(2)])
        self.head = MuliHead(d_model, 4)
        self.ff = FFLayer(d_model)
    def forward(self,x):
        
        x = self.block[0](x, self.head)
        x = self.block[1](x, self.ff)        
        return x

In [221]:
class Projection(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(d_model, voc_size)
        
    def forward(self,x):
        x = self.linear(x)
        return torch.log_softmax(x, dim=-1)

In [222]:
class Decoder(nn.Module):
    
    def __init__(self, d_model):
        super().__init__()
        self.decode = nn.ModuleList([Block(d_model) for i in range(6)])
        self.norm = LayerNorm()
        
    def forward(self,x):
        
        for d in self.decode:
            x = d(x)
            
        x = self.norm(x)
        
        return x

In [238]:
class LLM(nn.Module):
    
    def __init__(self, voc_size):
        super().__init__()
        self.emb = InputEmbedding(voc_size,d_model)
        self.pos_emb = PositionEmbedding(block_size, d_model)
        self.decode = Decoder(d_model)
        self.proj = Projection()
    def forward(self,x, target=None):
        
        b,t = x.shape
        x = self.pos_emb(self.emb(x))
        y = self.decode(x)
        y = self.proj(y)
        if target == None:
            loss = None
            
        else:
            b, t, c = y.shape
            y = y.view(b*t,c)
            target = target.view(b*t)
            loss = F.nll_loss(y, target)
            
        return y,loss
    
    def generate(self, idx, max_new_token):
#         idx: (B,T)
        for i in range(max_new_token):
#         logit: (B,T,C)
            logit, loss = self(idx[:, -(block_size):])
#         logit: (B,C)
            logit = logit[:, -1, :]
    #     logit: (B, C)
           
            logit = F.softmax(logit, dim=-1)
    #     logit: (B, 1)
            nxt= torch.multinomial(logit, num_samples=1)
            idx = torch.concat([idx, nxt], dim=-1)
        return idx

In [239]:
m = LLM(voc_size)
logits, loss = m(x, y)

In [240]:
opt = torch.optim.AdamW(m.parameters(), lr=1e-3)


In [248]:
epochs = 1000
for i in range(epochs):
    xb, yb = get_batch()
    logits, loss = m(xb, yb)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()
print(loss)

tensor(2.3864, grad_fn=<NllLossBackward0>)


In [249]:
context = torch.zeros((1, 1), dtype=torch.long)
a =decode(m.generate(context, max_new_token=2000)[0].tolist())
print(''.join(a))



Thtereccire masancdy you.
RELI e--vincffe nyovouf se to'r flekeells,
Pumhe eronth, wht hy agh.
CUCI ATIURHTRYVELELYRY:
Nong andcarer i, trt fiavend frowhe the a po thist ller buno's g in.

Moveng mar.  WAickAnengeid ies heverert, whhe f acoric sher,
Geny b,y.


3e lasb fe usis as weano my rpe ipe. BRENCUif s:
Mime hi, ofl aly pamststagus, to thy tathe renerlb sot o he gromeis kiretutver.

LORORUERERRS:

Mhow toesey'tembur wen cotheeve wigl h;
Saschce turmigor rouphe, lashtt sinse thou
Eve pisha heren,
Ay me t Jeseends?
Anty divetu a iste, she mo met d the, te Mussooted ien,
A's thabr aimev thim kans, greleache wilan weay we athe ngreean, pthere thin, wiorn the douends; withou ane wit's,
Nthe,, ghiny 
uthit ker our
And?

LULORCA:
MLAVour tondercucpeveting;
The,-cy imy arfar.
Foch fealleS INO,
Endgou sus vinds, osandel onouthe vu,
Pons, Fid ghtoro oun urpucctoI d il coenoy seshooMorte mMded osrdm,
Hothe cal,
Heere Ifioitl kerim klexie, f fugourf.

Touoy nofere.
C lalle!

WOf Yoritemy, 