<a href="https://colab.research.google.com/github/theabhinav0231/Transformer/blob/main/architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

## **Multi-Head Attention**

In [2]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    assert d_model % num_heads == 0 # d_model must be divisible by num_heads

    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads

    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
    self.w_o = nn.Linear(d_model, d_model)

  def scaled_dot_product_attention(self, Q, K, V, mask=None):
    atten_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
    if mask is not None:
      atten_scores = atten_scores.masked_fill(mask == 0, -1e9)
    atten_probs = torch.softmax(atten_scores, dim = -1)
    output = torch.matmul(atten_probs, V)
    return output

  def split_heads(self, x):
    batch_size, seq_length, d_model = x.size()
    return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

  def combine_heads(self, x):
    batch_size, _, seq_length, d_k = x.size()
    return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

  def forward(self, Q, K, V, mask=None):
    Q = self.split_heads(self.w_q(Q))
    K = self.split_heads(self.w_k(K))
    V = self.split_heads(self.w_v(V))

    atten_output = self.scaled_dot_product_attention(Q, K, V, mask)
    output = self.w_o(self.combine_heads(atten_output))
    return output

## **Position-wise Feed-Forward Networks**

In [3]:
class PositionWiseFeedForward(nn.Module):
  def __init__(self, d_model, d_ff):
    super().__init__()
    self.layer1 = nn.Linear(d_model, d_ff)
    self.layer2 = nn.Linear(d_ff, d_model)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.relu(self.layer_2(self.layer_1(x)))

## **Positional Encoding**

In [4]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_seq_length):
    super().__init__()

    pe = torch.zeros(max_seq_length, d_model)
    position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    self.register_buffer("pe", pe.unsqueeze(0))

  def foward(self, x):
    return x + self.pe[:, :x.size(1)]

## **Encoder**

In [5]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super().__init__()
    self.self_atten = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
    self.norm_1 = nn.LayerNorm(d_model)
    self.norm_2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    atten_output = self.self_atten(x, x, x, mask)
    x = self.norm_1(x + self.dropout(atten_output))
    ff_output = self.feed_forward(x)
    x = self.norm_2(x + self.dropout(ff_output))
    return x

## **Decoder**

In [6]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super().__init__()
    self.self_atten = MultiHeadAttention(d_model, num_heads)
    self.cross_atten = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
    self.norm_1 = nn.LayerNorm(d_model)
    self.norm_2 = nn.LayerNorm(d_model)
    self.norm_3 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, enc_output, src_mask, tgt_mask):
    atten_output = self.self_atten(x, x, x, tgt_mask)
    x = self.norm_1(x + self.dropout(atten_output))
    atten_output = self.cross_atten(x, enc_output, src_mask)
    x = self.norm_2(x + self.dropout(atten_output))
    ff_output = self.feed_forward(x)
    x = self.norm_3(x + self.dropout(ff_output))
    return x

## **Transformer**

In [7]:
class Transformer(nn.Module):
  def __init__(self, src_vocab_size, tgt_vocab_size, d_model,
               num_heads, num_layers, d_ff, max_seq_length, dropout):
    super().__init__()
    self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
    self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

    self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
    self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

    self.fc = nn.Linear(d_model, tgt_vocab_size)
    self.dropout = nn.Dropout(dropout)

  def generate_mask(self, src, tgt):
    src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
    tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
    seq_length = tgt.size(1)
    nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
    tgt_mask = tgt_mask & nopeak_mask
    return src_mask, tgt_mask

  def forward(self, src, tgt):
    src_mask, tgt_mask = self.generate_mask(src, tgt)
    src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
    tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

    enc_output = src_embedded
    for enc_layer in self.encoder_layers:
      enc_output = enc_layer(enc_output, src_mask)

    dec_output = tgt_embedded
    for dec_layer in self.decoder_layers:
      dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

    output = self.fc(dec_output)
    return output