In [1]:
import torch

In [2]:
class TransformerBlock(torch.nn.Module):
  def __init__(self, embed_dim, n_heads):
    super().__init__()
    self.self_attn = torch.nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
    self.mlp = torch.nn.Sequential(
        torch.nn.Linear(embed_dim, embed_dim*4),
        torch.nn.ReLU(),
        torch.nn.Linear(embed_dim*4, embed_dim),
    )

    self.in_norm  = torch.nn.LayerNorm(embed_dim)
    self.mlp_norm = torch.nn.LayerNorm(embed_dim)

  def forward(self, x):
    x_norm = self.in_norm(x)
    x = x + self.self_attn(x_norm, x_norm, x_norm)[0] # the x + is the residual connection, basocally addind identity X
    x = x + self.mlp(self.mlp_norm(x)) # this is vanilla transformer layer.
    return x

class Transformer(torch.nn.Module):
  def __init__(self, embed_dim, n_heads, n_layers):
    super().__init__()
    self.net = torch.nn.Sequential(
        *[
            TransformerBlock(embed_dim, n_heads) for _ in range(n_layers)
        ]
        )

  def forward(self, x):
    return self.net(x)

net = Transformer(128,8,4)
x = torch.rand(16,10,128) # Same Dimension goes in and smae dim comes out. All that happens is it goes through a bunch of transformer layers.
print(net(x).shape)

torch.Size([16, 10, 128])


In [3]:
print(net)

Transformer(
  (net): Sequential(
    (0): TransformerBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (in_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (in_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp_norm): LayerNorm((128