In [None]:
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
# Initialize the parameters here...
ENCODER_NUM = 1
B = 128
L = 256
D_ENCODER = 768
D_SIT = 1024
PROJ = torch.nn.Linear(D_SIT, D_ENCODER)

def mean_flat(x):
    return torch.mean(x, dim=list(range(1, len(x.size()))))

zs = []  # <-- features from encoders
zs_tilde = []  # <-- projected SiT features

for _ in range(ENCODER_NUM):
    encoder_feat = torch.randn(B, L, D_ENCODER)
    zs.append(encoder_feat)

for _ in range(ENCODER_NUM):
    sit_feat = torch.randn(B, L, D_SIT)
    zs_tilde.append(PROJ(sit_feat))

print(zs_tilde[0].shape)
print(zs[0].shape)

In [None]:
### Check if the original augmentation loss is working...
proj_loss = 0.
bsz = zs[0].shape[0]
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    # B x L x D
    for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
        # L x D
        z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) 
        z_j = torch.nn.functional.normalize(z_j, dim=-1)
        proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
print(f"Summed loss: {proj_loss}")
proj_loss /= (len(zs) * bsz)
print(f"Averaged loss: {proj_loss}")

In [None]:
### Let's investigate how we can make the semantic relation loss(regularization) work...

### We can do it on three dimensions (B, L, D), see which one works
### For the L dimension, we can drop drop the MLP projection layer since the D dimension is canceled out

def normalize_rows(a_mat):
    return F.normalize(a_mat, p=2, dim=-1)

def matrix_l1_norm(a_mat, a_tilde_mat):
    # This is the element-wise L1 loss, batch-averaged
    # return F.l1_loss(a_mat, a_tilde_mat, reduction='none').sum(dim=tuple(range(1, a_mat.ndim))).mean()
    return F.l1_loss(a_mat, a_tilde_mat, reduction='mean')

def matrix_l2_norm(a_mat, a_tilde_mat):
    # This is the element-wise L2 loss, batch-averaged
    # return torch.sqrt(F.mse_loss(a_mat, a_tilde_mat, reduction='none').sum(dim=tuple(range(1, a_mat.ndim)))).mean()
    return torch.sqrt(F.mse_loss(a_mat, a_tilde_mat, reduction='mean'))

In [None]:
# Across patches, (B, L, D) x (B, D, L) = (B, L, L) (http://arxiv.org/abs/2104.15082)
proj_loss_l1 = 0.
proj_loss_l2 = 0.
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    n_patches = z.shape[1]
    # Compute the semantic relation activation matrices A -> (B x L x L), normalize the activate row-wise (dim=-1) with L2 norm
    a_mat = F.normalize(z @ z.transpose(1, 2), dim=-1)
    a_tilde_mat = F.normalize(z_tilde @ z_tilde.transpose(1, 2), dim=-1)
    # Compute the element-wise loss, normalize it by dividing L^2
    # proj_loss_l1 += matrix_l1_norm(a_mat, a_tilde_mat) / n_patches
    # proj_loss_l2 += matrix_l2_norm(a_mat, a_tilde_mat) / n_patches
    # NOTE: We take the average across all dimensions in the norm function
    proj_loss_l1 += matrix_l1_norm(a_mat, a_tilde_mat)
    proj_loss_l2 += matrix_l2_norm(a_mat, a_tilde_mat)

print(f"Summed loss: {proj_loss_l1}, {proj_loss_l2}")
# External-encoder-averaged
proj_loss_l1 /= (len(zs))
proj_loss_l2 /= (len(zs))
print(f"Averaged loss: {proj_loss_l1}, {proj_loss_l2}")

### Think: Is it reasonable to use L1 loss here although we are using L2 normalization?
### Think: Is there any problem with the paper...

In [None]:
# Some testing to see how normalization should be done...
print((torch.sqrt(F.mse_loss(a_mat, a_tilde_mat, reduction='none').sum(dim=tuple(range(1, a_mat.ndim)))).mean(0) / L).item())  # Correct
print((torch.sqrt(F.mse_loss(a_mat, a_tilde_mat, reduction='none').sum(dim=tuple(range(1, a_mat.ndim)))).mean(0) / L**2).item())  # Wrong, should be L, it will be scaled to L^2 after sqrt

print((torch.sqrt(F.mse_loss(a_mat, a_tilde_mat, reduction='mean'))).item())  # Correct, directly compute the mean is fine
print(((torch.sqrt(F.mse_loss(a_mat, a_tilde_mat, reduction='sum') / (B * L * L)))).item())  # Correct, but we need to divide by B * L * L, because the sum is over all dimension

In [None]:
# Across patches, but use KL div
# NOTE: The intuition is: normalizing each row is essentially making each row as a distribution
#       and we can make the loss term based on the similarity between two distributions (KL div)

