In [20]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import math as m

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=4, num_heads=2, dropout=0.3):
        super().__init__()

        # d_q, d_k, d_v
        self.d = d_model//num_heads


        self.d_model = d_model
        self.num_heads = num_heads

        self.dropout = nn.Dropout(dropout)

        ##create a list of layers for K, and a list of layers for V
        
        self.linear_Qs = nn.ModuleList([nn.Linear(d_model, self.d)
                                        for _ in range(num_heads)])
        self.linear_Ks = nn.ModuleList([nn.Linear(d_model, self.d)
                                        for _ in range(num_heads)])
        self.linear_Vs = nn.ModuleList([nn.Linear(d_model, self.d)
                                        for _ in range(num_heads)])

        self.mha_linear = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # shape(Q) = [B x seq_len x D/num_heads]
        # shape(K, V) = [B x seq_len x D/num_heads]

        Q_K_matmul = torch.matmul(Q, K.permute(0, 2, 1))
        scores = Q_K_matmul/m.sqrt(self.d)
        # shape(scores) = [B x seq_len x seq_len]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)
        # shape(attention_weights) = [B x seq_len x seq_len]

        output = torch.matmul(attention_weights, V)
        # shape(output) = [B x seq_len x D/num_heads]

        return output, attention_weights

    def forward(self, pre_q, pre_k, pre_v, mask=None):
        # shape(x) = [B x seq_len x D]

        Q = [linear_Q(pre_q) for linear_Q in self.linear_Qs]
        K = [linear_K(pre_k) for linear_K in self.linear_Ks]
        V = [linear_V(pre_v) for linear_V in self.linear_Vs]
        # shape(Q, K, V) = [B x seq_len x D/num_heads] * num_heads

        output_per_head = []
        attn_weights_per_head = []
        # shape(output_per_head) = [B x seq_len x D/num_heads] * num_heads
        # shape(attn_weights_per_head) = [B x seq_len x seq_len] * num_heads
        
        for Q_, K_, V_ in zip(Q, K, V):
            
            ##run scaled_dot_product_attention
            output, attn_weight = self.scaled_dot_product_attention(Q_, K_, V_, mask)
            # shape(output) = [B x seq_len x D/num_heads]
            # shape(attn_weights_per_head) = [B x seq_len x seq_len]
            output_per_head.append(output)
            attn_weights_per_head.append(attn_weight)

        output = torch.cat(output_per_head, -1)
        attn_weights = torch.stack(attn_weights_per_head).permute(1, 0, 2, 3)
        # shape(output) = [B x seq_len x D]
        # shape(attn_weights) = [B x num_heads x seq_len x seq_len]
        
        projection = self.dropout(self.mha_linear(output))

        return projection, attn_weights

In [21]:
import torch.nn as nn

class PWFFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.3):
        super().__init__()

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        # shape(x) = [B x seq_len x D]

        ff = self.ff(x)
        # shape(ff) = [B x seq_len x D]

        return ff

In [22]:
import torch.nn as nn


class ResidualLayerNorm(nn.Module):
    def __init__(self, d_model, dropout=0.3):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, residual):
        ln = self.layer_norm(self.dropout(x) + residual)
        return ln

In [23]:
import torch.nn as nn

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.3, efficient_mha=False):
        super().__init__()

        # initalize these
        self.norm_1 = ResidualLayerNorm(d_model, dropout)
        self.norm_2 = ResidualLayerNorm(d_model, dropout)

        self.mha = MultiHeadAttention(d_model, num_heads, dropout)

        self.ff = PWFFN(d_model, d_ff, dropout)

    def forward(self, x, mask):
        # shape(x) = [B x seq_len x D]

        mha, encoder_attention_weights = self.mha(x, x, x, mask)
        # shape(mha) = [B x seq_len x D]
        # shape(encoder_attention_weights) = [B x num_heads x seq_len x seq_len]

        norm1 = self.norm_1(mha, x)
        # shape(norm1) = [B x seq_len x D]

        ff = self.ff(norm1)
        norm2 = self.norm_2(ff, norm1)
        # shape(ff) = [B x seq_len x D]
        # shape(norm2) = [B x seq_len x D]

        return norm2, encoder_attention_weights

In [24]:
toy_encodings = torch.Tensor([[[0.0, 0.1, 0.2, 0.3], [1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3]]]) 
# shape(toy_encodings) = [B, T, D] = (1, 3, 4)
print("Toy Encodings:\n", toy_encodings)

Toy Encodings:
 tensor([[[0.0000, 0.1000, 0.2000, 0.3000],
         [1.0000, 1.1000, 1.2000, 1.3000],
         [2.0000, 2.1000, 2.2000, 2.3000]]])


In [26]:
toy_encoder_layer = EncoderLayer(d_model=4,num_heads=2, d_ff=16)
toy_encoder_layer_outputs, toy_encoder_layer_attn_outputs = toy_encoder_layer(toy_encodings,None)

print("Encodings: \n", toy_encoder_layer_outputs)
print("Encoder Layer Attn Weights: \n", toy_encoder_layer_attn_outputs)
print("Encodings Shape: \n", toy_encoder_layer_outputs.shape)
print("Encodings Attn Layer Weights Shape: \n", toy_encoder_layer_attn_outputs.shape)

Encodings: 
 tensor([[[ 0.4009, -1.3086,  1.3798, -0.4722],
         [ 0.3350, -0.1704,  1.3042, -1.4689],
         [-1.1922,  1.4353, -0.6160,  0.3729]]],
       grad_fn=<NativeLayerNormBackward0>)
Encoder Layer Attn Weights: 
 tensor([[[[0.3265, 0.3333, 0.3403],
          [0.2811, 0.3304, 0.3885],
          [0.2388, 0.3234, 0.4378]],

         [[0.3952, 0.3297, 0.2750],
          [0.4850, 0.3130, 0.2020],
          [0.5719, 0.2856, 0.1426]]]], grad_fn=<PermuteBackward0>)
Encodings Shape: 
 torch.Size([1, 3, 4])
Encodings Attn Layer Weights Shape: 
 torch.Size([1, 2, 3, 3])
