In [None]:

import torch

def cdist(x, y, eps=1e-8):
    '''
    Args:
        x: [B, C1, D]
        y: [B, C2, D]
    Returns: [B, C1, C2]

    Same effect as torch.cdist, but faster. 
    '''
    xydot = torch.einsum("bnd,bmd->bnm", x, y)
    xnorms = torch.einsum("bnd,bnd->bn", x, x)
    ynorms = torch.einsum("bmd,bmd->bm", y, y)
    return (xnorms[:, :, None] + ynorms[:, None, :] - 2 * xydot).clamp(min=eps).sqrt()
def sg(x):
    return x.detach()
def attn_contra_loss(target, recon, return_info=False, sample_norm=True, weight_r = None, scale_dist=False, no_R_norm=False, no_global_norm=False, new_R_norm=False, scale_dist_normed=True, R_list = [0.2], coord_norm=False, softmax=True, gen_attend_data=True,data_attend_gen=True,softmax_p=0.5,norm_R_force=False):
    '''
    Best recommendation:
        sample_norm: True
        scale_disr_normed: True
        Other: False
        This will align different features. Other settings are mostly for exploration.
    Args:
        target: [batch_size, C1, S]
        recon: [batch_size, C2, S]
        return_info: whether to return the info dict
        sample_norm: whether to normalize the loss by the sample norm; 
            if enabled: loss will have shape (B, )
        weight_r: 
            The weight for negative samples. None or shape (B, C2). 
            When enabled: all repulsions will be weighted by weight_r. 
        softmax: whether to use softmax to normalize the affinity matrix. 
    Returns:
        loss: [batch_size]
        (optional) info: a dict with entries:
            force_norm: the norm of the force.  
            prec: the precision of the target. 

    '''
    B, C1, S = target.shape
    B, C2, S = recon.shape
    if coord_norm:
        # normalize, to make sure every coordinate has mean 0 & var 1
        coord_mean = torch.cat([target, recon], dim=1).mean(dim=(0, 1)).detach()
        coord_std = torch.cat([target, recon], dim=1).std(dim=(0, 1)).detach()
        target = (target - coord_mean) / (coord_std + 1e-3)
        recon = (recon - coord_mean) / (coord_std + 1e-3)

    with torch.no_grad():
        pos_neg = torch.cat([target, recon], dim=1)
        dist = cdist(pos_neg, pos_neg) # [B, C1 + C2, C1 + C2]
        # assert target.shape[1] == 1
        scale = dist.mean()
        if sample_norm:
            scale = dist.mean(dim=(-1, -2), keepdim=True)
    coord_scale_fac = 1
    if scale_dist:
        target, recon, pos_neg = target / scale, recon / scale, pos_neg / scale
        coord_scale_fac = 1 / scale
    if scale_dist_normed:
        assert not scale_dist
        scale_2 = scale / (S ** 0.5)
        target, recon, pos_neg = target / scale_2, recon / scale_2, pos_neg / scale_2
        coord_scale_fac = 1 / scale_2
    with torch.no_grad():
        dist = dist + torch.eye(C1 + C2, device=dist.device) * 100 * scale

        info = {"nn_data": dist[:, :C1].min(dim=-1).values.mean(), "nn_samples": dist[:, C1:].min(dim=-1).values.mean(), "scale": scale.mean(), "scale_normed": scale.mean() / (S ** 0.5)}
        
        dist = dist / scale
        info['dist_std'] = (dist - torch.eye(C1 + C2, device=dist.device) * 100).std(dim=(-1,-2)).mean()
        
        total_pos = torch.zeros_like(dist)

        for R in R_list:
            if softmax:
                affinity = torch.ones_like(dist)
                if gen_attend_data:
                    affinity = affinity * ((-dist / R).softmax(dim=-1) ** softmax_p)
                if data_attend_gen:
                    affinity = affinity * ((-dist / R).softmax(dim=-2) ** softmax_p)
            else:
                # mean_dist = (-dist).logsumexp(dim=(-1, -2), keepdim=True)
                # affinity = ((-mean_dist - dist) / R).sigmoid()
                affinity = (-dist / R).exp()

            if weight_r is not None:
                affinity[:, :, C1:] = affinity[:, :, C1:] * weight_r[:, None, :]
            
            if sample_norm:
                norm_est = ((affinity * dist) ** 2).mean(dim=(-1, -2), keepdim=True).clamp_min(1e-8).sqrt()
                if norm_R_force:
                    norm_est = norm_est * affinity.mean(dim=(-1,-2),keepdim=True)
            cur_force = torch.zeros_like(dist)
            info[f'pos_ker_{R}'] = affinity[:, C1:, :C1].sum(dim=-1).mean()
            info[f'neg_ker_{R}'] = affinity[:, C1:, C1:].sum(dim=-1).mean()
            cur_force[:, C1:, C1:] = -affinity[:, C1:, C1:] * (affinity[:, C1:, :C1].sum(dim=-1, keepdim=True))
            cur_force[:, C1:, :C1] = affinity[:, C1:, :C1] * (affinity[:, C1:, C1:].sum(dim=-1, keepdim=True))
            if not sample_norm:
                norm_est = ((cur_force * dist) ** 2).mean().clamp_min(1e-8).sqrt()
            
            if new_R_norm:
                norm_est = (affinity * dist).mean().clamp_min(1e-8)
                if norm_R_force:
                    norm_est = norm_est * affinity.mean(dim=(-1,-2),keepdim=True)
            print("! N est", norm_est)
            if not no_R_norm:
                cur_force = cur_force / norm_est
            total_pos = total_pos + cur_force
            info[f'norm_{R}'] = norm_est.mean()
            info[f'norm_{R}_std'] = norm_est.std() if (len(norm_est.shape) > 0 and norm_est.shape[0] > 1) else 0
        
        sum_forces = torch.einsum("biy,byx->bix", total_pos[:, C1:], pos_neg)
        info['f_norm'] = ((sum_forces ** 2).mean())
        if not no_global_norm:
            sum_forces = sum_forces / ((sum_forces ** 2).mean().clamp_min(1e-8).sqrt())
        goal = sg(recon + sum_forces)

    grads = sum_forces * coord_scale_fac
    return grads
    
    if sample_norm:
        loss = ((recon - goal) ** 2).mean(dim=(-1, -2))
    else:
        loss = ((recon - goal) ** 2).mean()
    if return_info:
        return loss, info
    return loss

In [None]:
info_dict = {
    "sample_norm": True,
    "scale_dist_normed": True,
    "norm_R_force": False,
    "new_R_norm": False,
    "no_R_norm": False,
    "softmax": True,
    "no_global_norm": False,
}
target = torch.ones(2, 100, 1)
recon = torch.ones(2, 20, 1)
target[0, 0:1, :] = 0.5
recon[0, 0:2, :] = 0
target[1, :, :] = torch.randn(100)[:, None]
recon[1, :, :] = torch.randn(20)[:, None]
loss = attn_contra_loss(target, recon, **info_dict)
print("NORMS", (loss ** 2).mean(dim=(-1,-2)))

! tensor(0.0004)
NORMS tensor([0.6034, 4.5403])
