In [1]:
import numpy as np
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn
    

In [2]:
class PositionwiseFFN(nn.Module):
    """PositionWiseFFN from Attention is All You Need, dl2.ai, and Cohen et. al (2022).
        Two-layer MLP applied to each index of the input sequence.
        Parameters
        ----------        
        d_model: dimension of latent space in the model.
        d_ffn_hidden: defaults to 2048. Dimension of hidden layer between MLPs.

    """
    def __init__(self, d_model: int, d_ffn_hidden=2048):
        """Default hidden dimension is 2048"""
        super().__init__()
        self._linear1 = nn.Linear(d_model, d_ffn_hidden)
        self._relu = nn.ReLU()
        self._linear2 = nn.Linear(d_ffn_hidden, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Pass x through the PositionwiseFFN block  Input and output have a shape (d_model, d_ffn_hidden).
        Parameters
        ----------
        x: 
            Input tensor of shape (batch_size, input_len, d_model)
        Returns
        ----------
        x:
            Output tensor of shape (batch_size, input_len, d_model)
        """
        x = self._linear1(x)
        x = self._relu(x)
        return self._linear2(x)


In [3]:
dmodel=64
inlen = 60
ffn = PositionwiseFFN(dmodel)
x = torch.rand(32, 10, 64)

In [4]:
y = ffn(x)

In [6]:
y

tensor([[[ 0.2851,  0.1430,  0.1040,  ...,  0.2884, -0.4036, -0.0844],
         [ 0.2177,  0.1151,  0.0471,  ...,  0.0714, -0.2473, -0.0229],
         [ 0.2511,  0.0610,  0.0490,  ...,  0.2193, -0.2525, -0.0742],
         ...,
         [ 0.2412,  0.0759, -0.0226,  ...,  0.1360, -0.3927, -0.0599],
         [ 0.3453,  0.0892,  0.0219,  ...,  0.0996, -0.3204, -0.0642],
         [ 0.2717,  0.0913,  0.0785,  ...,  0.1144, -0.2010,  0.1214]],

        [[ 0.2064, -0.0324, -0.0077,  ...,  0.1458, -0.2877,  0.0661],
         [ 0.3164,  0.0321,  0.0240,  ...,  0.0603, -0.2975, -0.0182],
         [ 0.1927,  0.0828,  0.0341,  ...,  0.1114, -0.3281, -0.0177],
         ...,
         [ 0.3543,  0.2046,  0.1208,  ...,  0.0940, -0.2701, -0.0550],
         [ 0.2006, -0.0670,  0.1534,  ...,  0.0621, -0.3204, -0.1335],
         [ 0.2756,  0.1128,  0.0745,  ...,  0.1781, -0.3932, -0.0874]],

        [[ 0.2642,  0.0309,  0.0981,  ...,  0.1251, -0.3062, -0.0804],
         [ 0.2680,  0.0025,  0.0126,  ...,  0