<a href="https://colab.research.google.com/github/thomasshin/NLP_Study/blob/main/Transformers/Vanilla_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

In [13]:
def clones(module, N): #produces N identical layers
  return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [14]:
def scaled_dot_product_attention(q: torch.Tensor,
                                 k: torch.Tensor,
                                 v: torch.Tensor,
                                 mask: torch.Tensor = None,
                                 dropout: float = 0.1) -> torch.Tensor:
  #shape of q: [B, num_heads, q_len, d_k]
  #shape of k: [B, num_heads, k_len, d_k]
  #shape of v: [B, num_heads, k_len, d_v]
  k_dim = k.size()[-1]
  attn = torch.matmul(q, k.transpose(-2,-1))
  scale = torch.sqrt(k_dim)
  attn_scaled = attn / scale

  if mask != None: #if there is mask
    attn_scaled = attn_scaled.masked_fill_(mask==0, torch.finfo(attn.dtype).min) #fill 0 positions with the smallest value possible

  attn_weighted = torch.softmax(attn_scaled, dim=-1) #[B, num_heads, q_len, k_len]
  attn_weighted = nn.Dropout(attn_weighted, p=dropout) #nn.dropout more convenient than F.dropout (train vs eval mode)
  output = torch.matmul(attn_weighted, v) #[B, num_heads, q_len, d_v] (blended vector)

  return output, attn_weighted

In [15]:
class Attention(nn.Module):
  def __init__(self, d_model, num_heads, dropout=0.1):
    super().__init__()
    assert d_model % num_heads == 0
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads
    self.wq = nn.Linear(d_model, d_model)
    self.wk = nn.Linear(d_model, d_model)
    self.wv = nn.Linear(d_model, d_model)
    self.dropout = nn.Dropout(p=dropout)
    self.wo = nn.Linear(d_model, d_model) #output

  def split_heads(self, x, B): #current shape: [B, q_len, d_model] -> desired shape: [B, num_heads, seq_len, head_dim]
    x = x.view(B, -1, self.num_heads, self.head_dim) #[B, seq_len, num_heads, head_dim]
    x = x.permute(0, 2, 1, 3)
    return x

  def forward(self, q, k, v, mask=None):
    query = self.wq(q)
    key = self.wk(k)
    value = self.wv(v)
    #shape change to [B, num_heads, seq_len, head_dim]
    B = query.size()[0]
    qS = query.size()[1] #query seq len
    kS = key.size()[1] #key seq len
    query = self.split_heads(query, B)
    key = self.split_heads(key, B)
    value = self.split_heads(value, B)
    scaled_attn, attn_weighted = scaled_dot_product_attention(query, key, value, mask, self.dropout) #for scaled_attn, [B, num_heads, q_len, d_v] to [B, q_len, d_model]
    scaled_attn = scaled_attn.permute(0, 2, 1, 3) #[B, num_heads, q_len, d_v]
    scaled_attn_concat = scaled_attn.reshape(B, qS, -1) #[B, q_len, d_model]
    output = self.wo(scaled_attn_concat)

    return output, attn_weighted #output: [B, q_len, d_model], attn_weighted: [B, num_heads, q_len, k_len]

In [16]:
class TransformerEncoderLayer(nn.Module): #a single layer for Transformer encoder
  def __init__(self, d_model, num_heads, dim_ff, dropout=0.1):
    super().__init__()
    self.self_attn = Attention(d_model, num_heads, dropout)
    self.layer_norm_self_attn = nn.LayerNorm(d_model)
    self.fc1 = nn.Linear(d_model, dim_ff)
    self.fc2 = nn.Linear(dim_ff, d_model)
    self.act_fn = nn.ReLU()
    self.layer_norm_fc = nn.LayerNorm(d_model)
    self.dropout = nn.dropout(p=dropout)

  def forward(self, x, mask):
    residual = x
    x, attn_weighted = self.self_attn(q=x, k=x, v=x, mask=mask)
    x = self.dropout(x)
    x = self.layer_norm_self_attn(residual + x) #Post LN
    residual2 = x
    x = self.act_fn(self.fc1(x))
    x = self.dropout(x)
    x = self.fc2(x)
    x = self.dropout(x)
    x = self.layer_norm_fc(residual2 + x) #Post LN

    return x, attn_weighted

In [17]:
class TransformerEncoder(nn.Module): #a stack of encoder layers
  def __init__(self, num_layers, src_dim, d_model, num_heads, dropout, max_seq_len = 100, dim_ff=None):
    super().__init__()
    self.num_layers = num_layers
    self.tok_embed = nn.Embedding(src_dim, d_model)
    self.pos_embed = nn.Embedding(max_seq_len, d_model)
    self.dropout = nn.Dropout(p=0.1)
    if dim_ff == None:
      dim_ff = 4 * d_model #mentioned in the paper
    single_layer = TransformerEncoderLayer(d_model, num_heads, dim_ff, dropout)

    #prepare N subblocks
    self.layers = clones(single_layer, self.num_layers)

  def forward(self, x, mask=None):
    #x: [B, src_len]
    src_len = x.size()[1]
    B = x.size()[0]
    pos = torch.arange(0, src_len).unsqueeze(0).repeat(B, 1) #pos: [B, src_len]
    x = self.dropout((self.tok_embed(x)) + self.pos_embed(pos))
    layers_attn_scores = []

    #pass the input (and the mask) through each layer in turn
    for layer in self.layers:
      x, attn_weighted = layer(x, mask)
      layers_attn_scores.append(attn_weighted)

    return x, layers_attn_scores

