In [11]:
import os
import sys
ROOT_DIR = os.path.abspath("..")
sys.path.insert(0, ROOT_DIR)

In [8]:
from src.encoder import Encoder
import torch
import torch.nn.functional as F

In [15]:
encoder = Encoder(
    num_embeddings=30000,
    d_model=512,
    max_len=512,
    heads=8,
    d_ff=2048,
    dropout_p=0.1,
    num_layers=6
)

total_params = sum(p.numel() for p in encoder.parameters())
print(f"Total parameters: {total_params:,}")

Total parameters: 34,274,304


In [25]:
encoder = Encoder(
    num_embeddings=100,
    d_model=512,
    max_len=512,
    heads=8,
    d_ff=2048,
    dropout_p=0.1,
    num_layers=3
)
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

print(encoder.embedding.weight.device)
encoder.to(device)
print(encoder.embedding.weight.device)

cpu
mps:0


In [31]:
B, L, vocab_size = 2, 10, 100
input_ids = torch.randint(1, vocab_size, (B,L), dtype=torch.long)
input_ids[0, -2:] = 0
input_ids[1, -3:] = 0

attention_mask = (input_ids != 0).long()

print(input_ids)
print(attention_mask)

tensor([[72, 25, 90, 98, 41, 89, 20, 35,  0,  0],
        [34, 60, 81, 94, 14, 42,  4,  0,  0,  0]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])


In [32]:
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)

In [34]:
encoder.train(True)
output = encoder(input_ids, mask=attention_mask)

In [36]:
print("output.shape:", output.shape)
print("output.device:", output.device)

output.shape: torch.Size([2, 10, 512])
output.device: mps:0


In [38]:
loss = output.mean()
print(loss)
encoder.zero_grad()
loss.backward()

print(encoder.embedding.weight.grad)
print(encoder.embedding.weight.grad.norm())

tensor(-2.0489e-09, device='mps:0', grad_fn=<MeanBackward0>)
tensor([[ 1.9975e-13,  8.1393e-14,  1.4412e-13,  ..., -2.4508e-13,
         -3.1938e-13,  2.1931e-13],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.8544e-13, -1.2116e-13, -4.7565e-13,  ...,  5.8469e-14,
          1.0935e-14, -2.8820e-13],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='mps:0')
tensor(4.7769e-11, device='mps:0')


In [40]:
pad_idx = 0
pad_grad = encoder.embedding.weight.grad[pad_idx]
print(pad_grad.abs().sum())

tensor(9.0138e-11, device='mps:0')


In [42]:
encoder.eval()
output2 = encoder(input_ids, mask=attention_mask)
output3 = encoder(input_ids, mask=attention_mask)

print("eval mode deterministic:", torch.allclose(output2, output3))

eval mode deterministic: True
