In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Latent(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
                
    def forward(self, data, **kwargs):
        # data['z'] : (N, c, H, W)
        # data['e'] : (M, c)
        
        z_dim = data['z'].shape[1]
        # (NHW, c)
        z = data['z'].permute(0, 2, 3, 1).reshape(-1, z_dim)
        N = len(z)
        distance_p = kwargs['distance_p'] if 'distance_p' in kwargs else 1
        
        # (NHW, M) = sum((NHW, 1, z) - (1, M, z), dim=2)
        distance = torch.norm(z.unsqueeze(1) - data['e'].unsqueeze(0), dim=2)
        # (NHW, M)
        belong = data['belong'] if 'belong' in data else None
        
        if belong is not None:
            # (1, M)
            norm_factor = torch.max(belong * distance, dim=0).values.unsqueeze(0)
            distance /= norm_factor
            distance = torch.where(distance < 1, distance ** distance_p, distance ** (1/distance_p))
            distance *= norm_factor
        loss = -torch.mean(torch.logsumexp(-distance, dim=0))
        data['lse_loss'] = loss
        
        return data

In [21]:
latent = Latent()

In [23]:
z = torch.randn(2, 512, 4, 4)
belong = torch.ones_like(z).permute(0, 2, 3, 1)[:, :, :, 0].reshape(-1)
data = {'z': z,
        'belong': belong

torch.Size([32])


In [19]:
torch.where(x<1, x**(1/2), x**2)

tensor([0.0000, 0.7071, 1.0000, 2.2500, 4.0000])