In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Latent(nn.Module):
    def __init__(self, n_prior_embeddings, **kwargs):
        super().__init__()
        self.max_distance = nn.Parameter(torch.ones(1, n_prior_embeddings), requires_grad=False)
                
    def forward(self, data, **kwargs):
        # data['z'] : (N, c, H, W)
        # data['e'] : (M, c)
        
        z_dim = data['z'].shape[1]
        # (NHW, c)
        z = data['z'].permute(0, 2, 3, 1).reshape(-1, z_dim)
        N = len(z)
        distance_p = kwargs['distance_p'] if 'distance_p' in kwargs else 1
        
        # (NHW, M) = sum((NHW, 1, z) - (1, M, z), dim=2)
        distance = torch.norm(z.unsqueeze(1) - data['e'].unsqueeze(0), dim=2)
        # (NHW, M)
        belong = data['belong'] if 'belong' in data else None
        
        if belong is not None:
            # (1, M)
            max_distance = torch.max(belong * distance, dim=0).values.unsqueeze(0)
            self.max_distance.data = 0.999 * self.max_distance + (1-0.999) * max_distance
            max_distance = torch.clamp(self.max_distance, 1e-8)
            distance = torch.where(distance < max_distance, distance ** (2/distance_p), distance ** (2*distance_p))
        loss = -torch.mean(torch.logsumexp(-distance, dim=0))
        data['lse_loss'] = loss
        
        return data

In [13]:
latent = Latent(100)
data = {'z': torch.randn(2, 512, 8, 8),
        'e': torch.randn(100, 512),
        'belong': torch.randint(0, 2, size=(128, 100))}
latent(data)
                         

torch.Size([1, 100]) torch.Size([1, 100])


{'z': tensor([[[[ 2.7956e+00, -2.7753e-01, -2.1402e-01,  ..., -2.3862e-01,
            -8.8758e-02, -1.7456e+00],
           [-3.4922e-01, -6.6686e-01, -1.6332e+00,  ...,  4.8552e-01,
             4.5820e-01, -2.1743e+00],
           [ 3.0604e+00,  2.2054e+00,  3.8421e-01,  ..., -1.5560e+00,
             3.4537e-01, -2.0184e+00],
           ...,
           [ 6.6592e-01, -2.1014e+00,  6.0068e-01,  ...,  7.5445e-01,
            -6.0734e-01, -8.9665e-01],
           [-2.9645e-01, -9.9076e-01, -2.3870e-01,  ..., -1.8589e-01,
            -1.8437e+00, -1.1748e+00],
           [-8.3086e-01,  1.3871e+00, -8.9438e-01,  ...,  6.7656e-01,
            -2.4760e-01, -9.1896e-01]],
 
          [[-6.8885e-02, -1.3255e+00,  3.6481e-01,  ...,  1.3183e-01,
             8.6723e-01,  1.6122e-01],
           [ 1.0089e+00, -7.7465e-01,  5.1216e-01,  ...,  8.9956e-01,
             9.3167e-01,  2.5005e-02],
           [-8.2946e-01, -1.1601e+00, -2.2664e-02,  ..., -1.0739e-01,
             2.0560e+00, -1.2470e+

In [12]:
2*8*8

128