In [1]:
import torch
import math
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchtext.datasets import Multi30k, IWSLT2016
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, Vocab
from torch.utils.data import DataLoader
import math as m
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [2]:
def formula(Q, K, V, dim = 4):
  QK = torch.matmul(Q, K.T)
  matmul = QK / math.sqrt(dim)
  weights = F.softmax(matmul, dim = -1)
  out = torch.matmul(weights, V)
  return out, weights

In [3]:
def test(Q, K, V):
  out, weights = formula(Q, K, V)
  out, weights = out.numpy(), weights.numpy()
  print(np.round(out, 4))
  print(np.round(weights, 4))

In [6]:
class MultiHeadAttention(nn.Module):
  def __init__(self, dim_model = 8, number_of_heads = 4, dropout = 0.2):
    super().__init__()
    self.d = dim_model // number_of_heads
    self.dropout = nn.Dropout(dropout)
    self.Qs = nn.ModuleList([nn.Linear(dim_model, self.d) for _ in range(number_of_heads)])
    self.Ks = nn.ModuleList([nn.Linear(dim_model, self.d) for _ in range(number_of_heads)])
    self.Vs = nn.ModuleList([nn.Linear(dim_model, self.d) for _ in range(number_of_heads)])
    self.mha_linear = nn.Linear(dim_model, dim_model)

  def formula(self, Q, K, V):
    QK = torch.matmul(Q, K.permute(0, 2, 1))
    matmul = QK / math.sqrt(self.d)
    weights = F.softmax(matmul, dim = -1)
    out = torch.matmul(weights, V)
    return out, weights

  def forward(self, pre_q, pre_k, pre_v, mask = None):
    # shape(x) = [B x seq_len x D]
    Q = [Q(pre_q) for Q in self.Qs]
    K = [K(pre_k) for K in self.Ks]
    V = [V(pre_v) for V in self.Vs]
    output_per_head = []
    weights_per_head = []
    for q, k, v in zip(Q, K, V):
      output, weight = self.formula(q, k, v)
      output_per_head.append(output)
      weights_per_head.append(weight)

    output = torch.cat(output_per_head, -1)
    weights = torch.stack(weights_per_head).permute(1, 0, 2, 3)
    x = self.dropout(self.mha_linaer(output))
    return x, output


In [7]:
class ResidualNorm(nn.Module):
    def __init__(self, dim_model, dropout=0.2):
        super().__init__()
        self.layer_norm = nn.LayerNorm(dim_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, residual):
      return self.layer_norm(self.dropout(x) + residual)

In [8]:
class EncoderLayer(nn.Module):
  def __init__(self, dim_model, number_of_heads, d_ff, dropout = 0.2):
    super().__init__()
    self.norm_1 = ResidualNorm(dim_model, dropout)
    self.norm_2 = ResidualNorm(dim_model, dropout)
    self.mha = MultiHeadAttention(dim_model, number_of_heads)
    self.ff = nn.Sequential(
            nn.Linear(dim_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, dim_model)
        )

  def forward(self, x, mask):
    mha, encoder_weights = self.mha(x, x, x, mask = mask)
    norm1 = self.norm_1(mha, x)
    ff = self.ff(norm1)
    norm2 = self.norm_2(ff, norm1)
    return norm2, encoder_weights

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout=0.2, max_seq_len = 200, device = "cuda"):
        super().__init__()
        self.dim_model = dim_model
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_seq_len, dim_model).to(device)
        pos = torch.arange(0, max_seq_len).unsqueeze(1).float()
        two_i = torch.arange(0, dim_model, step=2).float()
        div_term = torch.pow(10000, (two_i / torch.Tensor([dim_model]))).float()
        pe[:, 0::2] = torch.sin(pos/div_term)
        pe[:, 1::2] = torch.cos(pos/div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        one_batch_pe: torch.Tensor = self.pe[:, :x.shape[1]].detach()
        repeated_pe = one_batch_pe.repeat([x.shape[0], 1, 1]).detach()
        x = x.add(repeated_pe)
        return self.dropout(x)

In [10]:
class DecoderLayer(nn.Module):
    def __init__(self, dim_model, num_heads, feed_forward_dim, dropout=0.2):
        super().__init__()
        self.norm_1 = ResidualNorm(dim_model)
        self.norm_2 = ResidualNorm(dim_model)
        self.norm_3 = ResidualNorm(dim_model)
        self.masked_attention = MultiHeadAttention(dim_model, num_heads, dropout)
        self.encoder_decoder_attention = MultiHeadAttention(dim_model, num_heads, dropout)

        self.feed_forward = nn.Sequential(
          nn.Linear(dim_model, feed_forward_dim),
          nn.ReLU(),
          nn.Dropout(dropout),
          nn.Linear(feed_forward_dim, dim_model))

    def forward(self, input_tensor, encoder_outputs, target_mask, source_mask):
      masked_attention_output, masked_attention_weights = self.masked_attention(input_tensor, input_tensor, input_tensor, mask=target_mask)
      norm1 = self.norm_1(masked_attention_output, input_tensor)
      encoder_decoder_attention_output, encoder_decoder_attention_weights = self.encoder_decoder_attention(norm1, encoder_outputs, encoder_outputs, mask=source_mask)
      norm2 = self.norm_2(encoder_decoder_attention_output, norm1)
      feed_forward_output = self.feed_forward(norm2)
      norm3 = self.norm_3(feed_forward_output, norm2)
      return norm3, masked_attention_weights, encoder_decoder_attention_weights

In [11]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, padding_idx, dim_model):
        super().__init__()
        self.d_model = dim_model
        self.embed = nn.Embedding(vocab_size, dim_model, padding_idx=padding_idx)

    def forward(self, x):
        embedding = self.embed(x)
        return embedding * math.sqrt(self.dim_model)

