In [77]:
import torch
import torch.nn.functional as F
from torch import Tensor
import torch.nn as nn 

def scaled_dot_product_attention(query : Tensor, key : Tensor, value : Tensor) -> Tensor:
  prod = query.bmm(key.transpose(1, 2))
  scale = query.size(-1) ** 0.5
  softmax = F.softmax(prod/scale, dim = -1)
  return softmax.bmm(value)

In [78]:
class AttentionHead(nn.Module):
  def __init__(self, dim_in : int, dim_k : int, dim_v : int):
    super().__init__()
    self.q = nn.Linear(dim_in, dim_k)
    self.k = nn.Linear(dim_in, dim_k)
    self.v = nn.Linear(dim_in, dim_v)

  def forward(self, query : Tensor, key : Tensor, value : Tensor) -> Tensor:
    return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value))  

In [79]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads : int, dim_in : int, dim_k : int, dim_v : int):
    super().__init__()
    self.heads = nn.ModuleList(
        [AttentionHead(dim_in, dim_k, dim_v) for _ in range(num_heads)]
    )
    self.linear = nn.Linear(num_heads * dim_v, dim_in)

  def forward(self, query : Tensor, key : Tensor, value : Tensor) -> Tensor:
    return self.linear(torch.cat([
                 h(query, key, value) for h in self.heads
    ], dim = -1))  

In [80]:
def positional_encoding(seq_len : int, dim_model : int, device : torch.device("cpu")) -> Tensor:
  pos = torch.arange(seq_len, dtype = torch.float, device = device).reshape(1, -1, 1) #axis 1 only
  dim = torch.arange(dim_model, dtype = torch.float, device = device).reshape(1, 1, -1) #axis 2
  #print(pos, dim)
  angle = (pos/1e4) ** (dim // dim_model)

  return torch.where(dim.long() % 2 == 0, -torch.sin(angle), torch.cos(angle))

In [81]:
#FFN in the model

def ffn(dim_input : int = 512, dim_hidden : int = 2048) -> nn.Module:
  return nn.Sequential(
      nn.Linear(dim_input, dim_hidden),
      nn.ReLU(),
      nn.Linear(dim_hidden, dim_input)
  )

In [82]:
#Layer Norm 

class Residual(nn.Module):
  def __init__(self, sublayer : nn.Module, dimension : int, dropout : float = 0.1):
    super().__init__()
    self.sublayer = sublayer
    self.norm = nn.LayerNorm(dimension)
    self.dropout = nn.Dropout(dropout)

  def forward(self, *tensors : Tensor) -> Tensor:
    return self.norm(tensors[-1] + self.dropout(self.sublayer(*tensors)))  

In [83]:
class TEncoderLayer(nn.Module):
  def __init__(self, 
               dim_model : int = 512,
               num_heads : int = 6,
               dim_hidden : int = 2048,
               dropout : float = 0.1):
    super().__init__()
    dim_k = dim_v = dim_model // num_heads
    self.attention = Residual(
        MultiHeadAttention(num_heads, dim_model, dim_k, dim_v),
        dimension = dim_model,
        dropout = dropout
    )

    self.ffn = Residual(
        ffn(dim_model, dim_hidden),
        dimension = dim_model,
        dropout = dropout
    )


  def forward(self, src : Tensor) -> Tensor:
    src = self.attention(src, src, src)
    return self.ffn(src)


class TEncoder(nn.Module):
  def __init__(self, 
               num_layers : int = 6,
               dim_model : int = 512,
               num_heads : int = 8,
               dim_hidden : int = 2048,
               dropout : float = 0.1):
    super().__init__()
    self.layers = nn.ModuleList([
                                 TEncoderLayer(dim_model, num_heads, dim_hidden, dropout)
                                 for _ in range(num_layers)
    ])    

  def forward(self, src : Tensor) -> Tensor:
    seq_len, dimension = src.size(1), src.size(2)
    src += positional_encoding(seq_len, dimension, device = "cpu")
    for layer in self.layers:
      src = layer(src)

    return src     


In [84]:
class TDecoderLayer(nn.Module):
  def __init__(self, 
               dim_model : int = 512,
               num_heads : int = 6,
               dim_hidden : int = 2048,
               dropout : float = 0.1):
    super().__init__()
    dim_k = dim_v = dim_model // num_heads

    self.attention1 = Residual(
        MultiHeadAttention(num_heads, dim_model, dim_k, dim_v),
        dimension = dim_model,
        dropout = dropout
    )

    self.attention2 = Residual(
        MultiHeadAttention(num_heads, dim_model, dim_k, dim_v),
        dimension = dim_model,
        dropout = dropout
    )    

    self.ffn = Residual(
        ffn(dim_model, dim_hidden),
        dimension = dim_model,
        dropout = dropout
    )


  def forward(self, target : Tensor, mem : Tensor) -> Tensor:
    target = self.attention1(target, target, target)
    target = self.attention2(mem, mem, mem)

    return self.ffn(target)


class TDecoder(nn.Module):
  def __init__(self, 
               num_layers : int = 6,
               dim_model : int = 512,
               num_heads : int = 8,
               dim_hidden : int = 2048,
               dropout : float = 0.1):
    super().__init__()
    self.layers = nn.ModuleList([
                                 TDecoderLayer(dim_model, num_heads, dim_hidden, dropout)
                                 for _ in range(num_layers)
    ])   

    self.linear = nn.Linear(dim_model, dim_model) 

  def forward(self, target : Tensor, mem : Tensor) -> Tensor:
    seq_len, dimension = target.size(1), target.size(2)
    target += positional_encoding(seq_len, dimension, device = "cpu")
    for layer in self.layers:
      target = layer(target, mem)

    return F.softmax(self.linear(target), dim = -1)   

In [85]:
#Wrap up in a single transformer


class Transformer(nn.Module):
    def __init__(
        self, 
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_model: int = 512, 
        num_heads: int = 6, 
        dim_hidden: int = 2048, 
        dropout: float = 0.1
    ):
        super().__init__()
        self.encoder = TEncoder(
            num_layers=num_encoder_layers,
            dim_model=dim_model,
            num_heads=num_heads,
            dim_hidden=dim_hidden,
            dropout=dropout
        )
        self.decoder = TDecoder(
            num_layers=num_decoder_layers,
            dim_model=dim_model,
            num_heads=num_heads,
            dim_hidden=dim_hidden,
            dropout=dropout
        )

    def forward(self, src : Tensor, target : Tensor) -> Tensor:
      memory = self.encoder(src)
      output = self.decoder(target, memory)
      return output    

## **Check (if model works right)**

In [86]:
src = torch.rand(64, 16, 512)
target = torch.rand(64, 16, 512)
out = Transformer()(src, target)
print(out.shape)

torch.Size([64, 16, 512])


In [86]:
#output sequence needs to be of same dims and is verified