In [2]:
import torch
import torch.nn as nn

In [3]:
GPT2_CONFIG = {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": None,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": True,
  "summary_type": "cls_index",
  "summary_use_proj": True,
  "task_specific_params": {
    "text-generation": {
      "do_sample": True,
      "max_length": 50
    }
  },
  "vocab_size": 50257
}

In [8]:
from blocks import MultiHeadAttention, FFN
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_in = cfg['n_embd']
        d_out = cfg['n_embd']
        self.layer_norm1 = nn.LayerNorm(d_in)
        self.mhma = MultiHeadAttention(cfg)
        self.drop1 = nn.Dropout(cfg['embd_pdrop'])
        self.layer_norm2 = nn.LayerNorm(d_out)
        self.ffn = FFN(cfg)
        self.drop2 = nn.Dropout(cfg['embd_pdrop'])
    
    def forward(self, x):
        x_pre = x
        x_p1 = self.layer_norm1(x)
        x_p1 = self.mhma(x_p1)
        x_p1 = self.drop1(x_p1)
        x_p1 = x_pre + x_p1
        x_p2 = self.layer_norm2(x_p1)
        x_p2 = self.ffn(x_p2)
        x_p2 = self.drop2(x_p2)
        x_final = x_p1 + x_p2

        return x_final


In [9]:
tb = TransformerBlock(GPT2_CONFIG)
inputs = torch.rand((2, 4, 768))
outputs = tb(inputs)
print(outputs.shape)
print(outputs)

torch.Size([2, 4, 768])
tensor([[[ 0.4450, -0.2454, -0.2246,  ...,  1.7134, -0.5881,  1.1280],
         [ 0.3554, -0.6981,  0.1642,  ..., -0.6029,  0.1183,  0.9462],
         [ 0.8822,  0.0585,  0.6456,  ..., -0.0868,  0.1910,  0.6915],
         [ 0.3881,  0.2339,  0.8671,  ..., -0.3560, -0.1109,  0.7169]],

        [[ 0.1437,  0.9500, -0.1398,  ...,  0.2089,  0.7430,  0.0888],
         [-0.2217, -0.2557,  0.9253,  ...,  0.3462, -1.0515,  0.3073],
         [ 0.6949,  1.1424,  0.2081,  ...,  0.4882,  0.4778,  0.6516],
         [ 0.3915,  0.3891, -0.0732,  ...,  0.3211,  0.4178,  0.8796]]],
       grad_fn=<AddBackward0>)
