# Transformer From Scratch

## Imports & Inits

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pdb, math

import numpy as np
np.set_printoptions(precision=4)

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context(context="talk")
%matplotlib inline

import torch;
# assert(torch.cuda.is_available())
from torch import nn
from torch.nn import functional as F

In [3]:
def sequence_mask(x, valid_len, value=0):
  maxlen = x.shape[1]
  mask = torch.arange((maxlen), dtype=torch.float32)[None, :] >= valid_len[:, None]
  x[mask] = value
  return x

def masked_softmax(x, valid_len):
  if valid_len is None:
    return F.softmax(x, dim=-1)
  else:
    shape = x.shape
    if valid_len.dim() == 1:
      valid_len = torch.repeat_interleave(valid_len, repeats=shape[1], dim=0)
    else:
      valid_len = valid_len.reshape(-1)
  
  x = sequence_mask(x.reshape(-1, shape[-1]), valid_len, value=-1e6)
  return F.softmax(x.reshape(shape), dim=-1)      

In [11]:
class DotProductAttention(nn.Module):
  def __init__(self, dropout, **kwargs):
    super(DotProductAttention, self).__init__(**kwargs)
    self.dropout = nn.Dropout(dropout)
    
  def forward(self, query, key, value, valid_len=None):
#     pdb.set_trace()
    d = query.shape[-1]
    scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(d)
    attn_wts = self.dropout(masked_softmax(scores, valid_len))
    return torch.bmm(attn_wts, value)

In [24]:
attn = DotProductAttention(dropout=0.5)
attn.eval()
keys = torch.ones(2,10,2)
values = torch.arange(40, dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)
queries = torch.ones(2,1,2)
value_len = torch.tensor([2, 6])

In [25]:
keys.shape

torch.Size([2, 10, 2])

In [26]:
values.shape

torch.Size([2, 10, 4])

In [17]:
at = attn(queries, keys, values, value_len)
at.shape

torch.Size([2, 1, 4])

In [18]:
attn = DotProductAttention(dropout=0.5)
attn.eval()
keys = torch.rand(2, 1, 4)
values = torch.rand(2, 1, 4)
queries = torch.rand(2, 1, 4)
value_len = torch.tensor([2, 6])

In [19]:
at = attn(queries, keys, values, value_len)
at.shape

torch.Size([2, 1, 4])