In [None]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

import sys, os, time, warnings, pdb, pickle, random, math, re, json
warnings.filterwarnings('ignore')
sys.path.insert(0, '../scripts')

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

np.set_printoptions(precision=4)
sns.set_style("darkgrid")
%matplotlib inline

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

In [None]:
# InputEmbeddings, PositionalEncoding
from input import InputEmbeddings, PositionalEncoding
from internal import LayerNormalization, FeedForwardBlock, ResidualConnection, MultiHeadAttention
from encoder import EncoderBlock, Encoder
from decoder import DecoderBlock, Decoder
from transformer import ProjectionLayer, Transformer, build_transformer

In [None]:
d_model = 4
vocab_size = 8
sos,eos,pad=0,1,2
seq_len = 10
dropout = 0.1
d_ff = 8
h = 2
N = 2

In [None]:
# t = build_transformer(vocab_size, vocab_size, seq_len, seq_len, d_model, N, h, dropout, d_ff)
t = build_transformer(vocab_size, vocab_size, seq_len, seq_len)

In [None]:
t

In [None]:
x = torch.tensor([[sos,3,3,5,7,4,pad,pad,pad,eos], [sos,5,7,7,6,6,4,3,pad,eos]])
y = torch.tensor([[sos,7,7,3,4,pad,pad,pad,pad,eos], [sos,6,4,4,3,7,5,4,pad,eos]])
src_emb = InputEmbeddings(d_model, vocab_size)
target_emb = InputEmbeddings(d_model, vocab_size)
src_pe = PositionalEncoding(d_model, seq_len, dropout)
target_pe = PositionalEncoding(d_model, seq_len, dropout)
x = src_pe(src_emb(x))
y = target_pe(target_emb(y))

In [None]:
ffb_enc = FeedForwardBlock(d_model, d_ff, dropout)
ffb_dec = FeedForwardBlock(d_model, d_ff, dropout)
self_attn = MultiHeadAttention(d_model, h, dropout)
cross_attn = MultiHeadAttention(d_model, h, dropout)

In [None]:
src_mask = None
target_mask = torch.triu(torch.ones(seq_len, seq_len) * -1e9, diagonal=1)

In [None]:
encoder_blocks = []
for _ in range(N):    
  encoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
  feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
  encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
  encoder_blocks.append(encoder_block)

decoder_blocks = []
for _ in range(N):
  decoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
  decoder_cross_attention_block = MultiHeadAttention(d_model, h, dropout)
  feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
  decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
  decoder_blocks.append(decoder_block)

encoder = Encoder(nn.ModuleList(encoder_blocks))
decoder = Decoder(nn.ModuleList(decoder_blocks))
projection = ProjectionLayer(d_model, vocab_size)

In [None]:
t = Transformer(encoder, decoder, src_emb, target_emb, src_pe, target_pe, projection)

In [None]:
enc_out = encoder(x, src_mask)
dec_out = decoder(y, enc_out, src_mask, target_mask)

In [None]:
torch.argmax(out, dim=2)

In [None]:
attn = MultiHeadAttention(d_model, h, dropout)
attn(x, x, x)

In [None]:
norm = LayerNormalization()
residual_connection = ResidualConnection(dropout)
residual_connection.eval()
sublayer = nn.Identity()

In [None]:
y = norm(x)
y_res = residual_connection(x, sublayer)
torch.all(y_res == (y+x))

In [None]:
self_attn = MultiHeadAttention(d_model, h, dropout)
cross_attn = MultiHeadAttention(d_model, h, dropout)
src_mask = None

In [None]:
batch, seq_len, h, d_model = 2, 6, 2,4

In [None]:
d_k = d_model // h

In [None]:
w_q = nn.Linear(d_model, d_model)

In [None]:
q = torch.rand(batch, seq_len, d_model)

In [None]:
query = w_q(q)