In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Latent(nn.Module):
    def __init__(self, init_log_sigma, const_sigma, clipping=0, **kwargs):
        super().__init__()
        self.log_sigma = nn.Parameter(torch.ones(1) * init_log_sigma, requires_grad=not const_sigma)
        self.clipping = clipping
                
    def forward(self, data, **kwargs):
        # data['z'] : (N, c, H, W)
        # data['e'] : (M, c)
        
        z_dim = data['z'].shape[1]
        # (NHW, c)
        z = data['z'].permute(0, 2, 3, 1).reshape(-1, z_dim)
        N = len(z)
        T = kwargs['latent_temp'] if 'latent_temp' in kwargs else 1.0
        
        # (NHW, M) = sum((NHW, 1, z) - (1, M, z), dim=2)
        distance = torch.norm(z.unsqueeze(1) - data['e'].unsqueeze(0), dim=2) ** 2
        if self.clipping > 0:
            inf = torch.full_like(distance, 1e+3)
            index = (distance < self.clipping).float()
            print(index)
            distance = (1-index) * distance + index * inf
        alpha = -1/(2*torch.exp(self.log_sigma)**2)
        matrix = alpha*distance/T
        data['matrix'] = matrix
        loss = -torch.mean(T*torch.logsumexp(matrix, dim=0))
        loss = loss + 0.5*z_dim*(2*self.log_sigma-np.log(np.e)) + np.log(N)        
        data['lse_loss'] = loss
        
        return data

In [36]:
data = {'z': torch.zeros(1000, 2, 8, 8),
        'e': torch.zeros(100, 2)}
data = Latent(-3, False, 1e-3)(data)
data.keys()

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])


dict_keys(['z', 'e', 'matrix', 'lse_loss'])

In [37]:
data['lse_loss']

tensor([201707.4062], grad_fn=<AddBackward0>)

In [38]:
1e+3

1000.0