In [18]:
class TransformerDecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, dropout, dim_ff=None, eps=1e-12):
    super().__init__()
    self.masked_self_attn = Attention(d_model, num_heads, dropout)
    self.layer_norm_self_attn = nn.LayerNorm(d_model)
    self.cross_attn = Attention(d_model, num_heads, dropout)
    self.layer_norm_cross_attn = nn.LayerNorm(d_model)
    self.fc1 = nn.Linear(d_model, dim_ff)
    self.fc2 = nn.Linear(dim_ff, d_model)
    self.layer_norm_ff = nn.LayerNorm(d_model)
    self.act_fn = nn.ReLU()
    self.dropout = nn.Dropout(p=dropout)
    if dim_ff == None:
      dim_ff = 4 * d_model

  def forward(self, x, look_ahead_mask, padding_mask, enc_output):
    residual = x
    x, attn_weighted_self = self.masked_self_attn(q=x, k=x, v=x, mask=look_ahead_mask)
    x = self.dropout(x)
    x = self.layer_norm_self_attn(residual + x)
    residual2 = x
    x, attn_weighted_cross = self.cross_attn(q=x, k=enc_output, v=enc_output, mask=padding_mask)
    x = self.dropout(x)
    x = self.layer_norm_cross_attn(residual2 + x)
    residual3 = x
    x = self.act_fn(self.fc1(x))
    x = nn.dropout(x)
    x = self.fc2(x)
    x = nn.dropout(x)
    x = self.layer_norm_ff(residual + x)

    return x, attn_weighted_self, attn_weighted_cross
    #x : [B, seq_len d_model]
    #attn_weighted_self : [B, num_heads, trg_len, trg_len]
    #attn_weighted_cross : [B, num_heads, trg_len, src_len]

In [19]:
class TransformerDecoder(nn.Module): #a stack of decoder layers
  def __init__(self, num_layers, trg_dim, output_dim, d_model, num_heads, dropout, max_seq_len=100, dim_ff=None):
    super().__init__()
    self.num_layers = num_layers
    self.tok_embed = nn.Embedding(trg_dim, d_model)
    self.pos_embed = nn.Embedding(max_seq_len, d_model)
    if dim_ff == None:
      dim_ff = 4 * d_model
    single_layer = TransformerDecoderLayer(d_model, num_heads, dropout, dim_ff)

    #prepare N subblocks
    self.layers = clones(single_layer, num_layers)

  def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
    #x : [B, tar_len, d_model]
    #enc_output : [B, src_len, d_model]
    #look_ahead_mask : for decoder (prevent future info)
    #padding mask : for blending encoder's hidden states(key) with decoder's input(query), need to ignore 'pad' positioned hidden states
    #x: [B, trg_len]
    trg_len = x.size()[1]
    B = x.size()[0]
    pos = torch.arange(0, trg_len).unsqueeze(0).repeat(B, 1) #pos: [B, trg_len]
    x = self.dropout((self.tok_embed(x)) + self.pos_embed(pos))
    layers_attn_self = []
    layers_attn_cross = []

    #pass the input (and the mask) through each layer in turn
    for layer in self.layers:
      x, attn_weighted_self, attn_weighted_cross = layer(x, look_ahead_mask, padding_mask, enc_output)
      layers_attn_self.append(attn_weighted_self)
      layers_attn_cross.append(attn_weighted_cross)



    return x, layers_attn_self, layers_attn_cross

In [20]:
class Transformer(nn.Module):
  def __init__(self, src_dim, trg_dim, num_layers, d_model, num_heads, dropout, max_seq_len=100, dim_ff=None):
    super().__init__()
    self.encoder = TransformerEncoder(num_layers, src_dim, d_model, num_heads, dropout, max_seq_len, dim_ff)
    self.decoder = TransformerDecoder(num_layers, trg_dim, d_model, num_heads, dropout, max_seq_len, dim_ff)

  def create_padding_mask(self, mask):
    return mask[:, None, None, :] #[B, 1, 1, seq_len]

  def create_look_ahead_mask(self, seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.int), diagonal=1)
    mask = 1 - mask #inverted
    return mask

  def forward(self, enc_input, dec_input, enc_pad_mask):
    # enc_input : [B, src_len, d_model]
    # dec_input : [B, tar_len, d_model]
    # enc_pad_mask :
    # - padding mask for encoder attention
    # - padding mask for decoder's 2nd attention (to blend encoder's outputs)

    #-----encoder-----
    enc_pad_mask = self.create_padding_mask(enc_pad_mask)
    enc_output, enc_attention = self.encoder(enc_input, enc_pad_mask)
    #-----decoder-----
    dec_seq_len = dec_input.size()[1]
    look_ahead_mask = self.create_look_ahead_mask(dec_seq_len).to(dec_input.device)
    output, dec_layer_attn_scores, dec_layer_cross_attn_scores = self.decoder(dec_input, enc_output,
                                                                              look_ahead_mask=look_ahead_mask,
                                                                              enc_pad_mask=enc_pad_mask)
    return output