# Building the Transformer from Scratch

In this notebook, we'll be implementing the famous Transformer architecture from scratch.

The code is based off of the following repos/blog posts:

- [attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch)
- [pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT)
- [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) 

Thanks so much to their authors!

In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
print(torch.__version__)
print(torch.cuda.is_available())

1.7.1
True


One of the keys to understanding how any model works is understanding how the shapes of the tensors change during the processing of each part. We'll be using the logging module to output debugging information to help our understanding.

In [3]:
import logging
logger = logging.getLogger("tensor_shapes")
handler = logging.StreamHandler()
formatter = logging.Formatter(
        '%(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
# if you want the model to continuously print tensor shapes, set to DEBUG!
logger.setLevel(1)

In [4]:
import inspect
def getclass():
    stack = inspect.stack()
    return stack[3][0].f_locals["self"].__class__

# A helper function to check how tensor sizes change
def log_size(tsr: torch.Tensor, name: str):
    cls = getclass()
    logger.log(level=cls.level, msg=f"[{cls.__name__}] {name} size={tsr.shape}")

We'll use logging levels to control the modules we receive output from. The lower the logging level, the more tensor information you'll get. Feel free to play around!

In [5]:
from enum import IntEnum
# Control how much debugging output we want
class TensorLoggingLevels(IntEnum):
    attention = 1
    attention_head = 2
    multihead_attention_block = 3
    enc_dec_block = 4
    enc_dec = 5

We'll be using an enum to refer to dimensions whenever possible to improve readability.

In [6]:
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

# Components

### Scaled dot product attention

The Transformer is an attention-based architecture. The attention used in the Transformer is the scaled dot product attention, represented by the following formula.

$$ \textrm{Attention}(Q, K, V) = \textrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

![image](https://i2.wp.com/mlexplained.com/wp-content/uploads/2017/12/scaled_dot_product_attention.png?zoom=2&w=750)

In [7]:
import math

class ScaledDotProductAttention(nn.Module):
    level = TensorLoggingLevels.attention # Logging level: 
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        d_k = k.size(-1) # get the size of the key
        assert q.size(-1) == d_k

        # compute the dot product between queries and keys for
        # each batch and position in the sequence
        attn = torch.bmm(q, k.transpose(Dim.seq, Dim.feature)) # (Batch, Seq, Seq)
        # we get an attention score between each position in the sequence
        # for each batch

        # scale the dot products by the dimensionality (see the paper for why we do this!)
        attn = attn / math.sqrt(d_k)
        # normalize the weights across the sequence dimension
        # (Note that since we transposed, the sequence and feature dimensions are switched)
        attn = torch.exp(attn)
        log_size(attn, "attention weight") # (Batch, Seq, Seq)
        
        # fill attention weights with 0s where padded
        if mask is not None: attn = attn.masked_fill(mask, 0)
        attn = attn / attn.sum(dim=-1, keepdim=True)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v) # (Batch, Seq, Feature)
        log_size(output, "attention output size") # (Batch, Seq, Seq)
        return output

In [8]:
attn = ScaledDotProductAttention()

In [9]:
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)

In [10]:
attn(q, k, v)

[ScaledDotProductAttention] attention weight size=torch.Size([5, 10, 10])
[ScaledDotProductAttention] attention output size size=torch.Size([5, 10, 20])


