In [None]:
!pip install torchinfo

In [None]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torchinfo import summary

In [None]:
class MaskedMultiSelfAttention(nn.Module):
  def __init__(self, h_dim, max_T, n_heads, drop_p):
    super().__init__()
    self.n_heads = n_heads

    self.q_net = nn.Linear(h_dim, h_dim)
    self.k_net = nn.Linear(h_dim, h_dim)
    self.v_net = nn.Linear(h_dim, h_dim)

    self.proj_net = nn.Linear(h_dim, h_dim)

    self.attn_drop = nn.Dropout(drop_p)
    self.proj_drop = nn.Dropout(drop_p)

    # Make lower triangle matrix with one
    ones = torch.ones((max_T, max_T))
    mask = torch.tril(ones).view(1, 1, max_T, max_T)

    # mask is constant
    self.register_buffer('mask', mask)

  def forward(self, x):
    B, T, C = x.shape
    N, D = self.n_heads, C // self.n_heads

    q = self.q_net(x).view(B, T, N, D).transpose(1, 2)
    k = self.k_net(x).view(B, T, N, D).transpose(1, 2)
    v = self.v_net(x).view(B, T, N, D).transpose(1, 2)

    weights = q @ k.transpose(2, 3) / math.sqrt(D)

    # Masked causal weights
    weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf'))

    # Normalize weights : all -inf -> 0 after softmax
    normalized_weights = F.softmax(weights, dim=-1)

    # Masked causal attention (B, N, T, D)
    attention = self.attn_drop(normalized_weights @ v)
    attention = attention.transpose(1, 2).contiguous().view(B, T, N * D)

    out = self.proj_drop(self.proj_net(attention))
    return out

In [None]:
class TransformerDecoderBlock(nn.Module):
  def __init__(self, h_dim, max_T, n_heads, drop_p):
    super().__init__()
    self.attn = MaskedMultiSelfAttention(h_dim, max_T, n_heads, drop_p)
    self.mlp = nn.Sequential(
        nn.Linear(h_dim, 4 * h_dim),
        nn.GELU(),
        nn.Linear(4 * h_dim, h_dim),
        nn.Dropout(drop_p)
    )
    self.ln1 = nn.LayerNorm(h_dim)
    self.ln2 = nn.LayerNorm(h_dim)

  def forward(self, x):
    # MaskedMultiSelfAttention -> LayerNorm -> FeedForward -> LayerNorm
    x = self.attn(x) + x
    x = self.ln1(x)
    x = self.mlp(x) + x
    x = self.ln2(x)
    return x

In [None]:
B, T, D = 4, 8, 64
n_heads = 12

In [None]:
block = TransformerDecoderBlock(h_dim=n_heads*D, max_T=T, n_heads=n_heads, drop_p=0.1)

In [None]:
summary(block, input_size=(B, T, n_heads * D))

In [None]:
from google.colab import runtime

runtime.unassign()