# Transformer from Scratch

In [1]:
import torch
import torch.nn as nn
import numpy as np
import math

In [2]:
import logging
logger = logging.getLogger("tensor_shapes")
handler = logging.StreamHandler()
formatter = logging.Formatter('%(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
# if you want the model to continuously print tensor shapes, set to DEBUG!
logger.setLevel(1)

In [3]:
import inspect
def getclass():
    stack = inspect.stack()
    return stack[3][0].f_locals["self"].__class__

# A helper function to check how tensor sizes change
def log_size(tsr: torch.Tensor, name: str):
    cls = getclass()
    logger.log(level=cls.level, msg=f"[{cls.__name__}] {name} size={tsr.shape}")

In [4]:
from enum import IntEnum
# Control how much debugging output we want
class TensorLoggingLevels(IntEnum):
    attention = 1
    attention_head = 2
    multihead_attention_block = 3
    enc_dec_block = 4
    enc_dec = 5

In [5]:
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

# NN builder

## Embeddings

In [6]:
class PositionalEmbedding(nn.Module):
    level = 1
    def __init__(self, d_model, max_len=512):
        super().__init__()        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=False)
        
    def forward(self, x):
#         return self.weight[:, :x.size(1), :] # (1, Seq, Feature)
        return self.weight[:, :x.size(1), :].expand(x.size(0), -1, -1) # (Batch, Seq, Feature)

In [7]:
posemb = PositionalEmbedding(512)
x = torch.randint(1000, (5, 30)).to(dtype=torch.long)
posemb(x).shape

torch.Size([5, 30, 512])

In [8]:
class WordPositionEmbedding(nn.Module):
    level = 1
    def __init__(self, vocab_size, d_model=512):
        super().__init__()
        self.word_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = PositionalEmbedding(d_model)
        
    def forward(self, x: torch.LongTensor, mask=None) -> torch.FloatTensor:
        return self.word_embedding(x) + self.position_embedding(x)

In [9]:
emb = WordPositionEmbedding(1000)
x = torch.randint(1000, (5, 30)).to(dtype=torch.long)
emb(x).shape

torch.Size([5, 30, 512])

## Attention layer

In [10]:
class ScaledDotProductAttention(nn.Module):
    level = TensorLoggingLevels.attention # Logging level: 
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        # q, k, v: (Batch, Seq, Feature)
        d_k = k.size(-1) # get the size of the key
        assert q.size(-1) == d_k

        # compute the dot product between queries and keys for
        # each batch and position in the sequence
        # (Batch, Seq, Feature) x (Batch, Feature, Seq) --> (Batch, Seq, Seq)
        attn = torch.bmm(q, k.transpose(Dim.seq, Dim.feature)) # (Batch, Seq, Seq)
        # we get an attention score between each position in the sequence
        # for each batch

        # scale the dot products by the dimensionality (see the paper for why we do this!)
        attn = attn / math.sqrt(d_k)
        # normalize the weights across the sequence dimension
        # (Note that since we transposed, the sequence and feature dimensions are switched)
        attn = torch.exp(attn)
        log_size(attn, "attention weight") # (Batch, Seq, Seq)
        
        # fill attention weights with 0s where padded
        if mask is not None: attn = attn.masked_fill(mask, 0)
        attn = attn / attn.sum(dim=-1, keepdim=True)
        attn = self.dropout(attn)
        # (Batch, Seq, Seq) x (Batch, Seq, Feature) --> (Batch, Seq, Feature)
        output = torch.bmm(attn, v) # (Batch, Seq, Feature)
        log_size(output, "attention output size") # (Batch, Seq, Feature)
        return output

In [11]:
# Double Checking
attn = ScaledDotProductAttention()
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)
attn(q, k, v)

[ScaledDotProductAttention] attention weight size=torch.Size([5, 10, 10])
[ScaledDotProductAttention] attention output size size=torch.Size([5, 10, 20])


