In [1]:
import torch
import torch.nn as nn
import math
from typing import Optional


class SimpleDecoderLayer(nn.Module):
    def __init__(
        self, hidden_dim: int, head_num: int, dropout_rate: float = 0.1
    ) -> None:
        super().__init__()

        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        self.drop_att = nn.Dropout(dropout_rate)
        self.layer_norm_att = nn.LayerNorm(hidden_dim, eps=0.00001)

        self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4)
        self.down_proj = nn.Linear(hidden_dim * 4, hidden_dim)

        self.drop_ffn = nn.Dropout(dropout_rate)
        self.layer_norm_ffn = nn.LayerNorm(hidden_dim, eps=0.00001)
        self.act_fn = nn.ReLU()

    def attention_block(
        self, X: torch.Tensor, attention_mask: Optional[torch.Tensor]
    ) -> torch.Tensor:
        batch_size, seq_len, _ = X.shape
        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)
        # (batch_size, seq_len, head_num * head_dim) -> (batch_size, head_num, seq_len, head_dim)
        q = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        k = K.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        v = V.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)

        # (batch_size, head_num, seq_len, seq_len)
        attention_weight = q @ k.transpose(2, 3) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            attention_mask = attention_mask.tril()
        else:
            attention_mask = torch.ones_like(attention_weight).tril()
        attention_weight = attention_weight.masked_fill(
            attention_mask == 0, float("-inf")
        )

        attention_weight = torch.softmax(attention_weight, -1)

        attention_weight = self.drop_att(attention_weight)

        output_mid = attention_weight @ v
        output_mid = (
            output_mid.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        )
        return self.layer_norm_att(X + self.out_proj(output_mid))

    def ffn_block(self, X: torch.Tensor) -> torch.Tensor:
        output = self.up_proj(X)
        output = self.act_fn(output)
        output = self.down_proj(output)
        output = self.drop_ffn(output)
        return self.layer_norm_ffn(X + output)

    def forward(
        self, X: torch.Tensor, attention_mask: Optional[torch.Tensor]
    ) -> torch.Tensor:
        att_output = self.attention_block(X, attention_mask)

        ffn_output = self.ffn_block(att_output)

        return ffn_output


In [2]:
x = torch.rand(3, 4, 64)
net = SimpleDecoderLayer(64, 8)
# (batch_size, head_num, seq_len, seq_len)
# (3, 4) -> (3, 8, 4, 4)
mask = (
    torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
    .unsqueeze(1)
    .unsqueeze(2)
    .repeat(1, 8, 4, 1)
)
output = net(x, mask)
print(output.shape)

torch.Size([3, 4, 64])


In [3]:
class SimpleDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_list = nn.ModuleList([SimpleDecoderLayer(64, 8) for _ in range(6)])
        self.emb = nn.Embedding(12, 64)
        self.out = nn.Linear(64, 12)

    def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
        X = self.emb(X)  # (batch_size, seq_len, hidden_dim)
        for layer in self.layer_list:
            X = layer(X, mask)
        X = self.out(X)  # (batch_size, seq_len, vocab_size)
        return torch.softmax(X, dim=-1)

In [4]:
x = torch.randint(low=0, high=12, size=(3, 4))
net = SimpleDecoder()
output = net(x, mask)
print(output.shape)  # (3, 4, 12)

torch.Size([3, 4, 12])