tensor([[[0.4857, 0.3804, 0.4538, 0.6123, 0.4129, 0.5331, 0.6630, 0.5870,
          0.6286, 0.5360, 0.5223, 0.4314, 0.5797, 0.5173, 0.4921, 0.4940,
          0.4200, 0.5500, 0.6776, 0.2807],
         [0.4994, 0.4496, 0.4019, 0.5565, 0.4321, 0.5182, 0.6957, 0.5612,
          0.6322, 0.4660, 0.4624, 0.4465, 0.5715, 0.4583, 0.5564, 0.4500,
          0.4112, 0.4479, 0.6002, 0.3139],
         [0.5427, 0.3497, 0.4569, 0.6068, 0.4407, 0.5626, 0.6485, 0.5449,
          0.5605, 0.4981, 0.5649, 0.3616, 0.5621, 0.5028, 0.4819, 0.4704,
          0.4494, 0.5098, 0.5724, 0.3464],
         [0.5319, 0.4777, 0.4630, 0.6271, 0.4476, 0.6110, 0.7463, 0.6055,
          0.7028, 0.5336, 0.5724, 0.5006, 0.6333, 0.5503, 0.5727, 0.5186,
          0.4917, 0.5291, 0.6765, 0.3580],
         [0.5602, 0.4699, 0.4619, 0.6399, 0.4525, 0.6335, 0.7449, 0.6070,
          0.6830, 0.5381, 0.5995, 0.4956, 0.6323, 0.5505, 0.5741, 0.5287,
          0.5057, 0.5323, 0.6766, 0.3656],
         [0.5513, 0.4791, 0.4524, 0.6368, 0.4

### Multi-Head Attention

Now, we turn to the core component in the Transformer architecture: the multi-head attention block. This block applies linear transformations to the input, then applies scaled dot product attention.

![image](https://i2.wp.com/mlexplained.com/wp-content/uploads/2017/12/multi_head_attention.png?zoom=2&resize=224%2C293)

In [11]:
class AttentionHead(nn.Module):
    level = TensorLoggingLevels.attention_head
    def __init__(self, d_model, d_feature, dropout=0.1):
        super().__init__()
        # We will assume the queries, keys, and values all have the same feature size
        self.attn = ScaledDotProductAttention(dropout)
        self.query_tfm = nn.Linear(d_model, d_feature)
        self.key_tfm = nn.Linear(d_model, d_feature)
        self.value_tfm = nn.Linear(d_model, d_feature)

    def forward(self, queries, keys, values, mask=None):
        Q = self.query_tfm(queries) # (Batch, Seq, Feature)
        K = self.key_tfm(keys) # (Batch, Seq, Feature)
        V = self.value_tfm(values) # (Batch, Seq, Feature)
        log_size(Q, "queries, keys, vals")
        # compute multiple attention weighted sums
        x = self.attn(Q, K, V)
        return x

In [12]:
attn_head = AttentionHead(20, 20)
attn_head(q, k, v)

[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[ScaledDotProductAttention] attention weight size=torch.Size([5, 10, 10])
[ScaledDotProductAttention] attention output size size=torch.Size([5, 10, 20])


tensor([[[-0.1168, -0.3617,  0.5936, -0.5844, -0.1107,  0.0041, -0.2305,
           0.1480, -0.1898, -0.7099, -0.4407,  0.4266, -0.0891,  0.4478,
          -0.0348, -0.2356, -0.5981,  0.0698,  0.3248,  0.2721],
         [-0.1120, -0.3312,  0.4993, -0.4546, -0.1087, -0.0078, -0.2317,
           0.1108, -0.1533, -0.6003, -0.3819,  0.4152, -0.0756,  0.4097,
           0.0410, -0.2562, -0.5285,  0.0390,  0.2496,  0.1992],
         [-0.1314, -0.3839,  0.5574, -0.5464, -0.1182,  0.0010, -0.2491,
           0.1516, -0.1735, -0.7196, -0.4778,  0.4343, -0.1131,  0.4082,
           0.0569, -0.2585, -0.6125,  0.0342,  0.2955,  0.2721],
         [-0.1523, -0.3917,  0.5769, -0.5523, -0.1125,  0.0246, -0.2222,
           0.1460, -0.1641, -0.7126, -0.4480,  0.4509, -0.0647,  0.4409,
           0.0014, -0.2791, -0.5824,  0.0839,  0.2986,  0.2416],
         [-0.1141, -0.2396,  0.3893, -0.3451, -0.1042,  0.0187, -0.1365,
           0.1195, -0.0713, -0.5247, -0.3133,  0.3209, -0.0427,  0.2756,
          

The multi-head attention block simply applies multiple attention heads, then concatenates the outputs and applies a single linear projection.

In [13]:
# We'll supress logging from the scaled dot product attention now
logger.setLevel(TensorLoggingLevels.attention_head)

In [14]:
class MultiHeadAttention(nn.Module):
    level = TensorLoggingLevels.multihead_attention_block
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_feature = d_feature
        self.n_heads = n_heads
        # in practice, d_model == d_feature * n_heads
        assert d_model == d_feature * n_heads

        # Note that this is very inefficient:
        # I am merely implementing the heads separately because it is 
        # easier to understand this way
        self.attn_heads = nn.ModuleList([
            AttentionHead(d_model, d_feature, dropout) for _ in range(n_heads)
        ])
        self.projection = nn.Linear(d_feature * n_heads, d_model) 
    
    def forward(self, queries, keys, values, mask=None):
        log_size(queries, "Input queries")
        x = [attn(queries, keys, values, mask=mask) # (Batch, Seq, Feature)
             for i, attn in enumerate(self.attn_heads)]
        log_size(x[0], "output of single head")
        
        # reconcatenate
        x = torch.cat(x, dim=Dim.feature) # (Batch, Seq, D_Feature * n_heads)
        log_size(x, "concatenated output")
        x = self.projection(x) # (Batch, Seq, D_Model)
        log_size(x, "projected output")
        return x

In [15]:
heads = MultiHeadAttention(20 * 8, 20, 8)
heads(q.repeat(1, 1, 8), 
      k.repeat(1, 1, 8), 
      v.repeat(1, 1, 8))

[MultiHeadAttention] Input queries size=torch.Size([5, 10, 160])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 20])
[MultiHeadAttention] concatenated output size=torch.Size([5, 10, 160])
[MultiHeadAttention] projected output size=torch.Size([5, 10, 160])


tensor([[[-0.2726,  0.0613,  0.0665,  ...,  0.1950, -0.0730, -0.2064],
         [-0.2451,  0.0897,  0.0726,  ...,  0.2008, -0.0787, -0.2203],
         [-0.2224,  0.0783,  0.0927,  ...,  0.2146, -0.0632, -0.2278],
         ...,
         [-0.2027,  0.0793,  0.0466,  ...,  0.2121, -0.0786, -0.2020],
         [-0.2109,  0.0995,  0.0565,  ...,  0.1801, -0.0490, -0.1820],
         [-0.2407,  0.0730,  0.0738,  ...,  0.1710, -0.0650, -0.2465]],

        [[-0.3743,  0.0265,  0.0836,  ...,  0.1776, -0.1079, -0.2764],
         [-0.2933,  0.1091,  0.0896,  ...,  0.1938, -0.0715, -0.2439],
         [-0.2999,  0.0532,  0.0834,  ...,  0.1677, -0.0811, -0.2427],
         ...,
         [-0.2830,  0.0988,  0.0813,  ...,  0.1985, -0.0696, -0.2285],
         [-0.2947,  0.0946,  0.0936,  ...,  0.1996, -0.0590, -0.2627],
         [-0.2478,  0.1118,  0.0684,  ...,  0.1838, -0.0690, -0.2064]],

        [[-0.2833,  0.0613,  0.0448,  ...,  0.1163, -0.0277, -0.2150],
         [-0.2705,  0.0666,  0.0359,  ...,  0

### The Encoder

With these core components in place, implementing the encoder is pretty easy.

![image](https://i2.wp.com/mlexplained.com/wp-content/uploads/2017/12/%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%BC%E3%83%B3%E3%82%B7%E3%83%A7%E3%83%83%E3%83%88-2017-12-29-19.14.41.png?w=273)

The encoder consists of the following components:
- A multi-head attention block
- A simple feedforward neural network

These components are connected using residual connections and layer normalization

In [16]:
# We'll supress logging from the individual attention heads
logger.setLevel(TensorLoggingLevels.multihead_attention_block)

Layer normalization is similar to batch normalization, but normalizes across the feature dimension instead of the batch dimension.

![image](https://i1.wp.com/mlexplained.com/wp-content/uploads/2018/01/%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%BC%E3%83%B3%E3%82%B7%E3%83%A7%E3%83%83%E3%83%88-2018-01-11-11.48.12.png?w=1500)

In [17]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

The encoder just stacks these together

In [18]:
class EncoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64,
                 d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.layer_norm1 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.layer_norm2 = LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        log_size(x, "Encoder block input")
        att = self.attn_head(x, x, x, mask=mask)
        log_size(x, "Attention output")
        # Apply normalization and residual connection
        x = x + self.dropout(self.layer_norm1(att))
        # Apply position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        log_size(x, "Feedforward output")
        # Apply normalization and residual connection
        x = x + self.dropout(self.layer_norm2(pos))
        log_size(x, "Encoder size output")
        return x

In [19]:
enc = EncoderBlock()

In [20]:
enc(torch.rand(5, 10, 512))

[EncoderBlock] Encoder block input size=torch.Size([5, 10, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 10, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 64])
[MultiHeadAttention] concatenated output size=torch.Size([5, 10, 512])
[MultiHeadAttention] projected output size=torch.Size([5, 10, 512])
[EncoderBlock] Attention output size=torch.Size([5, 10, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 10, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 10, 512])


tensor([[[ 2.8346e+00, -7.4305e-01, -1.5199e-01,  ...,  6.0437e-01,
           3.0892e-01, -2.6203e-01],
         [ 4.3891e+00, -3.2669e-03,  2.4501e-01,  ...,  4.7756e-01,
          -7.6516e-01,  1.8474e-01],
         [ 4.2455e+00,  5.2228e-01,  1.9544e-01,  ...,  3.9973e-01,
          -1.2112e+00,  3.9558e-02],
         ...,
         [ 3.4253e+00, -8.0995e-01, -1.8052e-01,  ...,  1.7636e-01,
          -3.0822e-01, -6.1162e-01],
         [ 3.9821e+00, -5.2248e-01,  2.2056e-01,  ...,  5.7183e-01,
           7.4778e-02, -1.0841e+00],
         [ 4.0259e+00, -1.8689e-01, -2.2434e-01,  ..., -1.6212e-01,
           2.4365e-01,  2.6721e-01]],

        [[ 4.4928e+00, -2.9093e-01, -3.3372e-01,  ..., -4.9624e-01,
          -5.8743e-01, -4.5782e-01],
         [ 4.2599e+00,  3.6836e-01, -8.1547e-03,  ..., -7.4190e-01,
          -2.3468e-01, -4.7133e-01],
         [ 3.4088e+00, -6.7641e-01, -2.2919e-01,  ..., -3.5993e-01,
          -1.1139e+00, -1.2430e+00],
         ...,
         [ 4.3702e+00,  1

The encoder consists of 6 consecutive encoder blocks, so can simply be implemented like the following

In [21]:
class TransformerEncoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    def __init__(self, n_blocks=6, d_model=512,
                 n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.encoders = nn.ModuleList([
            EncoderBlock(d_model=d_model, d_feature=d_model // n_heads,
                         d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks)
        ])
    
    def forward(self, x: torch.FloatTensor, mask=None):
        for encoder in self.encoders:
            x = encoder(x)
        return x

### The Decoder

The decoder is mostly the same as the encoder. There's just one additional multi-head attention block that takes the target sentence as input.

![image](https://i1.wp.com/mlexplained.com/wp-content/uploads/2017/12/%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%BC%E3%83%B3%E3%82%B7%E3%83%A7%E3%83%83%E3%83%88-2017-12-29-19.14.47.png?w=287)

The keys and values are the outputs of the encoder, and the queries are the outputs of the multi-head attention over the target entence embeddings.

In [22]:
class DecoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64,
                 d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.masked_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.layer_norm3 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, 
                src_mask=None, tgt_mask=None):
        # Apply attention to inputs
        att = self.masked_attn_head(x, x, x, mask=src_mask)
        x = x + self.dropout(self.layer_norm1(att))
        # Apply attention to the encoder outputs and outputs of the previous layer
        att = self.attn_head(queries=x, keys=enc_out, values=enc_out, mask=tgt_mask)
        x = x + self.dropout(self.layer_norm2(att))
        # Apply position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        x = x + self.dropout(self.layer_norm2(pos))
        return x

In [23]:
dec = DecoderBlock()
dec(torch.rand(5, 10, 512), enc(torch.rand(5, 10, 512)))

[EncoderBlock] Encoder block input size=torch.Size([5, 10, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 10, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 64])
[MultiHeadAttention] concatenated output size=torch.Size([5, 10, 512])
[MultiHeadAttention] projected output size=torch.Size([5, 10, 512])
[EncoderBlock] Attention output size=torch.Size([5, 10, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 10, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 10, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 10, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 64])
[MultiHeadAttention] concatenated output size=torch.Size([5, 10, 512])
[MultiHeadAttention] projected output size=torch.Size([5, 10, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 10, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 64])
[MultiHeadAttention] concatenated output size=torch.Siz

tensor([[[ 1.7539,  1.7319,  0.0315,  ...,  0.0723,  1.6455,  3.1135],
         [ 1.5491,  0.5092,  0.0440,  ..., -0.4172,  1.7473,  2.5684],
         [ 0.9116,  1.3977, -1.6380,  ..., -0.5238,  0.6878,  1.9830],
         ...,
         [ 1.1654,  1.5397, -1.0929,  ..., -0.4030,  2.1998,  1.5289],
         [ 1.9782,  1.1124,  0.3097,  ..., -0.2545,  1.8794,  2.4676],
         [ 1.2787,  1.9875, -0.9133,  ...,  0.7860,  0.9329,  2.9307]],

        [[-0.2970,  1.4840, -0.8573,  ..., -0.6599,  1.4021,  1.9822],
         [ 0.5374,  0.9221,  0.4522,  ..., -0.3725, -0.4283,  2.2946],
         [-0.5751,  0.2213,  0.0646,  ..., -0.5866,  0.8773,  1.3984],
         ...,
         [ 0.6994,  1.1560, -0.2433,  ..., -0.7659,  2.5076,  3.0898],
         [ 1.2502,  0.2776,  0.0700,  ..., -1.1239,  2.2884,  4.1344],
         [ 0.9814,  0.2584, -0.8471,  ..., -1.1770,  1.8067,  2.7638]],

        [[ 0.2110,  1.3540, -0.2460,  ..., -0.1773,  2.9904,  2.4827],
         [ 0.5595,  2.2145, -0.0809,  ...,  0

Again, the decoder is just a stack of the underlying block so is simple to implement.

In [24]:
class TransformerDecoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    def __init__(self, n_blocks=6, d_model=512, d_feature=64,
                 d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.position_embedding = PositionalEmbedding(d_model)
        self.decoders = nn.ModuleList([
            DecoderBlock(d_model=d_model, d_feature=d_model // n_heads,
                         d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks)
        ])
        
    def forward(self, x: torch.FloatTensor, 
                enc_out: torch.FloatTensor, 
                src_mask=None, tgt_mask=None):
        for decoder in self.decoders:
            x = decoder(x, enc_out, src_mask=src_mask, tgt_mask=tgt_mask)
        return x

### Positional Embeddings

Attention blocks are just simple matrix multiplications: therefore they don't have any notion of order! The Transformer explicitly adds positional information via the positional embeddings.

In [25]:
class PositionalEmbedding(nn.Module):
    level = 1
    def __init__(self, d_model, max_len=512):
        super().__init__()        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=False)
        
    def forward(self, x):
        return self.weight[:, :x.size(1), :] # (1, Seq, Feature)

In [26]:
class WordPositionEmbedding(nn.Module):
    level = 1
    def __init__(self, vocab_size, d_model=512):
        super().__init__()
        self.word_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = PositionalEmbedding(d_model)
        
    def forward(self, x: torch.LongTensor, mask=None) -> torch.FloatTensor:
        return self.word_embedding(x) + self.position_embedding(x)

In [27]:
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()

In [28]:
encoder(emb(torch.randint(1000, (5, 30))))

[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 30, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 30, 64])
[MultiHeadAttention] concatenated output size=torch.Size([5, 30, 512])
[MultiHeadAttention] projected output size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 30, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 30, 64])
[MultiHeadAttention] concatenated output size=torch.Size([5, 30, 512])
[MultiHeadAttention] projected output size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[

tensor([[[-1.6179, -0.9129, -1.1146,  ..., -2.4037,  1.2557,  0.5186],
         [ 0.0916,  2.5295, -2.1315,  ..., -0.4497,  4.3204,  2.8204],
         [ 1.5603, -3.4203, -2.0865,  ..., -0.4914,  1.3665,  2.6843],
         ...,
         [ 2.5888, -1.3422, -4.9319,  ..., -0.1028,  0.8384,  0.3542],
         [-5.0183, -4.6209, -0.9035,  ...,  0.4303, -0.1303,  4.3591],
         [-1.7145, -1.2137,  2.0444,  ...,  0.1812,  1.1217,  3.8644]],

        [[ 2.0313, -2.1645, -1.9350,  ..., -1.3981,  1.4702, -0.4281],
         [-1.2277,  2.5563, -0.6967,  ...,  0.6443, -2.3782, -2.0074],
         [-1.4581,  0.0839, -0.2799,  ..., -1.2875,  1.1970,  2.3062],
         ...,
         [-2.3831, -0.4709, -2.0701,  ...,  3.8697,  0.6059, -3.4711],
         [ 0.4490,  0.5291,  2.7600,  ...,  1.3978,  1.0326, -1.8111],
         [ 0.5248, -5.2180, -4.3385,  ...,  4.0048,  0.5965, -0.1954]],

        [[ 1.1011,  3.0193, -6.9870,  ..., -2.1411,  1.3825,  0.8011],
         [-1.0197, -2.8118, -5.5080,  ...,  3

### Putting it All Together

Let's put everything together now.

![image](https://camo.githubusercontent.com/88e8f36ce61dedfd2491885b8df2f68c4d1f92f5/687474703a2f2f696d6775722e636f6d2f316b72463252362e706e67)

In [29]:
# We'll supress logging from the scaled dot product attention now
logger.setLevel(TensorLoggingLevels.enc_dec_block)

In [30]:
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()
decoder = TransformerDecoder()

In [31]:
src_ids = torch.randint(1000, (5, 30))
tgt_ids = torch.randint(1000, (5, 30))
x = encoder(emb(src_ids))
decoder(emb(tgt_ids), x)

[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=t

tensor([[[ 3.9257,  1.4525, -1.6411,  ..., -0.0952, -0.6878,  6.8743],
         [-1.9338,  3.0697, -2.5299,  ...,  4.5336, -2.4466,  8.6706],
         [-4.1225,  2.8537, -4.6281,  ..., -1.2991, -0.5202,  9.1048],
         ...,
         [ 1.0886,  0.7016, -4.6155,  ...,  6.5053,  1.7904,  7.6790],
         [-3.7776, -0.6953, -9.0473,  ...,  7.3665,  1.2213,  3.6967],
         [ 0.4485,  1.7003, -3.4540,  ...,  7.4460, -3.4167,  5.7489]],

        [[-3.9350,  4.8898, -4.7293,  ...,  3.6016,  5.7632,  4.7220],
         [ 1.5016,  1.2416, -7.1086,  ...,  6.4188, -0.8304,  5.8426],
         [-3.1504,  0.9825, -6.5698,  ...,  9.4508, -1.6345,  2.9084],
         ...,
         [-5.7944,  1.2646, -4.3030,  ...,  7.2413, -1.0951,  1.7650],
         [-3.4720, -4.2639, -6.6723,  ...,  5.8381,  3.3105,  7.4766],
         [-2.5388,  0.0677, -9.8157,  ...,  6.1770, -0.3951,  5.8490]],

        [[ 0.1910,  2.5887, -3.6397,  ...,  2.0766,  3.5951,  8.8904],
         [ 1.0138,  3.6099, -2.3926,  ...,  5