In [179]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import math
from torch.autograd import Variable

In [263]:
class MultiHeadedAttention(nn.Module):
    
    def __init__(self, dims:int, heads:int, dropout:float=0.1):
        
        super(MultiHeadedAttention, self).__init__()
        
        # number of dims must be evenly divisible by num heads
        assert dims % heads == 0
        
        self.dims = dims
        self.heads = heads
        self.head_dims = dims // heads
        self.dk_sqrt = math.sqrt(self.head_dims)
        
        self.qkv_projection = nn.Linear(self.dims,self.dims * 3)
        self.attn_dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(self.dims,self.dims)
        self.resid_dropout = nn.Dropout(dropout)
        
    def calculate_attention(self, q, k, v):
        
        # get similarity of key and query via dot product
        qk_similarity = q @ k.transpose(-2, -1)
        # normalize values by head dims
        normalized_qk_similarity = qk_similarity / self.dk_sqrt
        attention = torch.softmax(normalized_qk_similarity, dim=-1)
        attention = self.attn_dropout(attention)
        out = attention @ v
        
        return out, attention
    
    def forward(self, x):
        
        B, T, C = x.shape # batch size, num tokens, token embedding size
        
        # linear projection for q, k, v
        qkv = self.qkv_projection(x).split(self.dims, dim=2)
        
        # reshape to (batch, heads, tokens, head_dims)
        q, k, v = [cv.view((B, -1, self.heads, self.head_dims)).transpose(1, 2) for cv in qkv]
        
        # calculatue and apply attention
        y, attention = self.calculate_attention(q, k, v)
        
        # concat all heads
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        # outward projection and residual dropout
        y = self.resid_dropout(self.out_projection(y))
        
        return y

In [264]:
model = MultiHeadedAttention(100, 20)
x = torch.randn(3, 12, 100)
model(model(x)).shape

torch.Size([3, 12, 100])

In [285]:
class PositionalEncoder(nn.Module):
    
    def __init__(self, embedding_dims:int, max_sequence:int):
        
        super(PositionalEncoder, self).__init__()
        
        self.embedding_dims = embedding_dims
        self.max_sequence = max_sequence
        
        positional_embedding = torch.zeros(max_sequence, embedding_dims)
        div_pos = torch.arange(0, self.embedding_dims, 2)
        divisors = torch.exp(div_pos * -(math.log(10000.0) / self.embedding_dims))
        positions = torch.arange(max_sequence)[:,None,...]
        
        positional_embedding[:, 0::2] = torch.sin(positions * divisors)
        positional_embedding[:, 1::2] = torch.cos(positions * divisors)
        positional_embedding = positional_embedding[None,...]
        
        self.register_buffer('positional_embedding', positional_embedding)
        
        
        
    def forward(self, x):
        out = Variable(x + self.positional_embedding[:, :x.shape[1]], requires_grad=False)
        return out

In [265]:
class PositionWiseFFN(nn.Module):
    
    def __init__(self, in_dims:int, hidden_dims:int, dropout:float=0.1):
        
        super(PositionWiseFFN, self).__init__()
        
        self.sequence = nn.Sequential(
            nn.Linear(in_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, in_dims),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.sequence(x)
    
pwffn = PositionWiseFFN(10, 100)
pwffn(torch.randn(5,10,10)).shape

torch.Size([5, 10, 10])

In [282]:
class TransformerBlock(nn.Module):
    
    def __init__(self, 
                 att_heads:int = 8, 
                 att_dims:int = 64, 
                 ff_dims:int = 256
                ):
        
        super(TransformerBlock, self).__init__()
        
        self.att_heads = att_heads
        self.att_dims = att_dims
        self.ff_dims = ff_dims
        
        self.ln_1 = nn.LayerNorm(self.att_dims)
        self.attention = MultiHeadedAttention(self.att_dims, self.att_heads)
        self.ln_2 = nn.LayerNorm(self.att_dims)
        self.feed_forward = PositionWiseFFN(self.att_dims, self.ff_dims)
        
    def forward(self, x):
        
        x = x + self.attention(self.ln_1(x))
        x = x + self.feed_forward(self.ln_2(x))
        
        return x
        
tb = TransformerBlock()
x = torch.randn((3,12,64))
tb(x).shape

torch.Size([3, 12, 64])

In [297]:
class DecoderTransformer(nn.Module):
    
    def __init__(self, 
                 vocab_size:int = 32,
                 max_sequence:int = 64,
                 embedding_dims:int = 64,
                 att_heads:int = 8,
                 ff_dims:int = 256,
                 dropout:float = 0.1
                ):
        
        super(DecoderTransformer, self).__init__()
        
        self.vocab_size = vocab_size
        self.max_sequence = max_sequence
        self.embedding_dims = embedding_dims
        self.att_heads = att_heads
        self.ff_dims = ff_dims
        
        self.sequence = nn.Sequential(
            nn.Embedding(self.vocab_size, self.embedding_dims),
            PositionalEncoder(self.embedding_dims, self.max_sequence), # maybe just change this to an embedding too
            nn.Dropout(dropout),
            TransformerBlock(self.att_heads, self.embedding_dims, self.ff_dims),
            nn.LayerNorm(self.embedding_dims)
        )
        
        self.lm_head = nn.Linear(self.embedding_dims, self.vocab_size)
        
    def forward(self, x):
        x = self.sequence(x)
        return self.lm_head(x)

model = DecoderTransformer()
x = torch.randint(31, (4,16))
model(x).shape

tensor([[[-4.3223e-01,  7.2902e-01,  2.3409e-01,  ...,  2.9002e-01,
           2.9854e-02, -4.8219e-01],
         [-3.6588e-01,  1.2729e+00, -6.3550e-01,  ..., -4.5261e-01,
           4.2971e-01, -1.0852e+00],
         [-1.3221e+00,  1.4389e-01,  1.1058e+00,  ...,  2.8478e-01,
          -4.8847e-02, -5.2399e-01],
         ...,
         [-9.9639e-01, -5.1359e-02,  1.2110e+00,  ..., -3.3482e-01,
           6.1723e-01, -3.4147e-02],
         [-7.9969e-01,  8.3777e-01, -9.3797e-02,  ..., -1.2536e+00,
           4.5533e-02, -1.2406e+00],
         [-1.0880e+00,  3.3933e-01, -3.8144e-01,  ..., -9.4022e-01,
           1.1138e-01,  5.3261e-01]],

        [[-5.2703e-01, -1.0743e-01, -5.4753e-01,  ...,  2.2119e-01,
           5.0657e-01,  3.0102e-01],
         [-2.4397e-01,  7.3370e-01,  1.6236e-01,  ...,  2.4425e-01,
           5.1760e-02, -3.4514e-01],
         [-4.7319e-01,  1.3166e+00,  5.3485e-02,  ...,  4.0152e-01,
          -1.4632e-01, -7.4426e-01],
         ...,
         [-2.9669e-01,  1