In [None]:
## why not put attention heads as 3rd dimension? (BATCH_SIZE, ATTENTION_HEADS, SEQ_LEN, D_MODEL)
## Apply padding masks to multi head attention
## how to stack transformer encoder decoder layers

In [108]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Building Blocks

#### Multi Head Attention

In [396]:
class Multi_Head_Attention(nn.Module) :

    def __init__(self, n_heads, d_model, d_key_query = None, d_value = None, mask = False) :
        super().__init__()
        
        self.n_heads = n_heads
        self.d_model = d_model
        self.mask = mask

        if not d_key_query : d_key_query = d_model//n_heads
        if not d_value : d_value = d_model//n_heads

        self.key = nn.Linear(d_model, d_key_query)
        self.query = nn.Linear(d_model, d_key_query)
        self.value = nn.Linear(d_model, d_value)

        self.linear  = nn.Linear(d_value * n_heads, d_model)

    def forward(self, key, query, value) :
        
        batch_size = key.size(0)
        seq_len = key.size(1)
        
        key = key.unsqueeze(1).repeat_interleave(self.n_heads, dim = 1)
        query = query.unsqueeze(1).repeat_interleave(self.n_heads, dim = 1)
        value = value.unsqueeze(1).repeat_interleave(self.n_heads, dim = 1)

        ## creating the key, query, value matrices for the input
        keys = F.gelu(self.key(key))
        queries = F.gelu(self.query(query))
        values = F.gelu(self.value(value))

        ## scalar dot product attention
        attention_scores = torch.matmul(queries, keys.transpose(2, 3))/torch.sqrt(torch.tensor(self.d_model))
        attention_scores = torch.softmax(attention_scores, dim = 3)

        ## create mask and apply
        mask = torch.ones(batch_size, self.n_heads, seq_len, seq_len)
        if self.mask :
            mask = torch.tril(mask)
        attention_scores = torch.matmul(mask, attention_scores)

        ## matmul with the values
        values_with_attention = torch.matmul(attention_scores, values)
        concatenated_vectors = torch.reshape(torch.stack([values_with_attention[:, idx, :, :] for idx in range(self.n_heads)], dim = 2), (batch_size, seq_len, -1))
        output = F.gelu(self.linear(concatenated_vectors))

        return output


#### Add and layer normalization

In [397]:
class AddLayerNormalization(nn.Module) :

    def __init__(self, sequence_len, d_model) :
        super().__init__()

        self.layer_norm = nn.LayerNorm([sequence_len, d_model])

    def forward(self, x, mha_output) :

        return self.layer_norm(x + mha_output)

#### PointWise Feedforward

In [398]:
class PointWise_Feedforward(nn.Module) :

    def __init__(self, d_ff, d_model) :
        super().__init__();

        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x) :

        linear1_output = self.linear1(x)
        linear2_output = self.linear2(F.relu(linear1_output))

        return linear2_output
        

## Single transformer layer

#### Encoder layer

In [400]:
class Encoder_Layer(nn.Module) :

    def __init__(self, n_heads, max_seq_len, d_model, d_ff) :
        super().__init__()

        self.mha = Multi_Head_Attention(n_heads, d_model)
        self.layer_norm = AddLayerNormalization(max_seq_len, d_model)
        self.pff = PointWise_Feedforward(d_ff, d_model)
        self.layer_norm2 = AddLayerNormalization(max_seq_len, d_model)

    def forward(self, x) :

        mha_output = self.mha(key = x, query = x, value = x)
        norm_output1 = self.layer_norm(x, mha_output)

        pff_output = self.pff(norm_output1)
        norm_output2 = self.layer_norm2(norm_output1, pff_output)

        return norm_output2

In [401]:
model = Encoder_Layer(8, 5, 512, 2048)

with torch.no_grad() :
    t = torch.rand(2, 5, 512)
    output = model(t)

## Experimentation

In [370]:
t = torch.rand(2, 3, 5, 6)
t

tensor([[[[5.6915e-01, 2.3101e-01, 1.8677e-01, 3.2823e-01, 3.3690e-01,
           8.1443e-01],
          [2.2544e-02, 6.1115e-01, 9.3989e-01, 9.0730e-01, 9.9352e-01,
           9.9025e-01],
          [4.8508e-01, 7.8124e-02, 3.3347e-01, 8.0652e-01, 4.0264e-01,
           1.4739e-01],
          [3.7981e-02, 6.8632e-01, 5.8275e-01, 7.5194e-01, 6.3034e-01,
           6.8327e-01],
          [7.0750e-01, 8.4914e-01, 6.0135e-01, 6.9475e-01, 8.7851e-01,
           1.9928e-01]],

         [[6.3917e-01, 4.5646e-02, 9.1937e-01, 8.9844e-01, 5.1065e-01,
           9.8977e-01],
          [2.6453e-01, 1.4332e-01, 2.2712e-01, 4.8705e-01, 7.6302e-01,
           7.6922e-01],
          [6.1331e-02, 1.5677e-01, 8.2732e-01, 9.2498e-01, 5.2458e-01,
           7.4341e-02],
          [6.1056e-02, 8.2466e-01, 4.1147e-01, 6.4349e-01, 1.5110e-01,
           3.1953e-01],
          [7.2273e-01, 1.9713e-01, 5.5603e-01, 8.7607e-01, 7.6079e-01,
           9.1249e-01]],

         [[8.3862e-01, 5.3867e-01, 5.1706e-01,

In [373]:
torch.matmul(torch.tril(torch.ones(2, 3, 5, 5)), t)

tensor([[[[5.6915e-01, 2.3101e-01, 1.8677e-01, 3.2823e-01, 3.3690e-01,
           8.1443e-01],
          [5.9170e-01, 8.4216e-01, 1.1267e+00, 1.2355e+00, 1.3304e+00,
           1.8047e+00],
          [1.0768e+00, 9.2028e-01, 1.4601e+00, 2.0421e+00, 1.7331e+00,
           1.9521e+00],
          [1.1148e+00, 1.6066e+00, 2.0429e+00, 2.7940e+00, 2.3634e+00,
           2.6353e+00],
          [1.8223e+00, 2.4557e+00, 2.6442e+00, 3.4887e+00, 3.2419e+00,
           2.8346e+00]],

         [[6.3917e-01, 4.5646e-02, 9.1937e-01, 8.9844e-01, 5.1065e-01,
           9.8977e-01],
          [9.0370e-01, 1.8897e-01, 1.1465e+00, 1.3855e+00, 1.2737e+00,
           1.7590e+00],
          [9.6503e-01, 3.4574e-01, 1.9738e+00, 2.3105e+00, 1.7982e+00,
           1.8333e+00],
          [1.0261e+00, 1.1704e+00, 2.3853e+00, 2.9540e+00, 1.9494e+00,
           2.1529e+00],
          [1.7488e+00, 1.3675e+00, 2.9413e+00, 3.8300e+00, 2.7101e+00,
           3.0653e+00]],

         [[8.3862e-01, 5.3867e-01, 5.1706e-01,

In [374]:
torch.tril(torch.ones(2, 3, 5, 5))

tensor([[[[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]],

         [[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]],

         [[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]],

         [[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]],

         [[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]]]])