In [1]:
import torch

In [2]:
# Initialize the parameters here...
ENCODER_NUM = 1
B = 128
L = 256
D_ENCODER = 768
D_SIT = 1024
PROJ = torch.nn.Linear(D_SIT, D_ENCODER)

def mean_flat(x):
    return torch.mean(x, dim=list(range(1, len(x.size()))))

zs = []  # <-- features from encoders
zs_tilde = []  # <-- projected SiT features

for _ in range(ENCODER_NUM):
    encoder_feat = torch.randn(B, L, D_ENCODER)
    zs.append(encoder_feat)

for _ in range(ENCODER_NUM):
    sit_feat = torch.randn(B, L, D_SIT)
    zs_tilde.append(PROJ(sit_feat))

print(zs_tilde[0].shape)
print(zs[0].shape)

torch.Size([128, 256, 768])
torch.Size([128, 256, 768])


In [3]:
### Check if the original augmentation loss is working...
proj_loss = 0.
bsz = zs[0].shape[0]
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    # B x L x D
    for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
        # L x D
        z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) 
        z_j = torch.nn.functional.normalize(z_j, dim=-1)
        proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
print(f"Summed loss: {proj_loss}")
proj_loss /= (len(zs) * bsz)
print(f"Averaged loss: {proj_loss}")

Summed loss: -0.004410076420754194
Averaged loss: -3.445372203714214e-05


In [4]:
### Let's investigate how we can make the semantic relation loss(regularization) work...

### We can do it on three dimensions (B, L, D), see which one works
### For the L dimension, we can drop drop the MLP projection layer since the D dimension is canceled out

def normalize_rows(a_mat):
    return a_mat / (torch.norm(a_mat, dim=-1, keepdim=True) + 1e-8)

def element_wise_l1_loss(a_mat, a_tilde_mat):
    # This is the element-wise L1 loss, batch-averaged
    return torch.abs(a_mat - a_tilde_mat).sum(dim=list(range(1, len(a_mat.shape)))).mean(0)

def element_wise_l2_loss(a_mat, a_tilde_mat):
    # This is the element-wise L2 loss, batch-averaged
    return torch.sqrt(torch.pow(a_mat - a_tilde_mat, 2).sum(dim=list(range(1, len(a_mat.shape))))).mean(0)

In [5]:
# D dimension reduction, (B, L, D) x (B, D, L) = (B, L, L) (http://arxiv.org/abs/2104.15082)
proj_loss_l1 = 0.
proj_loss_l2 = 0.
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    # Compute the semantic relation activation matrices A -> (B x L x L)
    a_mat = torch.matmul(z, z.transpose(1, 2))
    a_tilde_mat = torch.matmul(z_tilde, z_tilde.transpose(1, 2))
    # normalize the activate row-wise (dim=-1) with L2 norm
    a_mat = normalize_rows(a_mat)
    a_tilde_mat = normalize_rows(a_tilde_mat)
    # Compute the element-wise L1 loss
    proj_loss_l1 += element_wise_l1_loss(a_mat, a_tilde_mat)
    proj_loss_l2 += element_wise_l2_loss(a_mat, a_tilde_mat)

print(f"Summed loss: {proj_loss_l1}, {proj_loss_l2}")
# External-encoder-averaged
proj_loss_l1 /= (len(zs))
proj_loss_l2 /= (len(zs))
print(f"Averaged loss: {proj_loss_l1}, {proj_loss_l2}")

### Think: Is it reasonable to use L1 loss here although we are using L2 normalization?
### Think: Is there any problem with the paper...

Summed loss: 2578.52978515625, 12.611856460571289
Averaged loss: 2578.52978515625, 12.611856460571289


In [6]:
# L dimension reduction, (B, D, L) x (B, L, D) = (B, D, D) (http://arxiv.org/abs/2104.10602)
proj_loss = 0.
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    emb_dim = z.shape[-1]
    # Compute the Gram matrices (channel-wise self-correlation matrices)
    g_mat = torch.matmul(z.transpose(1, 2), z)
    g_tilde_mat = torch.matmul(z_tilde.transpose(1, 2), z_tilde)
    # normalize the Gram matrices row-wise (dim=-1) with L2 norm
    g_mat = normalize_rows(g_mat)
    g_tilde_mat = normalize_rows(g_tilde_mat)
    # Compute the element-wise (Fronebius) L2 loss, normalize it by dividing D
    proj_loss += element_wise_l2_loss(g_mat, g_tilde_mat) / emb_dim

print(f"Summed loss: {proj_loss}")
# External-encoder-averaged
proj_loss /= (len(zs))
print(f"Averaged loss: {proj_loss}")

Summed loss: 0.04479004070162773
Averaged loss: 0.04479004070162773


In [7]:
# B dimension reduction, (B, LD) x (LD, B) = (B, B) (http://arxiv.org/abs/1907.09682)
proj_loss = 0.
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    bs = z.shape[0]
    # First do the reshape to (B, LD)
    q_mat = z.view(bs, -1)
    q_tilde_mat = z_tilde.view(bs, -1)
    # compute the self-correlation within the mini-batch
    g_mat = torch.matmul(q_mat, q_mat.transpose(0, 1))
    g_tilde_mat = torch.matmul(q_tilde_mat, q_tilde_mat.transpose(0, 1))
    # normalize the Gram matrices row-wise (dim=-1) with L2 norm
    g_mat = normalize_rows(g_mat)
    g_tilde_mat = normalize_rows(g_tilde_mat)
    # Compute the element-wise (Fronebius) L2 loss, normalize it by dividing B^2
    proj_loss += element_wise_l2_loss(g_mat.unsqueeze(0), g_tilde_mat.unsqueeze(0)) / (bs * bs)

print(f"Summed loss: {proj_loss}")
# External-encoder-averaged
proj_loss /= (len(zs))
print(f"Averaged loss: {proj_loss}")

Summed loss: 3.0201676054275595e-05
Averaged loss: 3.0201676054275595e-05