proj_loss_kl = 0.
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    bs = z.shape[0]
    n_patches = z.shape[1]
    curr_loss = 0.
    # Compute the semantic relation activation matrices A -> (B x L x L), normalize the activate row-wise (dim=-1) with L2 norm
    a_mat = F.normalize(z @ z.transpose(1, 2), dim=-1)
    a_tilde_mat = F.normalize(z_tilde @ z_tilde.transpose(1, 2), dim=-1)
    
    # Compute the KL divergence between two distributions
    # FIXME: Implement the KL divergence here...

print(f"Summed loss: {proj_loss_kl}")
# External-encoder-averaged
proj_loss_kl /= (len(zs))
print(f"Averaged loss: {proj_loss_kl}")

In [None]:
# Across feature, (B, D, L) x (B, L, D) = (B, D, D) (http://arxiv.org/abs/2104.10602)
proj_loss = 0.
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    emb_dim = z.shape[-1]
    # Compute the Gram matrices (channel-wise self-correlation matrices), normalize the Gram matrices row-wise (dim=-1) with L2 norm
    g_mat = F.normalize(z.transpose(1, 2) @ z, dim=-1)
    g_tilde_mat = F.normalize(z_tilde.transpose(1, 2) @ z_tilde, dim=-1)
    # Compute the element-wise (Fronebius) L2 loss, normalize it by dividing D^2
    # proj_loss += matrix_l2_norm(g_mat, g_tilde_mat) / emb_dim
    # NOTE: We take the average across all dimensions in the norm function
    proj_loss += matrix_l2_norm(g_mat, g_tilde_mat)

print(f"Summed loss: {proj_loss}")
# External-encoder-averaged
proj_loss /= (len(zs))
print(f"Averaged loss: {proj_loss}")

In [None]:
# Across samples, (B, LD) x (LD, B) = (B, B) (http://arxiv.org/abs/1907.09682)
proj_loss = 0.
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    bs = z.shape[0]
    # First do the reshape to (B, LD)
    q_mat = z.view(bs, -1)
    q_tilde_mat = z_tilde.view(bs, -1)
    # compute the self-correlation within the mini-batch, normalize the Gram matrices row-wise (dim=-1) with L2 norm
    g_mat = F.normalize(q_mat @ q_mat.transpose(0, 1), dim=-1)
    g_tilde_mat = F.normalize(q_tilde_mat @ q_tilde_mat.transpose(0, 1), dim=-1)
    # Compute the element-wise (Fronebius) L2 loss, normalize it by dividing B^2
    # proj_loss += matrix_l2_norm(g_mat.unsqueeze(0), g_tilde_mat.unsqueeze(0)) / bs
    # NOTE: We take the average across all dimensions in the norm function
    proj_loss += matrix_l2_norm(g_mat.unsqueeze(0), g_tilde_mat.unsqueeze(0))

print(f"Summed loss: {proj_loss}")
# External-encoder-averaged
proj_loss /= (len(zs))
print(f"Averaged loss: {proj_loss}")

-----------------------------

In [None]:
# Check the effect of normalizing the loss with linear / quadratic terms...

def kernel_align_loss_patch(zs, zs_tilde, pow=2):
    proj_loss = 0.
    for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
        n_patches = z.shape[1]
        # Compute the semantic relation activation matrices A -> (B x L x L), normalize the activate row-wise (dim=-1) with L2 norm
        a_mat = F.normalize(z @ z.transpose(1, 2), dim=-1)
        a_tilde_mat = F.normalize(z_tilde @ z_tilde.transpose(1, 2), dim=-1)
        proj_loss += matrix_l2_norm(a_mat, a_tilde_mat) / (n_patches ** pow)
    return (proj_loss / len(zs)).item()

ENCODER_NUM = 1
B = 256
D_ENCODER = 768
D_SIT = 1024
PROJ = torch.nn.Linear(D_SIT, D_ENCODER)
LS = [256, 512, 1024, 2048]

pow1 = []
pow2 = []

for L in LS:
    zs = []  # <-- features from encoders
    zs_tilde = []  # <-- projected SiT features

    for _ in range(ENCODER_NUM):
        encoder_feat = torch.randn(B, L, D_ENCODER)
        zs.append(encoder_feat)

    for _ in range(ENCODER_NUM):
        sit_feat = torch.randn(B, L, D_SIT)
        zs_tilde.append(PROJ(sit_feat))

    pow1.append(kernel_align_loss_patch(zs, zs_tilde, 1))
    pow2.append(kernel_align_loss_patch(zs, zs_tilde, 2))    

plt.plot(LS, pow1, label='pow1')
plt.plot(LS, pow2, label='pow2')
plt.legend()
plt.show()