tensor([[[0.4673, 0.4940, 0.5537, 0.3861, 0.4264, 0.5436, 0.5173, 0.4326,
          0.4225, 0.4436, 0.5345, 0.5865, 0.5043, 0.4486, 0.5436, 0.5258,
          0.5037, 0.5831, 0.5165, 0.5009],
         [0.4826, 0.5054, 0.6466, 0.4544, 0.5262, 0.6494, 0.6370, 0.5545,
          0.5133, 0.5524, 0.5364, 0.7067, 0.5404, 0.5767, 0.6499, 0.5754,
          0.5823, 0.6228, 0.5902, 0.5283],
         [0.4985, 0.5123, 0.6367, 0.4539, 0.5270, 0.6612, 0.6483, 0.5564,
          0.5258, 0.5484, 0.5416, 0.7046, 0.5314, 0.5768, 0.6504, 0.5703,
          0.5723, 0.6296, 0.5947, 0.5329],
         [0.4664, 0.4985, 0.6478, 0.4616, 0.5296, 0.6523, 0.6348, 0.5491,
          0.5042, 0.5672, 0.5341, 0.7246, 0.5641, 0.5575, 0.6537, 0.5500,
          0.5785, 0.6062, 0.5835, 0.5259],
         [0.4653, 0.5070, 0.6525, 0.4521, 0.5288, 0.6490, 0.6315, 0.5436,
          0.5084, 0.5625, 0.5314, 0.7196, 0.5618, 0.5590, 0.6638, 0.5535,
          0.5744, 0.6045, 0.5793, 0.5326],
         [0.4655, 0.5098, 0.5525, 0.4006, 0.5

In [12]:
class AttentionHead(nn.Module):
    level = TensorLoggingLevels.attention_head
    def __init__(self, d_model, d_feature, dropout=0.1):
        super().__init__()
        # We will assume the queries, keys, and values all have the same feature size
        self.attn = ScaledDotProductAttention(dropout)
        self.query_tfm = nn.Linear(d_model, d_feature)
        self.key_tfm = nn.Linear(d_model, d_feature)
        self.value_tfm = nn.Linear(d_model, d_feature)

    def forward(self, queries, keys, values, mask=None):
        Q = self.query_tfm(queries) # (Batch, Seq, Feature)
        K = self.key_tfm(keys) # (Batch, Seq, Feature)
        V = self.value_tfm(values) # (Batch, Seq, Feature)
        log_size(Q, "queries, keys, vals")
        # compute multiple attention weighted sums
        x = self.attn(Q, K, V)
        return x

In [13]:
# Double Checking
attn_head = AttentionHead(20, 20)
attn_head(q, k, v)

[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[ScaledDotProductAttention] attention weight size=torch.Size([5, 10, 10])
[ScaledDotProductAttention] attention output size size=torch.Size([5, 10, 20])


tensor([[[-0.5787, -0.5825,  0.2877,  0.4638,  0.5377, -0.1285, -0.3489,
          -0.6006,  0.7705, -0.4411,  0.0052,  0.0720,  0.1320,  0.0365,
           0.3778, -0.0909,  0.0165, -0.0921,  0.0472, -0.3238],
         [-0.4205, -0.4182,  0.1735,  0.3552,  0.3908, -0.0873, -0.2271,
          -0.4572,  0.5092, -0.3006,  0.0291,  0.0855,  0.0959,  0.1057,
           0.2389, -0.0284,  0.0358, -0.0488,  0.0508, -0.2354],
         [-0.4049, -0.3178,  0.1341,  0.2661,  0.3309, -0.0649, -0.1953,
          -0.4050,  0.4349, -0.2690,  0.0220,  0.0628,  0.0675,  0.0876,
           0.1906, -0.0455,  0.0312, -0.0278,  0.0283, -0.2013],
         [-0.5344, -0.4068,  0.2043,  0.3159,  0.4202, -0.0688, -0.2693,
          -0.4689,  0.5893, -0.3714, -0.0221,  0.0391,  0.0928,  0.0064,
           0.2935, -0.1110,  0.0135, -0.0469,  0.0099, -0.2521],
         [-0.5760, -0.5850,  0.2888,  0.4658,  0.5372, -0.1299, -0.3472,
          -0.6040,  0.7707, -0.4398,  0.0082,  0.0771,  0.1337,  0.0419,
          

In [14]:
class MultiHeadAttention(nn.Module):
    level = TensorLoggingLevels.multihead_attention_block
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_feature = d_feature
        self.n_heads = n_heads
        # in practice, d_model == d_feature * n_heads
        assert d_model == d_feature * n_heads

        # Note that this is very inefficient:
        # I am merely implementing the heads separately because it is 
        # easier to understand this way
        self.attn_heads = nn.ModuleList([AttentionHead(d_model, d_feature, dropout) for _ in range(n_heads)])
        self.projection = nn.Linear(d_feature * n_heads, d_model) 
    
    def forward(self, queries, keys, values, mask=None):
        log_size(queries, "Input queries")
        x = [attn(queries, keys, values, mask=mask) # (Batch, Seq, Feature)
             for i, attn in enumerate(self.attn_heads)]
        log_size(x[0], "output of single head")
        
        # reconcatenate
        x = torch.cat(x, dim=Dim.feature) # (Batch, Seq, D_Feature * n_heads)
        log_size(x, "concatenated output")
        x = self.projection(x) # (Batch, Seq, D_Model)
        log_size(x, "projected output")
        return x

In [15]:
# We'll supress logging from the scaled dot product attention now
logger.setLevel(TensorLoggingLevels.attention_head)
heads = MultiHeadAttention(20 * 8, 20, 8)
heads(q.repeat(1, 1, 8), 
      k.repeat(1, 1, 8), 
      v.repeat(1, 1, 8))

[MultiHeadAttention] Input queries size=torch.Size([5, 10, 160])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[AttentionHead] queries, keys, vals size=torch.Size([5, 10, 20])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 20])
[MultiHeadAttention] concatenated output size=torch.Size([5, 10, 160])
[MultiHeadAttention] projected output size=torch.Size([5, 10, 160])


tensor([[[ 0.1159,  0.0529, -0.0127,  ..., -0.1988, -0.1469,  0.1303],
         [ 0.1197,  0.0322, -0.0224,  ..., -0.1162, -0.1722,  0.1110],
         [ 0.0994,  0.0027, -0.0074,  ..., -0.2015, -0.1298,  0.1169],
         ...,
         [ 0.1295, -0.0002, -0.0089,  ..., -0.2222, -0.1359,  0.0992],
         [ 0.1382,  0.0174, -0.0020,  ..., -0.2107, -0.1329,  0.1106],
         [ 0.1094,  0.0341, -0.0119,  ..., -0.1909, -0.1696,  0.1347]],

        [[ 0.1471, -0.0042,  0.0427,  ..., -0.2365, -0.1100,  0.0143],
         [ 0.1612,  0.0047,  0.0437,  ..., -0.2896, -0.1670,  0.0122],
         [ 0.1377, -0.0082,  0.0380,  ..., -0.2332, -0.0991, -0.0022],
         ...,
         [ 0.1294, -0.0110,  0.0371,  ..., -0.2590, -0.1424,  0.0102],
         [ 0.1590,  0.0161,  0.0354,  ..., -0.2371, -0.1789,  0.0016],
         [ 0.1560, -0.0066,  0.0620,  ..., -0.2245, -0.1363,  0.0147]],

        [[ 0.0418,  0.0270, -0.0096,  ..., -0.1850, -0.1378,  0.0812],
         [ 0.0811,  0.0068,  0.0021,  ..., -0

## Encoder layer

In [16]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [17]:
class EncoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.layer_norm1 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.layer_norm2 = LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        log_size(x, "Encoder block input")
        att = self.attn_head(x, x, x, mask=mask)
        log_size(x, "Attention output")
        # Apply normalization and residual connection
        x = x + self.dropout(self.layer_norm1(att))
        # Apply position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        log_size(x, "Feedforward output")
        # Apply normalization and residual connection
        x = x + self.dropout(self.layer_norm2(pos))
        log_size(x, "Encoder size output")
        return x

In [18]:
# We'll supress logging from the individual attention heads
logger.setLevel(TensorLoggingLevels.multihead_attention_block)
enc = EncoderBlock()
enc(torch.rand(5, 10, 512))

[EncoderBlock] Encoder block input size=torch.Size([5, 10, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 10, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 64])
[MultiHeadAttention] concatenated output size=torch.Size([5, 10, 512])
[MultiHeadAttention] projected output size=torch.Size([5, 10, 512])
[EncoderBlock] Attention output size=torch.Size([5, 10, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 10, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 10, 512])


tensor([[[ 3.7680,  2.0711,  0.3541,  ...,  1.2161,  0.2477,  0.3743],
         [ 3.3455,  1.9848,  0.0146,  ...,  1.7872, -0.3102,  0.3357],
         [ 3.7358,  2.3758, -0.3364,  ...,  0.3960,  0.3831,  0.2983],
         ...,
         [ 2.6606,  3.3947,  0.5691,  ...,  0.8397, -0.5506, -1.5812],
         [ 3.0163,  2.0350, -0.0419,  ...,  0.5514, -0.1235,  0.9761],
         [ 2.4870,  0.6717,  0.2109,  ...,  1.0857, -0.7127, -0.1037]],

        [[ 1.9009,  2.2927,  1.8900,  ...,  1.2419, -1.2112,  0.0765],
         [ 2.8415,  2.6401,  1.2274,  ...,  1.9445,  0.2086, -0.1449],
         [ 1.9334,  2.8163,  1.2456,  ...,  0.8339, -1.2177, -0.1095],
         ...,
         [ 0.5987,  2.3868,  0.6901,  ...,  0.7395, -0.3681,  0.3130],
         [ 1.5982,  2.0932,  1.5629,  ...,  0.8693,  1.5622,  0.0692],
         [ 1.9024,  2.3831,  0.3466,  ...,  0.1804,  1.1749,  0.3410]],

        [[ 3.5502,  2.3804,  1.3971,  ...,  1.5073, -0.1297,  1.7719],
         [ 2.7341,  0.7462,  1.7061,  ...,  0

In [19]:
class TransformerEncoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    def __init__(self, n_blocks=6, d_model=512,
                 n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.encoders = nn.ModuleList([
            EncoderBlock(d_model=d_model, d_feature=d_model // n_heads,
                         d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks)
        ])
    
    def forward(self, x: torch.FloatTensor, mask=None):
        for encoder in self.encoders:
            x = encoder(x)
        return x

In [30]:
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()
enc_out = encoder(emb(torch.randint(1000, (5, 30)).to(dtype=torch.long)))

[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=t

In [32]:
enc_out.shape

torch.Size([5, 30, 512])

## Decoder layer

In [21]:
class DecoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64,
                 d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.masked_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.layer_norm3 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, 
                src_mask=None, tgt_mask=None):
        # Apply attention to inputs
        att = self.masked_attn_head(x, x, x, mask=src_mask)
        x = x + self.dropout(self.layer_norm1(att))
        # Apply attention to the encoder outputs and outputs of the previous layer
        att = self.attn_head(queries=x, keys=enc_out, values=enc_out, mask=tgt_mask)
        x = x + self.dropout(self.layer_norm2(att))
        # Apply position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        x = x + self.dropout(self.layer_norm2(pos))
        return x

In [22]:
dec = DecoderBlock()
dec(torch.rand(5, 10, 512), enc(torch.rand(5, 10, 512)))

[EncoderBlock] Encoder block input size=torch.Size([5, 10, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 10, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 64])
[MultiHeadAttention] concatenated output size=torch.Size([5, 10, 512])
[MultiHeadAttention] projected output size=torch.Size([5, 10, 512])
[EncoderBlock] Attention output size=torch.Size([5, 10, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 10, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 10, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 10, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 64])
[MultiHeadAttention] concatenated output size=torch.Size([5, 10, 512])
[MultiHeadAttention] projected output size=torch.Size([5, 10, 512])
[MultiHeadAttention] Input queries size=torch.Size([5, 10, 512])
[MultiHeadAttention] output of single head size=torch.Size([5, 10, 64])
[MultiHeadAttention] concatenated output size=torch.Siz

tensor([[[ 1.3143, -3.6150,  2.0685,  ...,  1.0186,  2.1423, -5.2383],
         [ 0.4081, -3.9372,  0.3825,  ...,  2.2026,  1.9094, -0.9124],
         [ 0.9806, -1.7727,  2.2844,  ...,  0.9288,  1.3480, -4.8447],
         ...,
         [ 2.1945, -4.2671,  2.0488,  ...,  1.3101,  1.7486, -6.4701],
         [ 1.9170, -3.3968, -0.4371,  ...,  2.0384,  1.2802, -4.6823],
         [ 0.1921, -3.3898,  1.6817,  ...,  0.6606,  1.4826, -4.2986]],

        [[ 1.6475, -3.0126,  0.0400,  ...,  1.4631,  2.0488, -5.1784],
         [-0.0093, -3.1619, -0.9516,  ...,  0.8779,  2.0665, -3.4306],
         [ 1.6324, -4.3141,  0.1368,  ...,  2.8415,  1.8893, -4.9675],
         ...,
         [ 1.8956, -4.1262, -0.2359,  ...,  2.1064,  1.9551, -4.3774],
         [ 0.6772, -4.0377,  0.6609,  ...,  0.3275,  2.4405, -3.2075],
         [ 0.9534, -4.3457,  0.0026,  ...,  0.9398,  2.5656, -5.0483]],

        [[ 1.2853, -3.8986,  1.1103,  ...,  2.4212,  1.2662, -4.8599],
         [ 1.0998, -3.7637,  0.6068,  ...,  2

In [23]:
class TransformerDecoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    def __init__(self, n_blocks=6, d_model=512, d_feature=64,
                 d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.position_embedding = PositionalEmbedding(d_model)
        self.decoders = nn.ModuleList([
            DecoderBlock(d_model=d_model, d_feature=d_model // n_heads,
                         d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks)
        ])
        
    def forward(self, x: torch.FloatTensor, 
                enc_out: torch.FloatTensor, 
                src_mask=None, tgt_mask=None):
        for decoder in self.decoders:
            x = decoder(x, enc_out, src_mask=src_mask, tgt_mask=tgt_mask)
        return x

In [24]:
decoder = TransformerDecoder()

# Put all together

In [25]:
# We'll supress logging from the scaled dot product attention now
logger.setLevel(TensorLoggingLevels.enc_dec_block)
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()
decoder = TransformerDecoder()

In [26]:
src_ids = torch.randint(1000, (5, 30)).to(dtype=torch.long)
tgt_ids = torch.randint(1000, (5, 30)).to(dtype=torch.long)

In [27]:
x = encoder(emb(src_ids))
decoder(emb(tgt_ids), x)

[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=t

tensor([[[-1.5636e+00, -3.8588e+00,  3.7249e+00,  ..., -4.1889e+00,
          -5.6192e+00, -3.0030e+00],
         [ 2.8379e+00, -2.3141e+00, -1.5286e-01,  ..., -7.5009e+00,
          -2.8948e+00, -1.8489e+00],
         [ 3.1479e+00, -2.6509e+00,  2.6539e+00,  ..., -7.9812e+00,
          -2.8788e+00, -9.3533e-01],
         ...,
         [-3.8777e+00,  2.8749e-01,  2.3224e+00,  ..., -8.7214e+00,
          -4.9369e+00, -7.5240e-01],
         [-1.5537e+00, -9.9530e-01,  3.7730e+00,  ..., -8.0768e+00,
          -2.6874e-01, -4.8747e+00],
         [ 3.6634e+00, -2.4629e+00,  3.8783e-01,  ..., -9.1722e+00,
           1.3492e+00, -7.3836e-02]],

        [[-2.3774e-01, -2.4522e+00,  8.6978e+00,  ..., -8.3269e+00,
          -3.4548e+00, -1.6867e+00],
         [-2.6132e+00,  9.3545e-02,  9.6737e-01,  ..., -3.3069e+00,
          -6.2419e+00, -2.3105e+00],
         [ 2.3841e-02, -2.9472e+00,  1.4953e-01,  ..., -2.7899e+00,
          -4.2259e+00,  6.4982e-01],
         ...,
         [ 6.3979e-01, -4

In [29]:
q.repeat(1, 1, 8).shape

torch.Size([5, 10, 20])