In [1]:
import math

import torch as t
import torch.nn as nn
import torch.nn.functional as F

from torch import einsum
from einops import rearrange, reduce, repeat
from torchtyping import TensorType

import gpt_tests

In [2]:
DEVICE = t.device("cpu")

In [10]:
class Attention(nn.Module):
    
    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads
        self.attn_layer = nn.Linear(hidden_size, 3*hidden_size)
        self.output_layer = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, input: TensorType["batch_size", "seq_len", "hidden_size"], neg_inf=-1e4):
        batch_size, seq_len, hidden_size = input.shape

        # calculate attn_score
        queries, keys, values = t.split(self.attn_layer(input), hidden_size, dim=-1)
        queries = rearrange(queries, "b n (h p)-> b h n p", h=self.num_heads)
        keys = rearrange(keys, "b n (h p) -> b h n p", h=self.num_heads)
        values = rearrange(values, "b n (h p) -> b h n p", h=self.num_heads)
        attn_score = einsum("b h t p, b h f p -> b h t f", queries, keys) / math.sqrt(self.head_size)
        mask = t.ones((seq_len, seq_len), dtype=t.bool, device=DEVICE).triu(diagonal=1)
        mask = repeat(mask, "t f -> b h t f", b=batch_size, h=self.num_heads)
        attn_score[mask] = neg_inf

        # calculate output
        attn_pattern = t.softmax(attn_score, dim=-1)
        attn_concat = einsum("b h t f, b h f p -> b h t p", attn_pattern, values)
        attn_concat = rearrange(attn_concat, "b h t p -> b t (h p)")
        
        return self.output_layer(attn_concat)

In [11]:
gpt_tests.test_unidirectional_attn(Attention)

Congrats! You've passed the test!


In [15]:
class Block(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        dropout: float,
        layer_norm_epsilon: float
    ):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.attention = Attention(hidden_size, num_heads)
        self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.linear_1 = nn.Linear(hidden_size, 4*hidden_size)
        self.linear_2 = nn.Linear(4*hidden_size, hidden_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, input: TensorType["batch_size", "seq_len", "hidden_size"]):
        residual_attn = input + self.attention(self.layer_norm_1(input))
        normed_attn = self.layer_norm_2(residual_attn)
        mlp = self.linear_2(F.gelu(self.linear_1(normed_attn)))
        dropout = self.dropout(mlp)
        return residual_attn + dropout


In [16]:
gpt_tests.test_gpt_block(Block)

Congrats! You've passed the test!
