In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Latent(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
                
    def forward(self, data, **kwargs):
        # data['z'] : (N, c, H, W)
        # data['e'] : (M, c)
        
        z_dim = data['z'].shape[1]
        # (NHW, c)
        z = data['z'].permute(0, 2, 3, 1).reshape(-1, z_dim)
        N = len(z)
        distance_p = kwargs['distance_p'] if 'distance_p' in kwargs else 1
        
        # (NHW, M) = sum((NHW, 1, z) - (1, M, z), dim=2)
        distance = torch.norm(z.unsqueeze(1) - data['e'].unsqueeze(0), dim=2)
        # (NHW, M)
        belong = data['belong'] if 'belong' in data else None
        
        if belong is not None:
            # (1, M)
            norm_factor = torch.max(belong * distance, dim=0).values.unsqueeze(0)
            distance /= norm_factor
            distance = torch.where(distance < 1, distance ** distance_p, distance ** (1/distance_p))
            distance *= norm_factor
        loss = -torch.mean(torch.logsumexp(-distance, dim=0))
        data['lse_loss'] = loss
        
        return data

In [21]:
latent = Latent()

In [26]:
z = torch.randn(2, 512, 4, 4)
e = torch.randn(10, 512)
belong = torch.ones(32, 10)
distance_p = 2
data = {'z': z,
        'e': e,
        'belong': belong}
latent(data)

{'z': tensor([[[[-0.6370,  0.4413, -1.6222,  0.1365],
           [-0.3392, -0.9638, -0.2849,  0.2511],
           [ 0.5685,  1.6154, -2.2267,  0.0689],
           [ 2.3972,  0.1228, -0.6323,  0.1561]],
 
          [[-0.3138, -0.8191,  0.9409,  0.9433],
           [-0.9637, -0.7515,  1.1043, -1.1919],
           [-0.3034, -0.1115, -1.2882, -0.2467],
           [ 0.7775,  1.1563,  0.2742, -1.9066]],
 
          [[-0.1660, -2.1303, -0.3707,  0.0668],
           [ 1.1863,  1.1630,  1.2622,  0.1851],
           [-1.9704, -0.3641,  1.2678,  0.3567],
           [-1.4183, -0.2792, -0.8364,  1.1103]],
 
          ...,
 
          [[-0.6843, -1.0787,  0.0937, -0.7699],
           [-0.9051,  0.3880,  0.5267, -0.3764],
           [ 1.6540,  0.2144,  1.5670,  0.2223],
           [-0.4773, -1.7565,  1.1343, -0.2026]],
 
          [[ 0.6445,  0.5554, -0.2026, -0.3253],
           [ 0.4909,  0.4335,  0.1496, -1.3628],
           [ 1.1633, -1.5559,  1.0321, -1.1998],
           [-0.6010, -1.9488, -0.10

In [19]:
torch.where(x<1, x**(1/2), x**2)

tensor([0.0000, 0.7071, 1.0000, 2.2500, 4.0000])