In [51]:
import torch

# Define test parameters
b, t, c = 2, 4, 2  # Batch size, sequence length, embedding dimension

# Create dummy tensors
torch.manual_seed(42)  # For reproducibility
z_q = torch.randn(b, t, c)  # Random latent vectors (b, t, c)
z = torch.randn(b, t, c)  # Random latent vectors (b, t, c)
mask = torch.randint(0, 2, (b, t, 1), dtype=torch.float32)  # Random binary mask (b, t, 1)

# Create a dummy padding embedding
pad_e = torch.randn(1, 1, c)  # (1, 1, c)

# Print dummy inputs
print("z_q:\n", z_q)
print("mask:\n", mask)
print("pad_e:\n", pad_e)


z_q:
 tensor([[[ 1.9269,  1.4873],
         [ 0.9007, -2.1055],
         [ 0.6784, -1.2345],
         [-0.0431, -1.6047]],

        [[-0.7521,  1.6487],
         [-0.3925, -1.4036],
         [-0.7279, -0.5594],
         [-0.7688,  0.7624]]])
mask:
 tensor([[[1.],
         [1.],
         [1.],
         [0.]],

        [[1.],
         [0.],
         [0.],
         [0.]]])
pad_e:
 tensor([[[1.3525, 0.6863]]])


In [52]:
torch.where(mask == 0, pad_e, z_q)  # Replace masked positions with padding embeddings

tensor([[[ 1.9269,  1.4873],
         [ 0.9007, -2.1055],
         [ 0.6784, -1.2345],
         [ 1.3525,  0.6863]],

        [[-0.7521,  1.6487],
         [ 1.3525,  0.6863],
         [ 1.3525,  0.6863],
         [ 1.3525,  0.6863]]])

In [53]:
mask.shape

torch.Size([2, 4, 1])

In [54]:
z_q.shape

torch.Size([2, 4, 2])

/bin/bash: line 1: mask: command not found


In [81]:
mask


tensor([[[1.],
         [1.],
         [1.],
         [0.]],

        [[1.],
         [0.],
         [0.],
         [0.]]])

In [85]:
mse_loss

tensor([[[0.0810, 2.7122],
         [1.9547, 6.4776],
         [2.0637, 5.3493],
         [0.0000, 0.0000]],

        [[4.1260, 0.1241],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000]]])

In [None]:
import torch.nn.functional as F
mse_loss= F.mse_loss(z , z_q.detach(), reduction='none')  * mask  # (b, t, c)
valid_count = mask.sum() * z.shape[-1] # Total number of valid (non-masked) elements

In [112]:
mask.numel()

8

In [115]:
# Avoid division by zero
commitment_loss = mse_loss.sum() / valid_count 


print( valid_count, mse_loss.sum(), commitment_loss, mse_loss.sum() / (mask.numel()* z.shape[-1]))
print(F.mse_loss(z *mask, z_q.detach()*mask))

tensor(8.) tensor(22.8887) tensor(2.8611) tensor(1.4305)
tensor(1.4305)


In [78]:
mse_loss = F.mse_loss(z * mask, z_q.detach() * mask, reduction='none')  # (b, t, c)
valid_count = mask.sum()  # Count of valid (non-masked) elements

commitment_loss = mse_loss.sum() / valid_count.clamp(min=1)  # Avoid division by zero

print(valid_count, mse_loss.sum(), commitment_loss)


tensor(4.) tensor(22.8887) tensor(5.7222)


In [79]:
# Corrected global mean loss (dividing by all elements, masked and unmasked)
print(mse_loss.sum() / mask.numel())  


tensor(2.8611)
