In [3]:
import torch.nn as nn
import torch
import math

# Scaled Dot Product Attention

In [None]:
def scaled_dot_product_attention(q, k, v, mask):
    '''
        q:
        k:
        v:
        mask: (batch_size, 1, 1, seq_len) -> in case of encoder should be input seq len and same for q and k
    '''
    # multiply queries and keys 
    prod_qk = torch.matmul(q, k.transpose(2, 3)) # (batch_size, num_heads, seq_len_q, seq_len_k)
    # NOTE: matmul will do a batched matrix multiplication for N>2 dim
    
    # scale them by sqrt of depth
    scaled_prod_qk = prod_qk / math.sqrt(float(k.shape[-1]))
    
    # add the mask
    if mask is not None:
        # for each query, block out the weights associated with keys corresponding to padding tokens or tokens we shouldn't look ahead to
        # mask has 1's for tokens to mask, multiply those positions with neg inf and add so q*k prod there becomes neg inf
        scaled_prod_qk += (mask * -1e9)
        
    # take softmax across key dim for each query to get attention weights
    attention_weights = nn.functional.softmax(scaled_prod_qk, dim=-1) # (batch_size, num_heads, seq_len_q, seq_len_k)
        
    # compute weighted sum of keys based on values
    output = torch.matmul(attention_weights, v) # seq_len_k == seq_len_v are equal so (batch_size, num_heads, seq_len_q, depth_v)
    
    return output, attention_weights

# Multi-head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = self.d_model // self.num_heads # dimension of attention vectors in each head (after splitting across each head)
        
        self.wq = nn.linear(d_model, d_model) 
        self.wk = nn.linear(d_model, d_model)
        self.wv = nn.linear(d_model, d_model)
        
        self.dense = nn.linear(d_model, d_model) # brings heads back together
        
    def split_heads(self, x, batch_size):
        x = torch.reshape(x, (batch_size, -1, self.num_heads, self.depth)) # split last dim into num_heads and depth
        x = x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, depth)
        return x
        
    def forward(self, v_inp, k_inp, q_inp, mask):
        # compute queries, keys, and values
        q = self.wq(q_inp) # (batch_size, seq_len_q, d_model)
        k = self.wk(k_inp) # (batch_size, seq_len_k, d_model)
        v = self.wv(v_inp) # (batch_size, seq_len_v, d_model)
        
        # split them across the n heads
        q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_v, depth)
        
        # compute scaled dot prod attention within each head
        # attention_outputs.shape = (batch_size, num_heads, seq_len_q, depth_v)
        # attention_weights.shape = (batch_size, num_heads, seq_len_q, seq_len_k)
        attention_outputs, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        
        # bring heads back together
        concat_attention_outputs = torch.reshape(attention_weights, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, depth)
        
        # pass through final dense layer
        output = self.dense(concat_attention_outputs) # (batch_size, seq_len_q, depth)
        
        return output, attention_weights
        

# Encoder

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, dropout_rate):
        super(EncoderLayer, self).__init__()
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.dropout1 = nn.Dropout(dropout_rate) # Q: what are dropout and layer norm doing?
        self.layer_norm1 = nn.LayerNorm(d_model) # norm across columns

        self.feed_fwd_network = nn.Sequential(nn.Linear(d_model, dff), \
            nn.ReLU(), nn.Linear(dff, d_model))
        self.dropout2 = nn.Dropout(dropout_rate)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        # x: (batch size, input_seq_len, d_model)
        z = self.multi_head_attention(x, x, x, mask) # (batch size, input_seq_len, d_model)
        z = self.dropout1(z)
        norm_resid_z = self.layer_norm1(x + z) # (batch size, input_seq_len, d_model)

        ffn_out = self.feed_fwd_network(norm_resid_z) # (batch size, input_seq_len, d_model)
        ffn_out = self.dropout2(ffn_out)
        enc_out = self.layer_norm2(norm_resid_z + ffn_out) # (batch size, input_seq_len, d_model)

        return enc_out


class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, dff, dropout_rate, input_vocab_size):
        """
            num_layers: number of encoder blocks
            attention_vec_dim: 
            num_heads:
            dff:
            dropout_rate:
            input_vocab_size: 
        """
        super(TransformerEncoder, self).__init__()
        self.embedding_layer = nn.Embedding(input_vocab_size, attention_vec_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.num_layers = num_layers
        self.encoder_layers = [EncoderLayer(d_model, num_heads, dff, dropout_rate) \
                for _ in range(self.num_layers)]

    def forward(self, x, mask):
        # x (batch_size, input_seq_len) -> e.g. [1, 12, 21, 17, 0, 0] -> list of ids from tokenizer vocab
        x = self.embedding_layer(x)
        
        #TODO: add positional encodings
        
        enc_output = self.dropout(x) # Q: why dropout right after embedding, what would that really do?
        
        for i in range(self.num_layers):
            enc_output = self.encoder_layers[i](enc_output, mask)
            
        return enc_output


In [9]:
mat = torch.zeros(1, 2, 3)
mat1 = torch.zeros(1, 3, 2)
print(torch.matmul(mat, mat1).shape)
print(mat.shape)
print(mat.shape[-1])
print(mat[-1])
print(mat.T.shape)
print(mat.permute(0, 2, 1).shape)

torch.Size([1, 2, 2])
torch.Size([1, 2, 3])
3
tensor([[0., 0., 0.],
        [0., 0., 0.]])
torch.Size([3, 2, 1])
torch.Size([1, 3, 2])
