In [51]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import math as m

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=4, num_heads=2, dropout=0.3):
        super().__init__()

        # d_q, d_k, d_v
        self.d = d_model//num_heads


        self.d_model = d_model
        self.num_heads = num_heads

        self.dropout = nn.Dropout(dropout)

        ##create a list of layers for K, and a list of layers for V
        
        self.linear_Qs = nn.ModuleList([nn.Linear(d_model, self.d)
                                        for _ in range(num_heads)])
        self.linear_Ks = nn.ModuleList([nn.Linear(d_model, self.d)
                                        for _ in range(num_heads)])
        self.linear_Vs = nn.ModuleList([nn.Linear(d_model, self.d)
                                        for _ in range(num_heads)])

        self.mha_linear = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # shape(Q) = [B x seq_len x D/num_heads]
        # shape(K, V) = [B x seq_len x D/num_heads]

        Q_K_matmul = torch.matmul(Q, K.permute(0, 2, 1))
        scores = Q_K_matmul/m.sqrt(self.d)
        # shape(scores) = [B x seq_len x seq_len]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)
        # shape(attention_weights) = [B x seq_len x seq_len]

        output = torch.matmul(attention_weights, V)
        # shape(output) = [B x seq_len x D/num_heads]

        return output, attention_weights

    def forward(self, pre_q, pre_k, pre_v, mask=None):
        # shape(x) = [B x seq_len x D]

        Q = [linear_Q(pre_q) for linear_Q in self.linear_Qs]
        K = [linear_K(pre_k) for linear_K in self.linear_Ks]
        V = [linear_V(pre_v) for linear_V in self.linear_Vs]
        # shape(Q, K, V) = [B x seq_len x D/num_heads] * num_heads

        

        output_per_head = []
        attn_weights_per_head = []
        # shape(output_per_head) = [B x seq_len x D/num_heads] * num_heads
        # shape(attn_weights_per_head) = [B x seq_len x seq_len] * num_heads
        
        for Q_, K_, V_ in zip(Q, K, V):
            
            ##run scaled_dot_product_attention
            output, attn_weight = self.scaled_dot_product_attention(Q_, K_, V_, mask)
            # shape(output) = [B x seq_len x D/num_heads]
            # shape(attn_weights_per_head) = [B x seq_len x seq_len]
            output_per_head.append(output)
            attn_weights_per_head.append(attn_weight)

        output = torch.cat(output_per_head, -1)
        attn_weights = torch.stack(attn_weights_per_head).permute(1, 0, 2, 3)
        # shape(output) = [B x seq_len x D]
        # shape(attn_weights) = [B x num_heads x seq_len x seq_len]
        
        projection = self.dropout(self.mha_linear(output))

        return projection, attn_weights
        

In [52]:
import torch.nn as nn

class PWFFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.3):
        super().__init__()

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        # shape(x) = [B x seq_len x D]

        ff = self.ff(x)
        # shape(ff) = [B x seq_len x D]

        return ff

In [53]:
import torch.nn as nn


class ResidualLayerNorm(nn.Module):
    def __init__(self, d_model, dropout=0.3):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, residual):
        ln = self.layer_norm(self.dropout(x) + residual)
        return ln

In [54]:
import torch.nn as nn

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.3, efficient_mha=False):
        super().__init__()

        # initalize these
        self.norm_1 = ResidualLayerNorm(d_model, dropout)
        self.norm_2 = ResidualLayerNorm(d_model, dropout)

        self.mha = MultiHeadAttention(d_model, num_heads, dropout)

        self.ff = PWFFN(d_model, d_ff, dropout)

    def forward(self, x, mask):
        # shape(x) = [B x seq_len x D]

        mha, encoder_attention_weights = self.mha(x, x, x, mask)
        # shape(mha) = [B x seq_len x D]
        # shape(encoder_attention_weights) = [B x num_heads x seq_len x seq_len]

        norm1 = self.norm_1(mha, x)
        # shape(norm1) = [B x seq_len x D]

        ff = self.ff(norm1)
        norm2 = self.norm_2(ff, norm1)
        # shape(ff) = [B x seq_len x D]
        # shape(norm2) = [B x seq_len x D]

        return norm2, encoder_attention_weights

