In [1]:
import os
import sys
ROOT_DIR = os.path.abspath("..")
sys.path.insert(0, ROOT_DIR)

In [2]:
import torch
import math
from typing import Optional
from src.encoder import Encoder
from src.decoder import Decoder

In [3]:
torch.manual_seed(0)

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print("device:", device)

device: mps


In [None]:
class Transformer:
    def __init__(
        self,
        encoder_num_embeddings: int = 30000,
        decoder_num_embeddings: int = 30000,
        d_model: int = 512,
        max_len: int = 512,
        heads: int = 8,
        d_ff: int = 2048,
        dropout_p: float = 0.1,
        ln_bias: bool = True,
        attn_bias: bool = True,
        ffn_bias: bool = True,
        elementwise_affine: bool = True,
        eps: float = 1e-5,
        num_layers: int = 6,
    ):
        self.encoder = Encoder(
            num_embeddings=encoder_num_embeddings,
            d_model=d_model,
            max_len=max_len,
            heads=heads,
            d_ff=d_ff,
            dropout_p=dropout_p,
            num_layers=num_layers,   
        )
        self.decoder = Decoder(
            num_embeddings=decoder_num_embeddings,
            d_model=d_model,
            max_len=max_len,
            heads=heads,
            d_ff=d_ff,
            dropout_p=dropout_p,
            num_layers=num_layers,
        )
        self.linear = torch.randn(d_model, decoder_num_embeddings) / math.sqrt(d_model)
        self.linear.requires_grad_()
        self.bias = torch.zeros(decoder_num_embeddings, requires_grad=True)
        self.training = True
    
    def __call__(
        self,
        src_ids: torch.Tensor,
        tgt_ids: torch.Tensor,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        tgt_key_padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.forward(
            src_ids=src_ids,
            tgt_ids=tgt_ids,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
        )

    def forward(
        self,
        src_ids: torch.Tensor,
        tgt_ids: torch.Tensor,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        tgt_key_padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        _, L_tgt = tgt_ids.shape

        tgt_attn_mask = torch.tril(
            torch.ones(L_tgt, L_tgt, device=tgt_ids.device, dtype=torch.long)
        )

        memory = self.encoder(src_ids, key_padding_mask=src_key_padding_mask)
        
        x = self.decoder(
            tgt_ids,
            memory=memory,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
            tgt_attn_mask=tgt_attn_mask,
        )

        logits = torch.matmul(x, self.linear) + self.bias
        return logits
    
    def parameters(self):
        params = []
        params.extend(self.encoder.parameters())
        params.extend(self.decoder.parameters())
        params.append(self.linear)
        params.append(self.bias)
        return params
    
    def zero_grad(self):
        for p in self.parameters():
            if p.grad is not None:
                p.grad.zero_()

    def train(self, mode: bool = True):
        self.training = mode
        self.encoder.train(mode)
        self.decoder.train(mode)
        return self
    
    def eval(self):
        return self.train(False)
    
    def to(self, device: torch.device):
        self.encoder.to(device)
        self.decoder.to(device)
        self.linear = self.linear.to(device).detach().requires_grad_(True)
        self.bias = self.bias.to(device).detach().requires_grad_(True)
        return self

In [4]:
from src.transformer import Transformer

In [9]:
B, L_src, L_tgt = 4, 16, 24
src_vocab, tgt_vocab = 100, 140

pad_idx = 0

src_ids = torch.randint(1, src_vocab, (B, L_src))
tgt_ids = torch.randint(1, tgt_vocab, (B, L_tgt))

src_key_padding_mask = torch.ones(B, L_src, dtype=torch.long)
tgt_key_padding_mask = torch.ones(B, L_tgt, dtype=torch.long)

src_ids[0, -5:] = pad_idx
src_key_padding_mask[0, -5:] = 0

tgt_ids[1, -2:] = pad_idx
tgt_key_padding_mask[1, -2:] = 0

src_ids, tgt_ids, src_key_padding_mask, tgt_key_padding_mask

(tensor([[19, 41,  3, 19, 52, 70, 81, 35, 42, 82, 66,  0,  0,  0,  0,  0],
         [ 7, 88, 77, 76, 36,  4, 99, 74, 81,  2, 11, 80, 55, 45, 54, 54],
         [ 5, 88, 80, 55, 92,  7, 32, 10, 87, 42,  4, 28, 95, 56, 89, 46],
         [52, 26,  4, 71, 25,  5, 29, 80,  2, 88, 69, 45, 34, 23, 87, 77]]),
 tensor([[  2,  38, 118, 124,  51,  89,  62, 139,  32, 123, 118, 107,  83,  37,
           82,  50, 104, 136,  38,  75, 128,  25,  28, 112],
         [ 68,  14,  24, 109, 106, 110,  78,  77,  77,  47, 109, 118,  48,  21,
           96,  37,  35,  52,   9,  67,  41,  67,   0,   0],
         [ 47,  33,  86, 135,  86,  33, 139,  68,  62,  91,  74, 128,  90, 111,
           84,  48, 131,  69,  89,  76,  96,  65,  16,  55],
         [ 65,  24,  67,  92, 105,  78,   5,  52,  45,  56,  25,  72,  52, 100,
           41,  12,  77,  50,  20, 113,  31, 135, 120, 135]]),
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1

In [24]:
model = Transformer(
    encoder_num_embeddings=src_vocab,
    decoder_num_embeddings=tgt_vocab,
    d_model=64,
    max_len=64,
    heads=8,
    d_ff=128,
    dropout_p=0.1,
    num_layers=2,
).to(device).train()

src_ids = src_ids.to(device)
tgt_ids = tgt_ids.to(device)
src_key_padding_mask = src_key_padding_mask.to(device)
tgt_key_padding_mask = tgt_key_padding_mask.to(device)

logits = model(
    src_ids=src_ids,
    tgt_ids=tgt_ids,
    src_key_padding_mask=src_key_padding_mask,
    tgt_key_padding_mask=tgt_key_padding_mask,
)

print("logits shape:", logits.shape)

logits shape: torch.Size([4, 24, 140])


In [21]:
L_tgt = tgt_ids.size(1)
tgt_attn_mask = torch.tril(torch.ones(L_tgt, L_tgt, device=device, dtype=torch.long))

print("tgt_attn_mask shape:", tgt_attn_mask.shape)
print(tgt_attn_mask)

tgt_attn_mask shape: torch.Size([24, 24])
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,

In [22]:
import torch.nn.functional as F

labels = torch.randint(0, tgt_vocab, (B, L_tgt), device=device)

loss = F.cross_entropy(logits.reshape(-1, tgt_vocab), labels.reshape(-1))
print("loss:", float(loss))

loss: 5.49536657333374


In [23]:
loss.backward()
print("backward ok")

for p in model.parameters():
    if p.grad is not None:
        print(p.grad.abs().mean().item())
        break

backward ok
0.000299648119835183
