In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
model = nn.Linear(3, 5, bias=False, device='cuda')

In [7]:
sum([param.numel() for param in model.parameters()])

20

In [2]:
# Assume 
batch_size = 32
seq_len = 10
N = 4

x = torch.randn(batch_size, seq_len, N)
raw_weights = torch.bmm(x, x.transpose(1, 2))

In [16]:
x.shape

torch.Size([32, 10, 4])

In [23]:
y = x.transpose(1, 2).contiguous()
y.shape

torch.Size([32, 4, 10])

In [22]:
y

tensor([[[ 1.7630,  0.6042, -0.3009,  ...,  1.4541, -1.4075, -0.0840],
         [ 0.5463, -0.3298, -0.0373,  ...,  0.9029,  0.8614,  0.1479],
         [ 0.3253,  0.4925,  0.7287,  ...,  0.1725, -0.2038,  1.3128],
         [-0.6950, -0.2766,  0.4072,  ...,  0.6336,  0.2290, -0.3056]],

        [[ 1.3034, -0.8500,  0.8639,  ...,  1.4423,  1.7732,  1.4420],
         [-0.4350, -1.5608, -0.0585,  ...,  0.8729,  1.0806,  0.1019],
         [ 0.1021, -0.2605, -0.0435,  ..., -1.1443, -0.7740,  0.4617],
         [ 0.9360, -0.3276,  0.7518,  ..., -0.4964,  1.0971,  0.6500]],

        [[ 0.1747,  0.9367,  0.3179,  ..., -0.2144,  0.0483,  0.7049],
         [ 2.4926,  1.0518, -1.6349,  ...,  0.3182, -1.1450,  0.4416],
         [ 1.9857, -0.0313, -1.2208,  ..., -0.0224, -0.3288,  0.7385],
         [ 0.4569, -0.2774, -0.6974,  ...,  1.3395,  0.7676, -1.7763]],

        ...,

        [[ 0.0245,  0.1883,  1.6618,  ...,  1.5760, -0.0209, -0.8999],
         [ 2.1933,  0.3406,  0.9471,  ...,  0.4746,  0.53

In [19]:
y

tensor([[[ 1.7630,  0.6042, -0.3009,  ...,  1.4541, -1.4075, -0.0840],
         [ 0.5463, -0.3298, -0.0373,  ...,  0.9029,  0.8614,  0.1479],
         [ 0.3253,  0.4925,  0.7287,  ...,  0.1725, -0.2038,  1.3128],
         [-0.6950, -0.2766,  0.4072,  ...,  0.6336,  0.2290, -0.3056]],

        [[ 1.3034, -0.8500,  0.8639,  ...,  1.4423,  1.7732,  1.4420],
         [-0.4350, -1.5608, -0.0585,  ...,  0.8729,  1.0806,  0.1019],
         [ 0.1021, -0.2605, -0.0435,  ..., -1.1443, -0.7740,  0.4617],
         [ 0.9360, -0.3276,  0.7518,  ..., -0.4964,  1.0971,  0.6500]],

        [[ 0.1747,  0.9367,  0.3179,  ..., -0.2144,  0.0483,  0.7049],
         [ 2.4926,  1.0518, -1.6349,  ...,  0.3182, -1.1450,  0.4416],
         [ 1.9857, -0.0313, -1.2208,  ..., -0.0224, -0.3288,  0.7385],
         [ 0.4569, -0.2774, -0.6974,  ...,  1.3395,  0.7676, -1.7763]],

        ...,

        [[ 0.0245,  0.1883,  1.6618,  ...,  1.5760, -0.0209, -0.8999],
         [ 2.1933,  0.3406,  0.9471,  ...,  0.4746,  0.53

In [6]:
raw_weights.shape

torch.Size([32, 10, 10])

In [7]:
weights = F.softmax(raw_weights, dim=2)

In [12]:
y = torch.bmm(weights, x)

In [13]:
y.shape

torch.Size([32, 10, 4])

### Multi-head, Self-Attention Block 

In [24]:
class SelfAttention(nn.Module):
    def __init__(self, k, heads=8):
        super().__init__()
        self.k = k
        self.heads = 8
        
        # queries, keys, values
        self.tokeys = nn.Linear(k, k*heads, bias=False)
        self.toqueries = nn.Linear(k, k*heads, bias=False)
        self.tovalues = nn.Linear(k, k*heads, bias=False)
        
        # Unify outputs of multiple heads to one
        self.unifyheads = nn.Linear(k*heads, k)
    
    def forward(self, x):
        b, t, k = x.size()
        h = self.heads
        
        queries = self.toqueries(x).view(b, t, h, k)
        keys = self.toqueries(x).view(b, t, h, k)
        values = self.tovalues(x).view(b, t, h, k)
        
        keys = keys.transpose(1, 2).contiguous().view(b*h, t, k)
        
        queries = queries / (k ** (1/4))
        keys = keys / (k ** (1/4))
        
        dot = torch.bmm(queries, keys.transpose(1, 2))
        dot = F.softmax(dot, dim=2)
        # Apply attention to values
        out = torch.bmm(dot, values).view(b, h, t, k)
        out.transpose(1, 2).contiguous().view(b, t, h*k)
        
        return self.unifyheads(out)

### Transformer Block

In [27]:
class TransformerBlock(nn.Module):
    def __init__(self, k, heads):
        super().__init__()
        self.attention = SelfAttention(k, heads)
        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()
        
        self.ff = nn.Sequential(
            nn.Linear(k, 4*k),
            nn.ReLU(),
            nn.Linear(4*k, k))
    
    def forward(self, x):
        attended = self.attention(x)
        x = self.norm1(x+attention)
        fedforward = self.ff(x)
        return self.norm2(x+fedforward)

### Transformer

In [None]:
class Transformer(nn.Module):
    
    def __init__(self, k, heads, depth, seq_length, num_tokens, num_classes):
        super().__init__()
        
        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(seq_length, k)
        
        # Sequence of Transformer Block
        trans_blocks = []
        for i in range(depth):
            trans_blocks.append(TransformerBlock(k, heads))
        self.trans_blocks = nn.Sequential(*trans_blocks)
        
        self.toprobs = nn.Linear(k, num_classes)
        
    def forward(self, x):
        tokens = self.token_emb(x)
        b, t, k = tokens.size()
        
        positions = torch.arange(t)
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, k)
        
        x = tokens + positions
        x = self.trans_blocks(x)
        
        x = self.toprobs(x.mean(dim=1))
        return F.log_softmax(x, dim=1)
