In [13]:
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import math
from transformer.layers import FeedForward, LayerNorm

In [14]:
class ResidualConnection(nn.Module):
    def __init__(self, d_model: int, dropout: float) -> None:
        super().__init__()
        self.norm = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, sublayer): # x: [batch_size, seq_len, d_model]
        return x + self.dropout(sublayer(self.norm(x))) # [batch_size, seq_len, d_model]

In [15]:
# Define the dimensions and dropout rate
d_model = 512
dropout = 0.1

# Create an instance of the ResidualConnection and FeedForward classes
residual_connection = ResidualConnection(d_model, dropout)
feed_forward = FeedForward(d_model, d_ff=2048, dropout=dropout)

# Create a random tensor to represent a batch of sequences
x = torch.rand(10, 20, d_model)  # batch_size=10, seq_len=20, d_model=512

# Pass the tensor through the residual connection with feed forward as the sublayer
output = residual_connection(x, feed_forward)

print(output.shape)  # Should print: torch.Size([10, 20, 512])

torch.Size([10, 20, 512])


In [16]:
output

tensor([[[ 0.2129, -0.0436,  0.9588,  ...,  0.1822,  0.3639,  0.4293],
         [ 0.2390,  0.6822,  0.9811,  ...,  0.2690,  0.7167,  0.4967],
         [ 0.3385,  1.2155,  0.9386,  ..., -0.1280,  0.7318,  0.0548],
         ...,
         [ 1.0564,  0.7191,  0.6237,  ..., -0.4058,  0.6851,  0.1766],
         [ 0.7956,  0.6718,  0.5989,  ...,  0.9856,  0.7379,  0.9130],
         [ 0.6477,  0.6326,  0.9082,  ..., -0.2746,  1.2814,  0.7777]],

        [[ 0.9801,  0.1643,  0.8520,  ..., -0.3538,  0.7820,  0.0167],
         [ 0.8016,  0.9011,  1.9092,  ...,  0.6830, -0.3906, -0.0636],
         [-0.3192,  1.2711,  0.7822,  ...,  0.0668,  0.2121,  0.7413],
         ...,
         [-0.0962,  0.6480,  0.6683,  ...,  0.6441,  0.4376,  0.8433],
         [ 0.5444,  1.1282,  0.7066,  ...,  0.3085,  0.8632,  0.0959],
         [ 0.1369,  0.4966,  0.6239,  ...,  0.1070,  0.2950,  0.6108]],

        [[ 0.0291,  0.4064,  0.8068,  ...,  0.5433,  0.3167,  0.7600],
         [ 0.6815,  1.2556,  0.4052,  ..., -0