In [1]:
#export mllib.bert

In [2]:
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
import random, math

In [13]:
import torch, os

def mask_(matrices, maskval = 0.0, mask_diagonal=True):
    """ mask out all values in the given batch of matrices where i <=j holds"". i < j if mask_diagonal is False"""
    b, h, w = matrices.size()
    
    indices = torch.triu_indices(h, w, offset=0 if mask_diagonal else 1)
    matrices[:, indices[0], indices[1]] = maskval
    

def d(tensor=None):
    if tensor is None:
        return 'cuda' if torch.cuda.is_available() else 'cpu'
    
    return 'cuda' if tensor.is_cuda else 'cpu'

def contains_nan(tensor):
    return bool((tensor != tensor).sum() > 0)

# Self Attention Wide

Self attention module resposible to create self attention among input elements. 

Parameters are 

a) Linear layers ( queries , keys , values )

b) Unify heads ( emb* heads, heads) 

In [5]:
class SelefAttentionWide(nn.Module):
    def __init__(self, emb, heads=8, mask=False):
        '''
            :param emb
            :param heads
            :param mask
        '''
        
        super().__init__()
        self.emb = emb
        self.heads = heds
        self.mask = mask
        
        self.tokeys = nn.Linear(emb , emb*heads, bias = False)
        self.toqueries = nn.Linear(emb, emb*heads, bias = False)
        self.tovalues = nn.Linear(emb , emb*heads, bias=False)
        
        self.unifyheads = nn.Linear(heads * emb, emb)
        
    def forward(self, x):
        b, t, e = x.size()
        h = self.heads
        
        assert e == self.emb, f'Input embedding dim({e}) should match layer embedding dim ({self.emb})'
        
        keys = self.tokeys(x).view(b, t, h, e)
        queries = self.toqueries(x).view(b, t, h ,e)
        values = self.tovalues(x).view(b, t, h ,e)
        
        # compute scaled dot-product self-attention
        
        keys = keys.transpose(1, 2).contiguous().view(b*h, t, e)
        queries = queries.transpose(1, 2).contiguous().view(b*h, t, e)
        values = values.transpose(1,2).contiguous().view(b*h, t, e)
        
        queries = queries / (e**(1/4))
        keys = keys / (e**(1/4))
        
        dot = torch.bmm(queries, keys.transpose(1,2))
        
        assert dot.size == (b*h, t, t)
        
        if self.mask: # mask out the upper half of the dot matrix, excluding the diagonal
            mask_(dot, maskval=float('-inf'), mask_diagonal=False)
            
        dot = F.softmax(dot, dim=2)
        
        # apply the self attention to the values
        out = torch.bmm(dot, values).view(b, h, t, e)
        
        out = out.transpose(1, 2).contiguous().view(b, t, h*e)
        
        return self.unifyheads(out)
    

# Narrow Self Attention

Narrow Self Attention is more computationally effective. It breaks the inputs into multiple parts and uses each of them to compute attention output. Finally it combines all attention outputs into one.

In [6]:
class SelfAttentionNarrow(nn.Module):
    def __init__(self, emb, heads=8, mask=False):
        super().__init__()
        
        assert emb % heads == 0, f'Embedding dimension ({emb}) should be divisible by nr. of heads ({heads})'
        
        self.emb = emb
        self.heads = heads
        self.mask = mask
        
        s = emb // heads
        # We will break embedding into `heads` chunks and feed each to different attention head
        
        self.tokeys= nn.Linear(s, s, bias=False)
        self.toqueries = nn.Linear(s, s, bias=False)
        self.tovalues = nn.Linear(s, s, bias=False)
        
        self.unifyheads = nn.Linear(heads * s, emb)
        
    def forward(self, x):
        b, t, e = x.size()
        h = self.heads()
        
        assert e == self.emb , f'Input dimension ({e}) should match layer embedding dim ({self.emb})'
        
        s = e//h
        x = x.view(b, t, h, s)
        
        keys = self.tokeys(x)
        queries = self.toqueries(x)
        values = self.tovalues(x)
        
        assert keys.size() == (b, t, h, s)
        assert queries.size() == (b , t, h, s)
        assert values.size() == (b, t, h, s)
        
        
        # compute scaled dot product self attention
        keys = keys.transpose(1,2).contiguous().view(b*h,t, s)
        queries = queries.transpose(1,2).contiguous().view(b*h, t, s)
        values = values.transpose(1,2).contiguous().view(b*h, t, s)
        
        queries = queries / (e**(1/4))
        keys = keys / (e**(1/4))
        
        dot = torch.bmm(queries, keys.transpose(1,2))
        
        assert dot.size() == (b*h, t, t)
        
        if self.mask:
            mask_(dot, maskval=float('-inf'),mask_diagonal=False)
        
        dot = F.softmax(dot, dim=2)
        
        out = torch.bmm(dot, values).view(b, h, t, s)
        
        out = out.transpose(1,2).contiguous(b, t, s*h)
        
        return self.unifyheads(out)
        
        

Now comes the interesting part, how to use attention heads to form transformer block

In [8]:
class TransformerBlock(nn.Module):
    def __init__(self, emb, heads, mask, seq_length, ff_hidden_mult=4, dropout=0.0, wide=True):
        super().__init__()
        
        self.attention = SelfAttentionWide(emb, heads=heads, mask=mask)
        
        self.mask = mask
        
        self.norm1 = nn.LayerNorm(emb)
        self.norm2 = nn.LayerNorm(emb)
        
        self.ff = nn.Sequential(
            nn.Linear(emb, ff_hidden_mult * emb),
            nn.ReLU(),
            nn.Linear(ff_hidden_mult* emb, emb)
        )
        
        self.do = nn.Dropout(dropout)
        
    def forward(self, x):
        attended = self.attention(x)
        
        x = self.norm1(attended + x)
        
        x = self.do(x)
        
        feedforward = self.ff(x)
        
        x = self.norm2(feedforward + x)
        
        x = self.do(x)
        
        return x
        
        
        

# Transformers

In [9]:
import torch
from torch import nn
import torch.nn.functional as F

In [11]:
class GTransformer(nn.Module):
    """ 
        Transformer for generating text character by character
    """
    
    def __init__(self, emb, heads, depth, seq_length, num_tokens, wide=False):
        super().__init__()
        
        self.num_tokens= num_tokens
        self.token_embedding = nn.Embedding(embedding_dim = emb, num_embeddings = num_tokens)
        self.pos_embedding = nn.Embedding(embedding_dim = emb, num_embeddings = seq_length)
        
        tblocks = []
        for i in range(depth):
            tblocks.append(
                    TransformerBlock(emb = emb, heads=heads, seq_length=seq_length, mask=True, wide=wide)
                )
            
        self.tblocks = nn.Sequential(*tblocks)
        self.toprobs = nn.Linear(emb, num_tokens)
        
        
    def forward(self, x):
        '''
            param A (batch, sequence_length) integer tensor of token indices
            return predicted log-probability vectors for each token based on preceding tokens
        '''
        tokens = self.token_embedding(x)
        b, t, e = tokens.size()
        
        position = self.pos_embedding(torch.arange(t, device=d()))[None, :, :].expand(b ,t, e)
        x = tokens + positions
        
        x = self.tblocks(x)
        
        x = self.toprobs(x.view(b*t, e)).view(b , t, self.num_tokens)
        
        return F.log_softmax(x, dim=2)