In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")

## Code copied from
### https://www.youtube.com/watch?v=U0s0f995w14
### https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py
### modifying code from there

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = self.values(values)  # (N, value_len, embed_size)
        keys = self.keys(keys)  # (N, key_len, embed_size)
        queries = self.queries(query)  # (N, query_len, embed_size)

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)

        return out

In [4]:
sa = SelfAttention(64, 8)

In [5]:
## (N, value_len, embed_size)
x = torch.randn(1, 128, 64)
sa(x, x, x, None).shape

torch.Size([1, 128, 64])

## Self Attention in Block

1. The k,q,v transform matrices are -> shared for each token; hence having multiple such layers might increase params
2. The final linear transform matrix is also shared; mixing partial indirectly increases parameters for multiple mixing
3. We want to have smaller sequence and mix them using block attention matrix.

In [52]:
class SelfAttention_Sparse(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask, block_size):
        # Get number of training examples
        N = query.shape[0]
        
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = self.values(values)  # (N, value_len, embed_size)
        keys = self.keys(keys)  # (N, key_len, embed_size)
        queries = self.queries(query)  # (N, query_len, embed_size)

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len//block_size, block_size, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len//block_size, block_size, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len//block_size, block_size, self.heads, self.head_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("n aq h d , n ak h d -> n h a qk", [queries, keys])
        # queries shape: (N, n_query_blocks, block_query_len, heads, heads_dim),
        # keys shape: (N, n_key_blocks, block_key_len, heads, heads_dim)
        # energy: (N, heads, n_query_blocks, block_query_len, block_key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1)
        # attention shape: (N, heads, num_blocks, query_len, key_len)

        out = torch.einsum("n h a q k , n a k hd -> n a q hd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, num_blocks, query_len, key_len)
        # values shape: (N, num_blocks, block_value_len, heads, heads_dim)
        # out after matrix multiply: (N, num_blocks, block_query_len, heads, head_dim), then
        # we reshape and flatten the (1,2)dimensions as well as (3,4) dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)
        return out

## Development

In [9]:
###(N, query_len, heads, heads_dim)
# N, seq_len, heads, heads_dim = (1, 32, 2, 4)
N, seq_len, heads, heads_dim = (1, 512, 2, 4)

k = torch.randn(N, seq_len, heads, heads_dim)
q = torch.randn(N, seq_len, heads, heads_dim)
v = torch.randn(N, seq_len, heads, heads_dim)

In [10]:
k, q, v = k.to(device), q.to(device), v.to(device)

In [11]:
kq = torch.einsum("nqhd,nkhd->nhqk", [k, q])
kq.shape

torch.Size([1, 2, 512, 512])

In [12]:
att = torch.softmax(kq / ((heads*heads_dim) ** (1 / 2)), dim=-1)

In [13]:
out = torch.einsum("nhql,nlhd->nqhd", [att, v])

In [14]:
out.shape

torch.Size([1, 512, 2, 4])

In [15]:
# These are valid shapes for partial mixing
'''
{2: [2],
 4: [2, 4],
 8: [2, 8],
 16: [2, 4, 16],
 32: [2, 32],
 64: [2, 4, 8, 64],
 128: [2, 128],
 256: [2, 4, 16, 256],
 512: [2, 8, 512],
 1024: [2, 4, 32, 1024],
 2048: [2, 2048],
 4096: [2, 4, 8, 16, 64, 4096]}
 '''
print()




In [16]:
block_size = 32 #8
k_ = k.clone().reshape(k.shape[0], k.shape[1]//block_size, block_size, k.shape[2], k.shape[3]).contiguous()
q_ = q.clone().reshape(q.shape[0], q.shape[1]//block_size, block_size, q.shape[2], q.shape[3]).contiguous()
v_ = v.clone().reshape(v.shape[0], v.shape[1]//block_size, block_size, v.shape[2], v.shape[3]).contiguous()

In [222]:
32*32

1024

In [17]:
###(N, num_blocks, query_block_len, heads, heads_dim)
k_.shape

torch.Size([1, 16, 32, 2, 4])

In [18]:
### using partial attention
kq_ = torch.einsum("n aq h d , n ak h d -> n h a qk", [k_, q_])
kq_.shape

torch.Size([1, 2, 16, 32, 32])

In [19]:
## kq_ = torch.Size([1, 2, 4, 8, 8])
## batch size, num_heads, num_attention_blocks, (att0 , att1)

In [20]:
att_ = torch.softmax(kq_ / (8 ** (1 / 2)), dim=-1) ## attention over a block only

In [21]:
att_.shape

torch.Size([1, 2, 16, 32, 32])

In [22]:
out_ = torch.einsum("n h a q k , n a k hd -> n a q hd", [att_, v_])

In [23]:
out_.shape

torch.Size([1, 16, 32, 2, 4])

In [24]:
out_f_ = out_.reshape(out_.shape[0], out_.shape[1]*out_.shape[2], out_.shape[3]*out_.shape[4])
out_f_.shape

torch.Size([1, 512, 8])

In [25]:
# out_f_

### computing time

In [26]:
def attention_func():
    kq = torch.einsum("nqhd,nkhd->nhqk", [k, q])
    att = torch.softmax(kq / (8 ** (1 / 2)), dim=-1)
    out = torch.einsum("nhql,nlhd->nqhd", [att, v])

def partial_attention_func():
    kq_ = torch.einsum("n aq h d , n ak h d -> n h a qk", [k_, q_])
    att_ = torch.softmax(kq_ / (8 ** (1 / 2)), dim=-1)
    out_ = torch.einsum("n h a q k , n a k hd -> n a q hd", [att_, v_])

In [27]:
%timeit attention_func()

1.18 ms ± 237 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [28]:
%timeit partial_attention_func()

252 µs ± 123 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Implement Full Transformer

In [53]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, actf=nn.GELU):
        super(TransformerBlock, self).__init__()
        
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            actf(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [60]:
## (N, value_len, embed_size)
x = torch.randn(1, 128, 64)

b = TransformerBlock(64, 8, 0, 2)

In [61]:
b(x, x, x, None).shape

torch.Size([1, 128, 64])

## Sparse Block

In [327]:
class Sparse_TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, actf=nn.GELU):
        super().__init__()
        
        self.attention = SelfAttention_Sparse(embed_size, heads)
            
        self.norm1 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            actf(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask, block_size):
        attention = self.attention(value, key, query, mask, block_size)

        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [328]:
# class Mixer_TransformerBlock_Encoder(nn.Module):
#     def __init__(self, seq_length, block_size, embed_size, heads, dropout, forward_expansion, actf=nn.GELU):
#         super().__init__()
#         assert 2**int(np.log2(block_size)) == block_size, 'Block size must be power of 2'
#         assert 2**int(np.log2(seq_length)) == seq_length, 'Sequence length must be power of 2'
#         assert seq_length%block_size == 0, 'Sequence length must be divisible exactly by block_size'
        
#         self.block_size = block_size
#         self.seq_len = seq_length
        
#         def log_base(a, base):
#             return np.log(a) / np.log(base)
        
#         num_layers = int(np.ceil(log_base(seq_length, base=block_size)))
#         self.sparse_transformers = []
#         for i in range(num_layers):
#             tr = Sparse_TransformerBlock(embed_size, heads, dropout, forward_expansion, actf)
#             self.sparse_transformers.append(tr)
#         self.sparse_transformers = nn.ModuleList(self.sparse_transformers)
        

#     def forward(self, x, mask):
#         N, seq_len, d_model = x.shape
#         ### (N, seq_len, d_model) of the input x
        
#         assert seq_len == self.seq_len, 'The sequence length of given input does not match this model'
            
#         for i, fn in enumerate(self.sparse_transformers):
# #             x = x.view(N, -1, self.block_size, self.block_size**i, d_model).permute(0, 1, 3, 2, 4).contiguous().view(N, seq_len, d_model)
# #             x = fn(x, x, x, mask, self.block_size)
# #             x = x.view(N, -1, self.block_size**i, self.block_size, d_model).permute(0, 1, 3, 2, 4).contiguous()

#             x = x.view(N, -1, self.block_size, self.block_size**i, d_model).transpose(2, 3).contiguous().view(N, seq_len, d_model)
#             x = fn(x, x, x, mask, self.block_size)
#             x = x.view(N, -1, self.block_size**i, self.block_size, d_model).transpose(2, 3).contiguous()

#         x = x.view(N, seq_len, -1)
#         return x

In [329]:
class Mixer_TransformerBlock_Encoder(nn.Module):
    def __init__(self, seq_length, block_size, embed_size, heads, dropout, forward_expansion, actf=nn.GELU):
        super().__init__()
        assert 2**int(np.log2(block_size)) == block_size, 'Block size must be power of 2'
        assert 2**int(np.log2(seq_length)) == seq_length, 'Sequence length must be power of 2'
        assert seq_length%block_size == 0, 'Sequence length must be divisible exactly by block_size'
        
        self.block_size = block_size
        self.seq_len = seq_length
        
        def log_base(a, base):
            return np.log(a) / np.log(base)
        
        num_layers = int(np.ceil(log_base(seq_length, base=block_size)))
        self.sparse_transformers = []
        self.gaps = []
        for i in range(num_layers):            
            tr = Sparse_TransformerBlock(embed_size, heads, dropout, forward_expansion, actf)
            self.sparse_transformers.append(tr)
            ### find which permutation gives valid shape
            gap = self.block_size**i
            if gap*self.block_size <= self.seq_len:
                self.gaps.append(gap)
            else:
                self.gaps.append(int(np.ceil(self.seq_len/self.block_size)))
#                 break
            
        self.sparse_transformers = nn.ModuleList(self.sparse_transformers)
        

    def forward(self, x, mask):
        N, seq_len, d_model = x.shape
        ### (N, seq_len, d_model) of the input x
        
        assert seq_len == self.seq_len, 'The sequence length of given input does not match this model'
            
        for i, fn in enumerate(self.sparse_transformers):
            gap = self.gaps[i]
#             print(i, gap)
            x = x.view(N, -1, self.block_size, gap, d_model).transpose(2, 3).contiguous().view(N, seq_len, d_model)
            x = fn(x, x, x, mask, self.block_size)
            x = x.view(N, -1, gap, self.block_size, d_model).transpose(2, 3).contiguous()

        x = x.view(N, seq_len, -1)
        return x

In [360]:
S, B = 256, 16 
x = torch.randn(1, S, 64)

In [361]:
b = Mixer_TransformerBlock_Encoder(seq_length=S, block_size=B, 
                                   embed_size=64, heads=8, dropout=0, 
                                   forward_expansion=2)

In [362]:
b(x, None).shape

torch.Size([1, 256, 64])

In [363]:
b

Mixer_TransformerBlock_Encoder(
  (sparse_transformers): ModuleList(
    (0): Sparse_TransformerBlock(
      (attention): SelfAttention_Sparse(
        (values): Linear(in_features=64, out_features=64, bias=True)
        (keys): Linear(in_features=64, out_features=64, bias=True)
        (queries): Linear(in_features=64, out_features=64, bias=True)
        (fc_out): Linear(in_features=64, out_features=64, bias=True)
      )
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=64, out_features=128, bias=True)
        (1): GELU()
        (2): Linear(in_features=128, out_features=64, bias=True)
      )
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (1): Sparse_TransformerBlock(
      (attention): SelfAttention_Sparse(
        (values): Linear(in_features=64, out_features=64, bias=True)
        (keys): Linear(in_features=64, out_feature

In [342]:
log_base(256, 32)

1.5999999999999999

In [344]:
a = torch.arange(32).reshape(-1, 1)
a = a.repeat_interleave(5, dim=1)
a

tensor([[ 0,  0,  0,  0,  0],
        [ 1,  1,  1,  1,  1],
        [ 2,  2,  2,  2,  2],
        [ 3,  3,  3,  3,  3],
        [ 4,  4,  4,  4,  4],
        [ 5,  5,  5,  5,  5],
        [ 6,  6,  6,  6,  6],
        [ 7,  7,  7,  7,  7],
        [ 8,  8,  8,  8,  8],
        [ 9,  9,  9,  9,  9],
        [10, 10, 10, 10, 10],
        [11, 11, 11, 11, 11],
        [12, 12, 12, 12, 12],
        [13, 13, 13, 13, 13],
        [14, 14, 14, 14, 14],
        [15, 15, 15, 15, 15],
        [16, 16, 16, 16, 16],
        [17, 17, 17, 17, 17],
        [18, 18, 18, 18, 18],
        [19, 19, 19, 19, 19],
        [20, 20, 20, 20, 20],
        [21, 21, 21, 21, 21],
        [22, 22, 22, 22, 22],
        [23, 23, 23, 23, 23],
        [24, 24, 24, 24, 24],
        [25, 25, 25, 25, 25],
        [26, 26, 26, 26, 26],
        [27, 27, 27, 27, 27],
        [28, 28, 28, 28, 28],
        [29, 29, 29, 29, 29],
        [30, 30, 30, 30, 30],
        [31, 31, 31, 31, 31]])

In [345]:
a.view(-1, 4, 1, 5).permute(0, 2, 1, 3)

tensor([[[[ 0,  0,  0,  0,  0],
          [ 1,  1,  1,  1,  1],
          [ 2,  2,  2,  2,  2],
          [ 3,  3,  3,  3,  3]]],


        [[[ 4,  4,  4,  4,  4],
          [ 5,  5,  5,  5,  5],
          [ 6,  6,  6,  6,  6],
          [ 7,  7,  7,  7,  7]]],


        [[[ 8,  8,  8,  8,  8],
          [ 9,  9,  9,  9,  9],
          [10, 10, 10, 10, 10],
          [11, 11, 11, 11, 11]]],


        [[[12, 12, 12, 12, 12],
          [13, 13, 13, 13, 13],
          [14, 14, 14, 14, 14],
          [15, 15, 15, 15, 15]]],


        [[[16, 16, 16, 16, 16],
          [17, 17, 17, 17, 17],
          [18, 18, 18, 18, 18],
          [19, 19, 19, 19, 19]]],


        [[[20, 20, 20, 20, 20],
          [21, 21, 21, 21, 21],
          [22, 22, 22, 22, 22],
          [23, 23, 23, 23, 23]]],


        [[[24, 24, 24, 24, 24],
          [25, 25, 25, 25, 25],
          [26, 26, 26, 26, 26],
          [27, 27, 27, 27, 27]]],


        [[[28, 28, 28, 28, 28],
          [29, 29, 29, 29, 29],
          [3

In [346]:
a.view(-1, 4, 8, 5).permute(0, 2, 1, 3)

tensor([[[[ 0,  0,  0,  0,  0],
          [ 8,  8,  8,  8,  8],
          [16, 16, 16, 16, 16],
          [24, 24, 24, 24, 24]],

         [[ 1,  1,  1,  1,  1],
          [ 9,  9,  9,  9,  9],
          [17, 17, 17, 17, 17],
          [25, 25, 25, 25, 25]],

         [[ 2,  2,  2,  2,  2],
          [10, 10, 10, 10, 10],
          [18, 18, 18, 18, 18],
          [26, 26, 26, 26, 26]],

         [[ 3,  3,  3,  3,  3],
          [11, 11, 11, 11, 11],
          [19, 19, 19, 19, 19],
          [27, 27, 27, 27, 27]],

         [[ 4,  4,  4,  4,  4],
          [12, 12, 12, 12, 12],
          [20, 20, 20, 20, 20],
          [28, 28, 28, 28, 28]],

         [[ 5,  5,  5,  5,  5],
          [13, 13, 13, 13, 13],
          [21, 21, 21, 21, 21],
          [29, 29, 29, 29, 29]],

         [[ 6,  6,  6,  6,  6],
          [14, 14, 14, 14, 14],
          [22, 22, 22, 22, 22],
          [30, 30, 30, 30, 30]],

         [[ 7,  7,  7,  7,  7],
          [15, 15, 15, 15, 15],
          [23, 23, 23, 23,

In [347]:
log_base(32, 4)

2.5

### Higher level blocks

In [146]:
class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):

        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        # In the Encoder the query, key, value are all the same, it's in the
        # decoder this will change. This might look a bit odd in this case.
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.norm = nn.LayerNorm(embed_size)
        self.attention = SelfAttention(embed_size, heads=heads)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out

In [None]:
class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        device,
        max_length,
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out