In [1]:
import torch
import torch.nn as nn

In [2]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

In [3]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715 * torch.pow(x, 3))
        ))


In [5]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)

In [6]:
ff = FeedForward( GPT_CONFIG_124M )

In [16]:
input = torch.randn( 3, 4, 768 )

In [17]:
input.shape

torch.Size([3, 4, 768])

In [18]:
input

tensor([[[-0.2870,  0.7796, -0.7771,  ...,  1.7675,  1.8325,  0.5533],
         [ 0.1195, -2.3766, -0.1168,  ..., -2.3687, -1.2073,  1.0442],
         [-0.2396,  0.4309, -0.2816,  ...,  1.0989,  1.2649,  0.2766],
         [-0.4581,  1.5925,  0.2930,  ..., -0.2397, -0.3849, -0.3977]],

        [[ 0.7581, -0.9504, -0.2990,  ...,  0.8219, -0.9427,  0.0995],
         [ 1.0189, -0.0502, -2.3815,  ..., -0.7028, -0.3381,  0.9931],
         [ 0.8898, -0.6445,  0.0743,  ...,  0.8665,  1.0166,  3.1018],
         [-0.9737,  1.1691,  0.1148,  ...,  1.1822, -0.9717,  0.2131]],

        [[ 0.7425, -0.4969,  0.5585,  ...,  0.1093,  0.4270, -0.6889],
         [ 1.1036, -0.1694, -1.0940,  ..., -0.0278,  1.4463,  0.2761],
         [ 1.1362,  1.4879, -1.4161,  ..., -1.5167,  0.7745, -0.5070],
         [-1.7635,  0.2199, -0.9448,  ..., -0.3491, -0.7865, -0.5740]]])

In [19]:
output = ff( input )

In [20]:
output.shape

torch.Size([3, 4, 768])

In [21]:
output

tensor([[[ 0.0598,  0.3697, -0.1852,  ..., -0.0229,  0.3158,  0.1028],
         [ 0.1493,  0.0107, -0.0634,  ..., -0.0833,  0.2734,  0.0432],
         [ 0.1720, -0.1238, -0.0281,  ..., -0.4287,  0.1268, -0.3343],
         [-0.0652, -0.0863,  0.2504,  ...,  0.0721,  0.1843,  0.0671]],

        [[-0.0446,  0.0477,  0.1705,  ..., -0.1635,  0.1314,  0.2709],
         [ 0.3319, -0.3284, -0.2539,  ..., -0.1909, -0.3113, -0.1813],
         [-0.0656, -0.1042, -0.1510,  ..., -0.3866,  0.0474, -0.3870],
         [-0.0093, -0.0023,  0.2439,  ...,  0.0949,  0.0429,  0.0032]],

        [[ 0.2431,  0.1199, -0.1315,  ...,  0.0235,  0.0798, -0.2464],
         [ 0.1190, -0.1328, -0.0661,  ..., -0.1330,  0.1880, -0.2039],
         [ 0.2175, -0.0821, -0.0230,  ..., -0.0782,  0.1226,  0.0551],
         [-0.0900,  0.0584,  0.0264,  ...,  0.1051, -0.1146, -0.1833]]],
       grad_fn=<ViewBackward0>)

In [22]:
output = ff( output )

In [23]:
output.shape

torch.Size([3, 4, 768])