In [12]:
class Decoder(nn.Module):
    def __init__(self, embedding, model_dim, num_heads, num_layers, feed_forward_dim, device="cuda", dropout=0.2):
        super().__init__()
        self.embedding = embedding
        self.positional_encoding = PositionalEncoding(model_dim, device=device)
        self.dropout = nn.Dropout(dropout)
        self.decoder_layers = nn.ModuleList([DecoderLayer(
            model_dim,
            num_heads,
            feed_forward_dim,
            dropout,
        ) for _ in range(num_layers)])

    def forward(self, input_tensor, encoder_output, target_mask, source_mask):
        embeddings = self.embedding(input_tensor)
        encoding = self.positional_encoding(embeddings)
        for decoder in self.decoder_layers:
            encoding, masked_attention_weights, encoder_decoder_attention_weights = decoder(encoding, encoder_output, target_mask, source_mask)
        return encoding, masked_attention_weights, encoder_decoder_attention_weights

In [13]:
class Encoder(nn.Module):
    def __init__(self, embedding, model_dim,
                 num_heads, num_layers,
                 feed_forward_dim, device="cuda", dropout = 0.2):
        super().__init__()
        self.embedding = embedding
        self.positional_encoding = PositionalEncoding(
            model_dim, device=device)
        self.encoder_layers = nn.ModuleList([EncoderLayer(
            model_dim,
            num_heads,
            feed_forward_dim,
            dropout
        ) for _ in range(num_layers)])
    def forward(self, input_tensor, mask=None):
        embeddings = self.embedding(input_tensor)
        encoding = self.positional_encoding(embeddings)
        for encoder in self.encoder_layers:
            encoding, encoder_attention_weights = encoder(encoding, mask)
        return encoding, encoder_attention_weights

In [14]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_len, trg_vocab_len, d_model, d_ff,
                 num_layers, num_heads, src_pad_idx, trg_pad_idx, dropout=0.2, device="cuda"):
        super().__init__()
        self.num_heads = num_heads
        self.device = device
        encoder_Embedding = Embeddings(src_vocab_len, src_pad_idx, d_model)
        decoder_Embedding = Embeddings(trg_vocab_len, trg_pad_idx, d_model)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.encoder = Encoder(encoder_Embedding, d_model, num_heads, num_layers, d_ff, device, dropout)
        self.decoder = Decoder(decoder_Embedding, d_model, num_heads, num_layers, d_ff, device, dropout)
        self.linear_layer = nn.Linear(d_model, trg_vocab_len)
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def create_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1)
        if self.efficient_mha:
            src_mask = src_mask.unsqueeze(2)
        return src_mask

    def create_trg_mask(self, trg):
        if self.efficient_mha:
            trg_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
            mask = torch.ones((1, self.num_heads, trg.shape[1], trg.shape[1])).triu(1).to(self.device)
        else:
            trg_mask = (trg != self.trg_pad_idx).unsqueeze(1)
            mask = torch.ones((1, trg.shape[1], trg.shape[1])).triu(1).to(self.device)
        mask = mask == 0
        trg_mask = trg_mask & mask
        return trg_mask

    def forward(self, src, trg):
        src_mask = self.create_src_mask(src)
        trg_mask = self.create_trg_mask(trg)
        encoder_outputs, encoder_mha_attn_weights = self.encoder(src, src_mask)
        decoder_outputs, _, enc_dec_mha_attn_weights = self.decoder(trg, encoder_outputs, trg_mask, src_mask)
        logits = self.linear_layer(decoder_outputs)
        return logits