In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
#  X with size (b, t, k)
# A sequence of t vectors of dimension k as a t by k matrix 𝐗
x = torch.randn(1, 3, 4)
print('x shape:', x.shape)
print('xT shape:', x.transpose(1, 2).shape)
raw_weights = torch.bmm(x, x.transpose(1, 2))
print('w shape:', raw_weights.shape)

x shape: torch.Size([1, 3, 4])
xT shape: torch.Size([1, 4, 3])
w shape: torch.Size([1, 3, 3])


In [3]:
# To turn the raw weights w′ij into positive values that sum to one
# apply a row-wise softmax
print(raw_weights)
weights = F.softmax(raw_weights, dim=2)
weights

tensor([[[ 1.3397,  0.3852, -0.1114],
         [ 0.3852,  5.6949,  0.8862],
         [-0.1114,  0.8862,  3.1280]]])


tensor([[[0.6176, 0.2377, 0.1447],
         [0.0049, 0.9871, 0.0081],
         [0.0342, 0.0928, 0.8730]]])

In [4]:
# To compute the output sequence, we just multiply the weight matrix by 𝐗. 
# This results in a batch of output matrices 𝐘 of size (b, t, k) whose rows are weighted sums over the rows of 𝐗.
# That’s all. Two matrix multiplications and one softmax gives us a basic self-attention.

y = torch.bmm(weights, x) # (1, 3, 3) x (1, 3, 4) = (1, 3, 4) 
y.shape

torch.Size([1, 3, 4])

In [6]:
class SelfAttention(nn.Module):
    def __init__(self, emb, heads=8):
        super(SelfAttention, self).__init__()
        self.emb = emb
        self.heads = heads
        # These compute the queries, keys and values for all 
        # heads (as a single concatenated vector)
        self.tokeys    = nn.Linear(emb, emb * heads, bias=False)
        self.toqueries = nn.Linear(emb, emb * heads, bias=False)
        self.tovalues  = nn.Linear(emb, emb * heads, bias=False)

        # This unifies the outputs of the different heads into a single k-vector
        self.unifyheads = nn.Linear(heads * emb, emb)
    
    def forward(self, x):
        b, t, e = x.size()
        h = self.heads

        assert e == self.emb, f'Input embedding dim ({e}) should match layer embedding dim ({self.emb})'

        queries = self.toqueries(x).view(b, t, h, e)
        keys = self.tokeys(x).view(b, t, h, e)
        values = self.tovalues(x).view(b, t, h, e)

        # - fold heads into the batch dimension
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, e)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, e)
        values = values.transpose(1, 2).contiguous().view(b * h, t, e)

        queries = queries / (k ** (1/4))
        keys    = keys / (k ** (1/4))

        # - get dot product of queries and keys, and scale
        dot = torch.bmm(queries, keys.transpose(1, 2))
        # - dot has size (b*h, t, t) containing raw weights

        dot = F.softmax(dot, dim=2) 
        # - dot now contains row-wise normalized weights

        # apply the self attention to the values
        out = torch.bmm(dot, values).view(b, h, t, e)

        # swap h, t back, unify heads
        out = out.transpose(1, 2).contiguous().view(b, t, h * e)
        return self.unifyheads(out)

In [8]:
class TransformerBlock(nn.Module):
    def __init__(self, emb, heads, ff_hidden_mult=4, dropout=0.0):
        super(TransformerBlock, self).__init__()
        
        self.attention = SelfAttention(emb, heads=heads)

        self.norm1 = nn.LayerNorm(emb)
        slef.norm2 = nn.LayerNorm(emb)

        self.dropout = nn.Dropout(dropout)

   
        self.ff = nn.Sequential(
            nn.Linear(emb, ff_hidden_mult*emb),
            nn.ReLU(),
            nn.Linear(ff_hidden_mult*emb, emb)
        )

    def forward(self, x):
        attented = self.attention(x)

        out = self.norm1(attended + x)
        out = self.dropout(out)

        fedforward = self.ff(out)

        out = self.norm2(ffedforward + out)

        out = self.dropout(out)

        return out

In [10]:
class Transformer(nn.Module):
    def __init__(self, emb, heads, depth, seq_length, num_tokens, num_classes, max_pool=True, dropout=0.0):
        """
        :param emb: Embedding dimension
        :param heads: nr. of attention heads
        :param depth: Number of transformer blocks
        :param seq_length: Expected maximum sequence length
        :param num_tokens: Number of tokens (usually words) in the vocabulary
        :param num_classes: Number of classes.
        :param max_pool: If true, use global max pooling in the last layer. If false, use global
                         average pooling.
        """

        super(Transformer, self).__init__()

        self.max_poo = max_pool
        self.num_tokens = num_tokens

        self.token_embedding = nn.Embedding(embedding_dim=emb, num_embeddings=num_tokens)
        self.pos_embedding = nn.Embedding(embedding_dim=emb, num_embeddings=seq_length)
        

        self.unify_embeddings = nn.Linear(2 * emb, emb)

        tblocks = []
        for i in range(depth):
            tblocks.append(TransformerBlock(emb, heads))
        
        self.tblocks = nn.Sequential(*tblocks)
        # Maps the final output sequence to class logits
        self.toprobs = nn.Linear(k, num_classes)

        self.dropout = nn.Dropout(dropout)
    

    def forward(self, x):
        """
        :param x: A batch by sequence length integer tensor of token indices.
        :return: predicted log-probability vectors for each token based on the preceding tokens.
        """

        # generate token embeddings
        tokens = self.token_embedding(x)
        b, t, e = tokens.size()

        # generate position embeddings
        positions = torch.arange(t)
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, e)

        out = torch.cat((tokens, positions), dim=2).view(-1, 2*e)
        out = self.unify_embeddings(out).view(b, t, e)

        out = self.dropout(out)
        out = self.tblocks(out)

        x = x.max(dim=1)[0] if self.max_pool else x.mean(dim=1) # pool over the time dimension
        
        # probabilities
        x = self.toprobs(out)

        return F.log_softmax(x, dim=1)