# Transformer From Scratch

## Imports & Inits

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pdb, math
from copy import deepcopy

import numpy as np
np.set_printoptions(precision=4)

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context(context="talk")
%matplotlib inline

import torch;
assert(torch.cuda.is_available())
from torch import nn
from torch.nn import functional as F

In [None]:
from transformer import *

In [None]:
def transpose_qkv(x, n_heads):
  x = x.reshape(x.shape[0], x.shape[1], n_heads, -1)
  x = x.permute(0, 2, 1, 3)
  x = x.reshape(-1, x.shape[2], x.shape[3])
  return x

def transpose_output(x, n_heads):
  x = x.reshape(-1, n_heads, x.shape[1], x.shape[2])
  x = x.permute(0, 2, 1, 3)
  x = x.reshape(x.shape[0], x.shape[1], -1)
  return x

In [None]:
d_model = 6
bs = 2
seq_len = 4
n_heads = 3

dropout = 0.5
valid_len = torch.tensor([2,3])

x = torch.rand(bs, seq_len, d_model)
query,key,value = x,x,x

In [None]:
linears = clone_module(nn.Linear(d_model, d_model), 4)

In [None]:
class MultiHeadAttention1(nn.Module):
  def __init__(self, d_model, n_heads, dropout, bias=False, **kwargs):
    super(MultiHeadAttention1, self).__init__(**kwargs)
    self.n_heads = n_heads
    self.d_model = d_model
    
    self.attention = DotProductAttention(dropout)
#     self.W_q = nn.Linear(d_model, d_model, bias=bias)
#     self.W_k = nn.Linear(d_model, d_model, bias=bias)
#     self.W_v = nn.Linear(d_model, d_model, bias=bias)
#     self.W_o = nn.Linear(d_model, d_model, bias=bias)
    self.W_q,self.W_k,self.W_v,self.W_o = linears
    
  def forward(self, query, key, value, valid_len):
    query = transpose_qkv(self.W_q(query), self.n_heads)
    key = transpose_qkv(self.W_k(key), self.n_heads)
    value = transpose_qkv(self.W_v(value), self.n_heads)
    
    if valid_len is not None:
      valid_len = torch.repeat_interleave(valid_len, repeats=self.n_heads, dim=0)
    
    out = self.attention(query, key, value, valid_len)
    out = transpose_output(out, self.n_heads)
    out = self.W_o(out)
    return out    

In [None]:
class MultiHeadAttention2(nn.Module):
  def __init__(self, d_model, n_heads, dropout, bias=False, **kwargs):
    super(MultiHeadAttention2, self).__init__(**kwargs)
    self.n_heads = n_heads
    self.attn = DotProductAttention(dropout)
#     self.linears = clones(nn.Linear(d_model, d_model), 4)
    self.linears = linears
    
  def forward(self, query, key, value, valid_len):
    query, key, value = [
      l(x).reshape(bs, -1, n_heads, d_model//n_heads)
      .transpose(1, 2).reshape(bs * n_heads, seq_len, -1)
      for l,x in zip(self.linears, (query, key, value))
    ]

    if valid_len is not None:
      valid_len = torch.repeat_interleave(valid_len, repeats=self.n_heads, dim=0)
      
    out = self.attn(query, key, value, valid_len)
    out = out.reshape(bs, n_heads, seq_len, -1).transpose(1,2).reshape(bs, seq_len, -1)
    out = self.linears[-1](out)
    return out    

In [None]:
cell = MultiHeadAttention1(d_model, n_heads, dropout).eval()
o1 = cell(query, key, value, valid_len)

In [None]:
cell = MultiHeadAttention2(d_model, n_heads, dropout).eval()
o2 = cell(query, key, value, valid_len)

In [None]:
torch.all(o1 == o2).item()

In [None]:
cell = MultiHeadAttention(d_model, n_heads, dropout).eval()
o = cell(query, key, value, valid_len)
o.shape

In [None]:
cell = MultiHeadAttention(100, 10, 0.5)
cell.eval()
X = torch.ones((2, 4, 100))
valid_len = torch.tensor([2, 3])
cell(X, X, X, valid_len).shape

In [None]:
w_q = nn.Linear(d_model, d_model, bias=False)
q = w_q(x)
q.shape

In [None]:
q1 = transpose_qkv(q, n_heads)
q2 = q.reshape(bs, -1, n_heads, d_model//n_heads).transpose(1, 2).reshape(bs * n_heads, seq_len, -1)
q1 == q2

In [None]:
o1 = transpose_output(q1, n_heads)
o2 = q2.reshape(bs, n_heads, seq_len, -1).transpose(1,2).reshape(bs, seq_len, -1)
o1 == o2

In [None]:
o1 == o2

In [None]:
cell = MultiHeadAttention(d_model, n_heads, dropout).eval()
o1 = cell(query, key, value, valid_len)

In [None]:
cell = MultiHeadAttention1(d_model, n_heads, dropout).eval()
o2 = cell(query, key, value, valid_len)