In [56]:
import torch
import torch.nn.functional as F
from torch import nn
import math

In [57]:
def tokenizer(corpus):
  corpus = ' '.join(corpus).lower()
  token_arr = corpus.split()
  token_set = set(token_arr)
  token_dict = {}
  for index, i in enumerate(token_set):
    token_dict[i] = index

  return token_dict

def tokenize(vocab, text):
  words = text.lower().split()
  tokens = []
  for word in words:
    tokens.append(vocab[word])
  return tokens

### Self Attention

In [58]:
def scaled_dot_product(q, k, v, mask = None):
  scores = torch.bmm(q,k.transpose(1,2)) / math.sqrt(q.size(-1))
  if mask is not None:
    scores = scores.masked_fill(mask == 0, float("-inf"))
  scores = F.softmax(scores, dim = -1) # 5x5
  weights = torch.bmm(scores, v) # sm = 5 x 5, v = 5 x 10
  return weights

class Self_Attention(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.q = nn.Linear(config.emb_dim, config.head_dim)
    self.k = nn.Linear(config.emb_dim, config.head_dim)
    self.v = nn.Linear(config.emb_dim, config.head_dim)

  def forward(self,x):
    query = self.q(x)
    key = self.k(x)
    value = self.v(x)
    self_atten = scaled_dot_product(query, key, value)
    return self_atten

### Multi Headed Attention

In [59]:
class Multi_Headed_Attention(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.module_list = nn.ModuleList([Self_Attention(config) for _ in range(config.emb_dim // config.head_dim)])
    self.output = nn.Linear(config.emb_dim, config.emb_dim)


  def forward(self, hidden_state):
    x = torch.cat([h(hidden_state) for h in self.module_list], dim = -1)
    x = self.output(x)
    return x

## Feed Forward Layer

In [60]:
class Feed_Forward(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.feed_forward1 = nn.Linear(config.emb_dim, 4 * config.emb_dim)
    self.feed_forward2 = nn.Linear(4 * config.emb_dim, config.emb_dim)
    self.gelu = nn.GELU()
    self.drop_out = nn.Dropout(0.30)

  def forward(self, x):
    x = self.feed_forward1(x)
    x = self.gelu(x)
    x = self.feed_forward2(x)
    x = self.drop_out(x)
    return x

## Normalization

In [61]:
class Transformer_Encoder_Layer(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.layer_norm1 = nn.LayerNorm(config.emb_dim)
    self.layer_norm2 = nn.LayerNorm(config.emb_dim)
    self.attention = Multi_Headed_Attention(config)
    self.feed_forward = Feed_Forward(config)

  def forward(self, x):
    x = self.layer_norm1(x)
    x = x + self.attention(x)
    x = x + self.feed_forward(self.layer_norm2(x))
    return x

## Positional Embeddings

In [62]:
class Embeddings(nn.Module):
  def __init__(self, vocab_size, config):
    super().__init__()
    self.token_embeddings = nn.Embedding(vocab_size, config.emb_dim)
    self.positional_embeddings = nn.Embedding(config.emb_dim , config.emb_dim)
    self.norm_1 = nn.LayerNorm(config.emb_dim)
    self.drop_out = nn.Dropout(0.30)

  def forward(self, x, seq_len):
    token_em = self.token_embeddings(x)
    position = torch.arange(seq_len, dtype = torch.long).unsqueeze(0)
    positional_em = self.positional_embeddings(position)
    emb = token_em + positional_em
    emb = self.norm_1(emb)
    emb = self.drop_out(emb)
    return emb

## Transformer Encoder

In [63]:
class Transformer_Encoder(nn.Module):
  def __init__(self, vocab_size, config):
    super().__init__()
    self.embeddings = Embeddings(vocab_size, config)
    self.layers = nn.ModuleList([Transformer_Encoder_Layer(config) for _ in range(config.no_of_encoders)])

  def forward(self, x, seq_len):
    x = self.embeddings(x, seq_len)
    for layer in self.layers:
      x = layer(x)
    return x

In [64]:
class Config:
  def __init__(self):
    self.emb_dim = 32
    self.head_dim = 8
    self.no_of_encoders = 2

config = Config()

In [66]:
corpus = [
    'Time flies like an arrow',
    'fruit flies like a banana',
    'my name is faizan'
]

vocab = tokenizer(corpus)
tokens = tokenize(vocab, corpus[2])

encoder = Transformer_Encoder(len(vocab), config)
encoder(torch.tensor([tokens]), len(tokens)).size()

torch.Size([1, 4, 32])