In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# http://peterbloem.nl/blog/transformers

## naive self-attention

In [21]:
"""
b, batch size
t, sequence length
k, vector dimension
"""
b, t, k = 4, 5, 16
x = torch.randint(low=0, high=10, size=(b, t, k))
x = x.float()

In [12]:
# batched matrix multiplication with torch.bmm
raw_weights = torch.bmm(x, x.transpose(1, 2))
weights = F.softmax(raw_weights, dim=2)
y = torch.bmm(weights, x)

## complete self-attention

In [17]:
class SelfAttention(nn.Module):
    def __init__(self, k, heads=8):
        super().__init__()
        self.k, self.heads = k, heads
        # concatenate heads together into a single matrix to speed up
        self.tokeys = nn.Linear(k, k*heads, bias=False)
        self.toqueries = nn.Linear(k, k*heads, bias=False)
        self.tovalues = nn.Linear(k, k*heads, bias=False)
        
        # unify outputs of different heads into a single k-vector
        self.unifyheads = nn.Linear(heads*k, k)
        
    def forward(self, x):
        b, t, k = x.size()
        h = self.heads
        
        queries = self.toqueries(x).view(b, t, h, k) # output: (b, t, h*k)
        keys = self.tokeys(x).view(b, t, h, k)
        values = self.tovalues(x).view(b, t, h, k)
        
        # fold heads into batch dimension
        keys = keys.transpose(1, 2).contiguous().view(b*h, t, k) # transpose can be expensive operation
        queries = queries.transpose(1, 2).contiguous().view(b*h, t, k)
        values = values.transpose(1, 2).contiguous().view(b*h, t, k)
        
        # normalization
        queries = queries / (k ** (1/4))
        keys = keys / (k ** (1/4))
        dot = torch.bmm(queries, keys.transpose(1, 2)) # output dimension: (b*h, t, t)
        dot = F.softmax(dot, dim=2) # row-wise normalization
        
        out = torch.bmm(dot, values).view(b, h, t, k)
        
        out = out.transpose(1, 2).contiguous().view(b, t, h*k)
        return self.unifyheads(out)

In [23]:
self_att = SelfAttention(k)
#self_att(x)

In [30]:
class TransformerBlock(nn.Module):
    def __init__(self, k, heads):
        super().__init__()
        self.attention = SelfAttention(k, heads = heads)
        self.norm1 = nn.LayerNorm(k)
        self.norm2 = nn.LayerNorm(k)
        self.ff = nn.Sequential(
                    nn.Linear(k, 4*k),
                    nn.ReLU(),
                    nn.Linear(4*k, k))
        
    def forward(self, x):
        attended = self.attention(x)
        x = self.norm1(attended + x)
        feedforward = self.ff(x)
        return self.norm2(feedforward + x)

In [31]:
transformer_block = TransformerBlock(k, heads = 8)

In [33]:
# transformer_block(x)

## Trasformer-based Classifier

In [35]:
class Transformer(nn.Module):
    def __init__(self, k, heads, depth, seq_length, num_tokens, num_classes):
        super().__init__()
        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(seq_length, k)
        
        tblocks = []
        for i in range(depth):
            tblocks.append(TransformerBlock(k, heads))
        self.tblocks = nn.Sequential(*tblocks)
        
        self.toprobs = nn.Linear(k, num_classes)
        
    def forward(self, x):
        """
        :param x: A (b, t) tensor of integer values
        :return: A (b, c) tensor of log-prob over the classes
        """
        tokens = self.token_emb(x)
        b, t, k = tokens.size()
        # position embeddings
        positions = torch.arange(t)
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, k)
        
        x = tokens + positions
        x = self.tblocks(x)
        
        x = self.toprobs(x.mean(dim=1))
        return F.log_softmax(x, dim=1)

In [36]:
transformer = Transformer(k, 8, 2, 10, 20, 2)

In [43]:
x = torch.randint(low=0, high=10, size=(b, t))

transformer(x)

tensor([[-0.9025, -0.5201],
        [-0.8029, -0.5942],
        [-0.6675, -0.7195],
        [-0.5746, -0.8276]], grad_fn=<LogSoftmaxBackward>)