In [1]:
import torch
import torch.nn as nn

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [2]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

In [3]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715 * torch.pow(x, 3))
        ))


In [4]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)

In [5]:
ff = FeedForward( GPT_CONFIG_124M )

In [6]:
input = torch.randn( 3, 4, 768 )

In [7]:
input.shape

torch.Size([3, 4, 768])

In [8]:
input

tensor([[[-0.8642,  0.1608,  0.5805,  ..., -1.3529, -0.5953,  0.9460],
         [ 1.0093,  0.2314,  1.6023,  ..., -0.2950, -0.2106, -0.1758],
         [-0.1302,  0.2708,  0.7648,  ..., -1.2072,  1.4379, -0.1536],
         [ 0.0673, -0.7788, -0.9794,  ...,  0.3611, -0.3154,  0.2604]],

        [[-1.4537, -1.2576, -0.7515,  ...,  0.1825, -0.0973, -1.7866],
         [ 0.9329,  0.0222, -1.2215,  ...,  0.6574, -1.5332,  0.0874],
         [ 0.8134, -0.9286,  0.9755,  ...,  1.4380, -3.3191,  0.3978],
         [-1.7112, -0.4354, -0.3125,  ...,  0.2515, -0.0146, -0.0187]],

        [[ 0.4109,  1.0161,  1.7605,  ..., -0.7993, -0.7750,  0.5944],
         [ 0.0080, -0.5243,  0.3228,  ..., -0.7072,  1.4444,  0.7121],
         [-0.6951, -0.2926,  0.6064,  ...,  0.1126,  0.5330, -0.0722],
         [-0.9762, -0.4899,  0.2660,  ..., -1.9703,  0.5497, -0.6832]]])

In [9]:
output = ff( input )

In [10]:
output.shape

torch.Size([3, 4, 768])

In [11]:
output

tensor([[[ 0.3467, -0.3723, -0.0925,  ..., -0.5143,  0.0170, -0.1157],
         [ 0.1156,  0.0135, -0.0155,  ...,  0.1373, -0.0215, -0.2582],
         [ 0.0742, -0.2238, -0.4412,  ..., -0.0221, -0.1476, -0.4195],
         [ 0.0364, -0.2558,  0.2155,  ...,  0.0344,  0.0555, -0.1370]],

        [[ 0.0491, -0.1137, -0.1043,  ..., -0.1809,  0.2208, -0.1067],
         [ 0.4362,  0.1157, -0.0357,  ..., -0.0356,  0.0185, -0.0880],
         [-0.1516,  0.3122,  0.0575,  ...,  0.2998,  0.1111, -0.3437],
         [-0.1559, -0.0884,  0.3478,  ...,  0.1105,  0.3079, -0.1521]],

        [[-0.0836, -0.0849,  0.0145,  ..., -0.0752,  0.0725, -0.1646],
         [-0.1925,  0.0807,  0.2325,  ...,  0.4726,  0.2017, -0.1133],
         [ 0.1306, -0.0701, -0.0041,  ...,  0.1509, -0.5092, -0.0747],
         [ 0.0708, -0.1069, -0.1002,  ..., -0.0864,  0.1763, -0.3804]]],
       grad_fn=<ViewBackward0>)

In [12]:
output = ff( output )

In [13]:
output.shape

torch.Size([3, 4, 768])