In [5]:
import torch.nn as nn

class PWFFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.3):
        super().__init__()

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        # shape(x) = [B x seq_len x D]

        ff = self.ff(x)
        # shape(ff) = [B x seq_len x D]

        return ff

In [6]:
import torch.nn as nn
import torch


class ResidualLayerNorm(nn.Module):
    def __init__(self, d_model, dropout=0.3):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, residual):
        ln = self.layer_norm(self.dropout(x) + residual)
        return ln

In [7]:
encodings = torch.Tensor([[[0.0, 0.1, 0.2, 0.3], [1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3]]]) 
prev_x = torch.randn(1, 3, 4)
norm_layer = ResidualLayerNorm(d_model=4)
norm = norm_layer(encodings, prev_x)
print("Norm: \n", norm)
print("Norm shape: \n", norm.shape)

Norm: 
 tensor([[[-1.5965,  0.5186,  1.0873, -0.0094],
         [ 0.4455, -1.5694, -0.0326,  1.1565],
         [-0.9016,  1.1520, -1.0814,  0.8310]]],
       grad_fn=<NativeLayerNormBackward0>)
Norm shape: 
 torch.Size([1, 3, 4])


In [8]:
PWFFN_layer = PWFFN(d_model=4, d_ff=16)
PWFFN = PWFFN_layer(norm)
print("PWFFN: \n", PWFFN)
print("PWFFN Shape: \n", PWFFN.shape)

PWFFN: 
 tensor([[[-0.0035, -0.2865, -0.0963, -0.2456],
         [-0.0009, -0.3834,  0.0740, -0.0520],
         [ 0.2627, -0.1414, -0.1472, -0.1193]]], grad_fn=<AddBackward0>)
PWFFN Shape: 
 torch.Size([1, 3, 4])
