### Import libraries

In [1]:
import torch
from torch import nn
import torch.functional as F

### self attention wide 

In [4]:
class SelfAttentionWide(nn.Module):
    def __init__(self, emb, heads=8, mask=False):
        super().__init__()
        
        self.emb = emb
        self.heads = heads
        self.mask = mask
        
        self.tokeys = nn.Linear(emb, emb*heads, bias=False)
        self.toqueries = nn.Linear(emb, emb*heads, bias=False)
        self.tovalues = nn.Linear(emb, emb*heads, bias=False)
        
        self.unifyheads = nn.Linear(heads*emb, emb)
        
    def forward(self, x):
        b,t,e = x.size()
        assert e == self.emb
        h = self.heads
        
        keys = self.tokeys(x).view(b,t,h,e)
        queries = self.toqueries(x).view(b,t,h,e)
        values = self.tovalues(x).view(b,t,h,e)
        
        #compute scaled dot product self-attention
        
        # fold heads in batch dimension
        keys = keys.transpose(1,2).contiguous().view(b*h,t,e)
        queries = queries.transpose(1,2).contiguous().view(b*h,t,e)
        values = values.transpose(1,2).contiguous().view(b*h,t,e)
        
        # instead of scaling the dot product, the queries and keys are scaled to make memory efficient computation
        queries = queries/(e**(1/4))
        keys = keys/(e**(1/4))
        
        # get dot product of queries and keys, scale
        dot = torch.bmm(queries, keys.transpose(1,2))
        
        assert dot.size() == (b*h, t, t)
        
        if self.mask:
            mask_(dot, maskval=float('-inf'), mask_diagonal=False)
        
        # dot now has self-attention probabilities
        dot = F.softmax(dot, dim=2)
        
        # apply self-attention to the values
        out = torch.bmm(dot, values).view(b,h,t,e)
        
        out = out.transpose(1,2).contiguous().view(b,t,h*e)
        
        return self.unifyheads(out)

### self attention narrow 

In [5]:
class SelfAttentionNarrow(nn.Module):
    def __init__(self, emb, heads=8, mask=False):
        super().__init__()
        
        assert emb%heads == 0
        
        self.emb = emb
        self.heads = heads
        self.mask = mask
        
        s = emb//heads
        
        self.tokeys = nn.Linear(s, s, bias=False)
        self.toqueries = nn.Linear(s, s, bias=False)
        self.tovalues = nn.Linear(s, s, bias=False)
        
        self.unifyheads = nn.Linear(head*s, emb)
        
    def forward(self, x):
        
        b,t,e = x.size()
        h = self.heads
        assert e == self.emb
        
        s = e//h
        x = x.view(b,t,h,s)
        
        keys = self.tokeys(x)
        queries = self.toqueries(x)
        values = self.tovalues(x)
        
        assert keys.size() == (b,t,h,s)
        assert queries.size() == (b,t,h,s)
        assert values.size() == (b,t,h,s)
        
        #compute scaled dot product self-attention
        
        # fold heads in batch dimension
        keys = keys.transpose(1,2).contiguous().view(b*h,t,s)
        queries = queries.transpose(1,2).contiguous().view(b*h,t,s)
        values = values.transpose(1,2).contiguous().view(b*h,t,s)        
        
        # instead of scaling the dot product, the queries and keys are scaled to make memory efficient computation
        queries = queries/(e**(1/4))
        keys = keys/(e**(1/4))    
        
        # get dot product of queries and keys, scale
        dot = torch.bmm(queries, keys.transpose(1,2))
        
        assert dot.size() == (b*h, t, t)
        
        if self.mask:
            mask_(dot, maskval=float('-inf'), mask_diagonal=False)
        
        # dot now has self-attention probabilities
        dot = F.softmax(dot, dim=2)
        
        # apply self-attention to the values
        out = torch.bmm(dot, values).view(b,h,t,s)
        
        out = out.transpose(1,2).contiguous().view(b,t,s*h)
        
        return self.unifyheads(out)

### Transformer 

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, emb, heads, mask, seq_length, ff_hidden_mult=4, dropout=0.0, wide=True):
        super().__init__()
        
        self.attention = SelfAttentionWide(emb, heads=heads, mask=mask) if wide \
                    else SelfAttentionNarrow(emb, heads=heads, mask=mask)
        
        self.mask = mask
        
        self.norm1 = nn.LayerNorm(emb)
        self.norm2 = nn.LayerNorm(emb)
        
        self.ff = nn.Sequential(
            nn.Linear(emb, ff_hidden_mult*emb),
            nn.ReLU(),
            nn.Linear(ff_hidden_mult*emb, emb)
        )
        
        self.do = nn.Dropout(dropout)
        
    def forward(self, x):
        attended = self.attention(x)
        x = self.norm1(attended + x)
        x = self.do(x)
        fedforward = self.ff(x)
        x = self.norm2(fedforward + x)
        x = self.do(x)
        return x