In [45]:
import os
import sys
ROOT_DIR = os.path.abspath("..")
sys.path.insert(0, ROOT_DIR)

In [46]:
from src.encoder import Encoder
import torch
import torch.nn.functional as F

In [47]:
encoder = Encoder(
    num_embeddings=30000,
    d_model=512,
    max_len=512,
    heads=8,
    d_ff=2048,
    dropout_p=0.1,
    num_layers=6
)

total_params = sum(p.numel() for p in encoder.parameters())
print(f"Total parameters: {total_params:,}")

Total parameters: 34,274,304


In [48]:
encoder = Encoder(
    num_embeddings=100,
    d_model=512,
    max_len=512,
    heads=8,
    d_ff=2048,
    dropout_p=0.1,
    num_layers=3
)
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

print(encoder.embedding.weight.device)
encoder.to(device)
print(encoder.embedding.weight.device)

cpu
mps:0


In [49]:
B, L, vocab_size = 2, 10, 100
input_ids = torch.randint(1, vocab_size, (B,L), dtype=torch.long)
input_ids[0, -2:] = 0
input_ids[1, -3:] = 0

attention_mask = (input_ids != 0).long()

print(input_ids)
print(attention_mask)

tensor([[70, 52, 25, 83, 66, 66, 86, 18,  0,  0],
        [92,  3, 88, 67, 76, 67,  1,  0,  0,  0]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])


In [50]:
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)

In [51]:
encoder.train(True)
output = encoder(input_ids, mask=attention_mask)

In [52]:
print("output.shape:", output.shape)
print("output.device:", output.device)

output.shape: torch.Size([2, 10, 512])
output.device: mps:0


In [53]:
loss = output.mean()
print(loss)
encoder.zero_grad()
loss.backward()

print(encoder.embedding.weight.grad)
print(encoder.embedding.weight.grad.norm())

tensor(4.1910e-09, device='mps:0', grad_fn=<MeanBackward0>)
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [5.6027e-14, 2.7794e-15, 2.8424e-13,  ..., 1.2529e-14, 4.3345e-14,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]], device='mps:0')
tensor(5.4930e-11, device='mps:0')


In [54]:
pad_idx = 0
pad_grad = encoder.embedding.weight.grad[pad_idx]
print(pad_grad.abs().sum())

tensor(0., device='mps:0')


In [55]:
encoder.eval()
output2 = encoder(input_ids, mask=attention_mask)
output3 = encoder(input_ids, mask=attention_mask)

print("eval mode deterministic:", torch.allclose(output2, output3))

eval mode deterministic: True
