In [None]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

import sys, os, time, warnings, pdb, pickle, random, math, re, json
warnings.filterwarnings('ignore')
sys.path.insert(0, '../scripts')

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

np.set_printoptions(precision=4)
sns.set_style("darkgrid")
%matplotlib inline

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

In [None]:
# InputEmbeddings, PositionalEncoding
from input import *
from internal import LayerNormalization, FeedForwardBlock, ResidualConnection

In [None]:
d_model = 4
vocab_size = 8
sos,eos,pad=1,2,3
seq_len = 10
dropout = 0.1
d_ff = 8

In [None]:
x = torch.tensor([[sos,4,6,7,6,4,pad,pad,pad,eos], [sos,4,5,5,7,7,5,7,pad,eos]])
emb = InputEmbeddings(d_model, vocab_size)
pe = PositionalEncoding(d_model, seq_len, dropout)
norm = LayerNormalization()
ffb = FeedForwardBlock(d_model, d_ff, dropout)
residual_connection = ResidualConnection(dropout)
residual_connection.eval()
sublayer = nn.Identity()

In [None]:
x = pe(emb(x))
y = norm(x)
y_res = residual_connection(x, sublayer)
torch.all(y_res == (y+x))

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model: int, h: int, dropout: float) -> None:
    super().__init__()
    self.d_model = d_model
    self.h = h
    assert d_model % h == 0, "d_model is not divisible by h"
    self.d_k = d_model // h
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
    self.w_o = nn.Linear(d_model, d_model)
    self.dropout = nn.Dropout(dropout)

  @staticmethod
  def attention(query, key, value, mask, dropout: nn.Dropout):
    d_k = query.shape[-1]
    # token to token attention, so matrix is seq_len, seq_len
    # (batch, h, seq_len, d_k) -> (batch, h, seq_len, seq_len)
    attention_scores = (query @ key.transpose(-2, -1)) / np.sqrt(d_k)
    if mask is not None:
      # fill those masked location with large negative number so it softmaxes to zero
      attention_scores.masked_fill(mask == 0, -1e9)
    attention_scores = attention_scores.softmax(dim=-1)  # (batch, h, seq_len, seq_len)
    if dropout is not None:
      attention_scores = dropout(attention_scores)

    return (attention_scores @ value), attention_scores

  def forward(self, q, k, v, mask):
    query = self.w_q(q) # (batch, seq_len, d_model) -> (batch, seq_len, d_model)
    key = self.w_k(k)
    value = self.w_v(v)

    # 1) reshape q,k,v into separate heads
    # 2) put the head dim as the 2nd dim
    # 3) each head will see part of the embedding of ALL inputs in the batch
    # (batch, seq_len, d_model) -> (batch, seq_len, h, d_k) -> (batch, h, seq_len, d_k)
    query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
    key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
    value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2) 

    x, self.attention_scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout)
    # (batch, h, seq_len, d_k) -> (batch, seq_len, h, d_k) -> (batch, seq_len, d_model)
    x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.d_model)

    # (batch, seq_len, d_model) -> (batch, seq_len, d_model)
    return self.w_o(x)

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self, self_attention_block: MultiHeadAttention, feed_forward_block: FeedFowardBlock, dropout: float):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block    
    self.residual_connections = nn.ModuleList([
      ResidualConnection(dropout) for _ in range(2)
    ])

  def forward(self, x, src_mask):
    # src_mask for masking pad tokens
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
    x = self.residual_connections[1](x, self.feed_forward_block)
    return x

In [None]:
class Encoder(nn.Module):
  def __init__(self, layers: nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization()

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

In [None]:
class DecoderBlock(nn.Module):
  def __init__(self, self_attention_block: MultiHeadAttention, cross_attention_block: MultiHeadAttention, feed_forward_block: FeedFowardBlock, dropout):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([
      ResidualConnection(dropout) for _ in range(3)
    ])

  def forward(self, x, encoder_output, src_mask, target_mask):
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, target_mask))
    x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
    x = self.residual_connections[2](x, self.feed_forward_block)
    return x

In [None]:
class Decoder(nn.Module):
  def __init__(self, layers: nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization

  def forward(self, x, encoder_output, src_mask, target_mask):
    for layer in self.layers:
      x = layer(x, encoder_output, src_mask, target_mask)
    return self.norm(x)

In [None]:
class ProjectionLayer(nn.Module):
  def __init__(self, d_model: int, vocab_size: int):
    super().__init__()
    self.proj = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    # (batch, seq_len, d_model) -> (batch, seq_len, vocab_size)
    return torch.log_softmax(self.proj(x), dim=-1)
    

In [None]:
class Transformer(nn.Module):
  def __init__(self, enocder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, target_embed: InputEmbeddings, src_pos: PositionalEncoding, target_pos: PositionalEncoding, projection_layer: ProjectionLayer):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.target_embed = target_embed
    self.src_pos = src_pos
    self.target_pos = target_pos
    self.projection_layer = projection_layer

  def encode(self, src, src_mask):
    src = self.src_embd(src)
    src = self.src_pos(src)
    return self.encode(src, src_mask)

  def decode(self, encoder_output, src_mask, target, target_mask):
    target = self.target_embd(target)
    target = self.target_pos(target)    
    return self.decode(target, encoder_output, src_mask, target_mask)

  def project(self, x):
    return self.projection_layer(x) 

In [None]:
def build_transformer(src_vocab_size: int, target_vocab_size: int, src_seq_len: int, target_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
  # embedding layers
  src_embed = InputEmbeddings(d_model, src_vocab_size)
  target_embed = InputEmbeddings(d_model, target_vocab_size)
  # positional encoding layers
  src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
  target_pos = PositionalEncoding(d_model, target_seq_len, dropout)
  # encoder blocks
  encoder_blocks = []
  for _ in range(N):
    encoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
    feed_forward_block = FeedFowardBlock(d_model, d_ff, dropout)
    encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
    encoder_blocks.append(encoder_block)
    # decoder blocks
  decoder_blocks = []
  for _ in range(N):
    decoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
    decoder_cross_attention_block = MultiHeadAttention(d_model, h, dropout)
    feed_forward_block = FeedFowardBlock(d_model, d_ff, dropout)
    decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
    decoder_blocks.append(decoder_block)

  encoder = Encoder(nn.ModuleList(encoder_blocks))
  decoder = Decoder(nn.ModuleList(decoder_blocks))
  projection_layer = ProjectionLayer(d_model, target_vocab_size)

  transformer = Transformer(encoder, decoder, src_embed, target_embed, src_pos, target_pos, projection_layer)

  for p in transformer.parameters():
    if p.dim() > 1:
      nn.init.xavier_uniform_(p)

  return transformer

In [None]:
batch, seq_len, h, d_model = 2, 6, 2,4

In [None]:
d_k = d_model // h

In [None]:
w_q = nn.Linear(d_model, d_model)

In [None]:
q = torch.rand(batch, seq_len, d_model)

In [None]:
query = w_q(q)