<a href="https://colab.research.google.com/github/yugpsyfer/Playing_with_PyTorch/blob/main/TokenPredictiorTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Self Attention

I have tried to implement self attention paper from scratch, although I have missed things like positional encoding.

What I have implemented:- \\
 * Self Attention
 * Multihead attention
 * Encoder and Decoder

##Imports

In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import dataloader, Dataset
import pandas as pd
import sentencepiece as spm
import random
import math

##Dataset Loading

In [5]:
data_path = "/content/drive/MyDrive/Datasets/internet_archive_scifi_v3.txt"
corpus = []
vocab_size = 128

def data_reader():
    corpus = []
    with open(data_path, 'r') as fp:
        train_data = fp.readlines()
        corpus += train_data[0].split('.')
    return corpus, train_data


corpus, train_data = data_reader()
with open("text_data_text.txt", 'a') as wp:
    for line in corpus:
        wp.write(line + "\n")


##Tokenizer

Directly used google's sentence piece

In [6]:
spm.SentencePieceTrainer.train(input="/content/text_data_text.txt", model_prefix='tokenizer', vocab_size=vocab_size)

In [7]:
sp_e = spm.SentencePieceProcessor(model_file="/content/tokenizer.model")

tokenized_data = sp_e.encode(train_data[0][0:12800])
data_length = len(tokenized_data)

In [8]:
data_length

8430

##Creating Input data

In [9]:
sequence_length = 20
batch_size = 64
num_batches = data_length//(20*64)

In [10]:
def make_batch():
    batches = []
    for batch in range(num_batches):
        x = []
        y = []
        start = random.randint(0,data_length)
        end = start+sequence_length

        if data_length < end:
            continue

        for i in range(batch_size):
            x += torch.tensor([tokenized_data[start:end]], dtype=torch.long)
            y += torch.tensor([tokenized_data[start+1:end+1]], dtype=torch.long)

        x = torch.stack(x)
        y = torch.stack(y)

        batches.append((x,y))

    return batches

In [11]:
batches = make_batch()

##Network

In [39]:
class SingleHeadAttention(nn.Module):
    def __init__(self, masking=False):
        super().__init__()
        self.softmax = nn.Softmax(dim=1)
        self.masking = masking

    def mask(self, x):
        lower_tri = torch.tril(torch.ones(x.shape))
        lower_tri[lower_tri == 0] = -float("Inf")
        lower_tri = lower_tri.to('cuda')
        x = x @ lower_tri
        return x

    def forward(self, K, Q, V):
        matmul = torch.bmm(Q, torch.transpose(K, 1,2))
        scaled = matmul/math.sqrt(K.shape[2])

        if self.masking:
            self.mask(scaled)

        return self.softmax(scaled) @ V

In [54]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads=4, masking=False):
        super().__init__()
        self.attention_heads = heads
        self.multi_attention = nn.ModuleList([SingleHeadAttention(masking=masking) for h in range(heads)])

    def forward(self, K, Q, V):
        b,t,d = K.shape
        step = d//self.attention_heads

        out = [h(K[:,:, idx*step : (idx+1)*step], Q[:,:, idx*step : (idx+1)*step], V[:,:, idx*step : (idx+1)*step]) for idx, h in enumerate(self.multi_attention)]
        out = torch.concat(out, dim=-1)

        return out

In [55]:
class Encoder(nn.Module):
    def __init__(self, token_size, proj_dim):
        super().__init__()
        self.attention_layer = MultiHeadAttention()
        self.layer_norm = nn.LayerNorm(token_size)
        self.final_linear_layer = nn.Linear(token_size, token_size)
        self.final_layer_norm = nn.LayerNorm(token_size)

    def forward(self, input):
        x = self.attention_layer(input, input, input)
        x = self.layer_norm(x + input)     # Residual connection
        out = self.final_linear_layer(x)
        out = self.final_layer_norm(out + x)

        return out


In [56]:
class Decoder(nn.Module):
    def __init__(self, token_size, proj_dim):
        super().__init__()
        self.masked_attention_layer = MultiHeadAttention(masking=True)
        self.layer_norm_1 = nn.LayerNorm(token_size)

        self.attention_layer = MultiHeadAttention()
        self.layer_norm_2 = nn.LayerNorm(token_size)

        self.final_layer_norm = nn.LayerNorm(token_size)
        self.final_linear_layer = nn.Linear(token_size, token_size)

    def forward(self, x, encoded_input):
        out = self.masked_attention_layer(x, x, x)
        out = self.layer_norm_1(out + x)

        final_out = self.attention_layer(encoded_input, encoded_input, out)
        final_out = self.layer_norm_2(final_out + out)
        out = self.final_linear_layer(final_out)

        return self.final_layer_norm(out + final_out)

In [57]:
class TokenPredictor(nn.Module):
    def __init__(self, token_size, tokens, proj_dim=512):
        super().__init__()
        self.encoder = Encoder(token_size, proj_dim)
        self.decoder = Decoder(token_size, proj_dim)

        self.prediction_head = nn.Sequential(
            nn.Linear(token_size, tokens),
            nn.Softmax(dim=1)
        )

    def forward(self, inp):

        encoded_out = self.encoder(inp)
        out = self.decoder(inp, encoded_out)

        return self.prediction_head(out)


##Training

In [62]:
token_size = 48
epochs = 3000
learning_rate = 1e-3
device = 'cuda'

model = TokenPredictor(token_size=token_size, tokens=vocab_size, proj_dim=256)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
embedding_fn = nn.Embedding(num_embeddings=vocab_size,embedding_dim=token_size)
loss_fn = nn.CrossEntropyLoss()

In [63]:
def train():
    for eps in range(epochs):
        Loss = 0
        for x,target in batches:
            x = embedding_fn(x)
            b,t,d = x.shape
            target = target.reshape(b*t)
            x = x.to(device)
            target = target.to(device)

            out = model(x).reshape(b*t,-1)
            loss = loss_fn(out, target)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            Loss += loss

        if eps%100 == 0:
            print("Loss: {} epoch: {}".format(Loss/num_batches, eps))


train()

Loss: 4.851099014282227 epoch: 0
Loss: 4.203580379486084 epoch: 100
Loss: 4.176867485046387 epoch: 200
Loss: 4.175509452819824 epoch: 300
Loss: 4.174942970275879 epoch: 400
Loss: 4.174624443054199 epoch: 500
Loss: 4.166080474853516 epoch: 600
Loss: 4.1658034324646 epoch: 700
Loss: 4.157221794128418 epoch: 800
Loss: 4.156882286071777 epoch: 900
Loss: 4.156682014465332 epoch: 1000
Loss: 4.156554222106934 epoch: 1100
Loss: 4.156494140625 epoch: 1200
Loss: 4.1563591957092285 epoch: 1300
Loss: 4.152240753173828 epoch: 1400
Loss: 4.152204513549805 epoch: 1500
Loss: 4.152155876159668 epoch: 1600
Loss: 4.152143478393555 epoch: 1700
Loss: 4.152100563049316 epoch: 1800
Loss: 4.1520490646362305 epoch: 1900
Loss: 4.15201997756958 epoch: 2000
Loss: 4.152017116546631 epoch: 2100
Loss: 4.151995658874512 epoch: 2200
Loss: 4.151968002319336 epoch: 2300
Loss: 4.1519455909729 epoch: 2400
Loss: 4.151934623718262 epoch: 2500
Loss: 4.151923656463623 epoch: 2600
Loss: 4.151916027069092 epoch: 2700
Loss: 4.15