In [55]:
toy_encodings = torch.Tensor([[[0.0, 0.1, 0.2, 0.3], [1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3]]]) 
# shape(toy_encodings) = [B, T, D] = (1, 3, 4)
print("Toy Encodings:\n", toy_encodings)

Toy Encodings:
 tensor([[[0.0000, 0.1000, 0.2000, 0.3000],
         [1.0000, 1.1000, 1.2000, 1.3000],
         [2.0000, 2.1000, 2.2000, 2.3000]]])


In [56]:
toy_encoder_layer = EncoderLayer(d_model=4,num_heads=2, d_ff=16)
toy_encoder_layer_outputs, toy_encoder_layer_attn_outputs = toy_encoder_layer(toy_encodings,None)

print("Encodings: \n", toy_encoder_layer_outputs)
print("Encoder Layer Attn Weights: \n", toy_encoder_layer_attn_outputs)
print("Encodings Shape: \n", toy_encoder_layer_outputs.shape)
print("Encodings Attn Layer Weights Shape: \n", toy_encoder_layer_attn_outputs.shape)

Encodings: 
 tensor([[[ 0.8181, -1.6777,  0.1589,  0.7006],
         [ 0.3733, -1.4753, -0.1829,  1.2848],
         [ 0.9136, -1.5786, -0.1431,  0.8080]]],
       grad_fn=<NativeLayerNormBackward0>)
Encoder Layer Attn Weights: 
 tensor([[[[0.2824, 0.3306, 0.3870],
          [0.1588, 0.2946, 0.5466],
          [0.0795, 0.2336, 0.6869]],

         [[0.2697, 0.3290, 0.4013],
          [0.1662, 0.2983, 0.5355],
          [0.0941, 0.2487, 0.6572]]]], grad_fn=<PermuteBackward0>)
Encodings Shape: 
 torch.Size([1, 3, 4])
Encodings Attn Layer Weights Shape: 
 torch.Size([1, 2, 3, 3])


In [57]:
Net = MultiHeadAttention()

# Display all model layer weights
for name, param in Net.named_parameters():
    if param.requires_grad:
        print(name, param.data)

linear_Qs.0.weight tensor([[ 0.0060,  0.3131,  0.3927,  0.3131],
        [ 0.3914, -0.3025,  0.4542, -0.1676]])
linear_Qs.0.bias tensor([ 0.4214, -0.2699])
linear_Qs.1.weight tensor([[ 0.4249, -0.2427, -0.0262, -0.1922],
        [-0.4202, -0.4428,  0.1350,  0.0142]])
linear_Qs.1.bias tensor([-0.2336, -0.4481])
linear_Ks.0.weight tensor([[ 0.4329,  0.3307,  0.3568, -0.0264],
        [ 0.3394, -0.4201, -0.0122,  0.1136]])
linear_Ks.0.bias tensor([ 0.3852, -0.1858])
linear_Ks.1.weight tensor([[-0.3380, -0.3888,  0.4956,  0.1191],
        [-0.3206,  0.1198, -0.3628, -0.3521]])
linear_Ks.1.bias tensor([0.0322, 0.2747])
linear_Vs.0.weight tensor([[ 0.0266, -0.0121, -0.2924,  0.1973],
        [ 0.2219,  0.1576,  0.1671, -0.0789]])
linear_Vs.0.bias tensor([0.3774, 0.0881])
linear_Vs.1.weight tensor([[-0.3803, -0.1506,  0.2279, -0.1779],
        [-0.0215, -0.2624, -0.3260, -0.2322]])
linear_Vs.1.bias tensor([-0.2906, -0.1739])
mha_linear.weight tensor([[ 0.0193,  0.2813,  0.2510,  0.2645],
    