In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from IPython import embed

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SelfAttention(nn.Module):
    
    def __init__(self,
                 k,
                 heads = 4,
                 mask = False):
        
        # keeping multiple inheritance possibilities for nn.Module
        # need to hardwire it otherwise, e.g. nn.Module.__init__()
        super().__init__()
        
        # we want every head to have the same shape
        # e.g., 1 x 20 x (64*8) for 8 heads
        assert k % heads == 0
        
        self.k, self.heads = k, heads
        
        # linear layers to learn transformations of the input
        # 512 x 512
        self.K = nn.Linear(k, k, bias=False)
        self.Q = nn.Linear(k, k, bias=False)
        self.V = nn.Linear(k, k, bias=False)
        
        # 
        self.mergeheads = nn.Linear(k, k)
        
        
        
    def forward(self,
                x):
        
        # batch size, sequence size, embedding size
        # e.g., 1 x 10 x 512
        b, t, k = x.size()
        h = self.heads
        
        # 512 x 512 X 512 x 512 -> 512 x 512
        # here we get **full** embeddings
        queries = self.Q(x)
        keys = self.K(x)
        values = self.V(x)
        
        # identify the dimension size of each head
        # e.g., 512 / 8 = 64
        s = k // h
        
        # reshaping vectors into the shape incl. attention heads
        # 1 x 10 x 8 x 64
        keys = keys.view(b, t, h, s)
        queries = queries.view(b, t, h, s)
        values = values.view(b, t, h, s)
        
        # computing dot products means we need to move head dimension into the batch dimension
        # we should be able to use torch.bmm()
        # 8 x 10 x 64
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)
        
        # dot product of keys and queries
        # 8 x 10 x 64 X 8 x 64 x 10, eliminate dimension 64 because of matrix multiplication rules
        # dot shape: 8 x 10 x 10, b * h x t x t
        dot = torch.bmm(queries, keys.transpose(1, 2))

        # scaling
        # why? our initial vectors are samples from standard normal distribution with mean 0 and variance 1
        # when we perform multiplication between vectors of size N, the variance of the result will become N
        # it is too sparse, so we take N and use it to normalise our result
        dot = dot / (k ** (1/2))

        # transform into probabilities 
        dot = F.softmax(dot, dim = 2)
        # now we get normalised weights / probabilities
        
        # 8 x 10 x 10 X 8 x 10 x 64, eliminate dimension 10
        # result: torch.Size([1, 8, 10, 64])
        out = torch.bmm(dot, values).view(b, h, t, s)
        
        # get the result in the original shape, e.g. swapped heads
        # torch.Size([1, 10, 512])
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)
        
        #embed(); raise
        
        # finally, return projection of all heads together
        # why? before this step we performed head-wise operations, and now we need to learn to put them all in the same space
        # standard step
        return self.mergeheads(out)

In [3]:
class TransformerBlock(nn.Module):
    
    def __init__(self,
                 emb,
                 heads,
                 mask,
                 seq_length,
                 ff_hidden_size = 8,
                 dropout = 0.0):
        
        super().__init__()

        self.attention = SelfAttention(emb,
                                       heads = heads,
                                       mask = mask)
        self.mask = mask

        self.norm1 = nn.LayerNorm(emb)
        self.norm2 = nn.LayerNorm(emb)

        self.ff = nn.Sequential(
            nn.Linear(emb, ff_hidden_size * emb),
            nn.ReLU(),
            nn.Linear(ff_hidden_size * emb, emb)
        )

        self.dropout = nn.Dropout(dropout)
        


    def forward(self, x):

        attended = self.attention(x)
        x = self.norm1(attended + x)
        x = self.dropout(x)
        fedforward = self.ff(x)
        x = self.norm2(fedforward + x)
        x = self.dropout(x)

        return x

In [4]:
emb_size = 512
heads = 8
mask = False
seq_length = 10
dropout = 0.5
ff_hidden_size = 8

transformer = TransformerBlock(emb_size,
                               heads,
                               mask,
                               seq_length,
                               ff_hidden_size = ff_hidden_size,
                               dropout = dropout)

In [5]:
transformer

TransformerBlock(
  (attention): SelfAttention(
    (K): Linear(in_features=512, out_features=512, bias=False)
    (Q): Linear(in_features=512, out_features=512, bias=False)
    (V): Linear(in_features=512, out_features=512, bias=False)
    (mergeheads): Linear(in_features=512, out_features=512, bias=True)
  )
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (ff): Sequential(
    (0): Linear(in_features=512, out_features=4096, bias=True)
    (1): ReLU()
    (2): Linear(in_features=4096, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.5, inplace=False)
)

In [6]:
X = torch.rand(5, 10, 512)

In [7]:
X.shape

torch.Size([5, 10, 512])

In [8]:
out = transformer(X)

In [9]:
out.shape

torch.Size([5, 10, 512])