In [1]:
import os
import sys
ROOT_DIR = os.path.abspath("..")
sys.path.insert(0, ROOT_DIR)

In [2]:
from src.decoder import Decoder
import torch
import torch.nn.functional as F

In [3]:
torch.manual_seed(0)

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print("device:", device)

device: mps


In [4]:
decoder = Decoder(
    num_embeddings=30000,
    d_model=512,
    max_len=512,
    heads=8,
    d_ff=2048,
    dropout_p=0.1,
    num_layers=6
)

total_params = sum(p.numel() for p in decoder.parameters())
print(f"Total parameters: {total_params:,}")

Total parameters: 40,584,192


In [5]:
B, src_L, tgt_L = 2, 7, 6
vocab = 50
d_model = 32
heads = 4
d_ff = 64
num_layers = 2
pad_id = 0

memory = torch.randn(B, src_L, d_model, device=device, requires_grad=True)
tgt_ids = torch.tensor([
    [1,2,3,4,0,0],
    [5,6,7,0,0,0],
], dtype=torch.long, device=device)

tgt_key_padding_mask = (tgt_ids != pad_id).long() 
memory_key_padding_mask = torch.ones(B, src_L, dtype=torch.long, device=device)
tgt_attn_mask = torch.tril(torch.ones(tgt_L, tgt_L, dtype=torch.long, device=device))

In [6]:
decoder = Decoder(
    num_embeddings=vocab,
    d_model=d_model,
    max_len=128,
    heads=heads,
    d_ff=d_ff,
    dropout_p=0.1,
    num_layers=num_layers
).to(device)

decoder.train(True)

out = decoder(
    x=tgt_ids,
    memory=memory,
    tgt_key_padding_mask=tgt_key_padding_mask,
    memory_key_padding_mask=memory_key_padding_mask,
    tgt_attn_mask=tgt_attn_mask,
)

print("out shape:", out.shape)

out shape: torch.Size([2, 6, 32])


In [7]:
loss = out.mean()
decoder.zero_grad()
if memory.grad is not None:
    memory.grad.zero_()
loss.backward()

grads_ok = True
for p in decoder.parameters():
    if p.grad is None:
        grads_ok = False
        print("grad None")
        break
    if not torch.isfinite(p.grad).all():
        grads_ok = False
        print("non-finite grad found")
        break

print("grads ok:", grads_ok)

grads ok: True
