In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Latent(nn.Module):
    def __init__(self, init_log_sigma, const_sigma, **kwargs):
        super().__init__()
        self.log_sigma = nn.Parameter(torch.ones(1) * init_log_sigma, requires_grad=not const_sigma)
                
    def forward(self, data, **kwargs):
        # data['z'] : (N, c, H, W)
        # data['e'] : (M, c)
        
        z_dim = data['z'].shape[1]
        # (NHW, c)
        z = data['z'].permute(0, 2, 3, 1).reshape(-1, z_dim)
        N = len(z)
        T = kwargs['latent_temp'] if 'latent_temp' in kwargs else 1.0
        softmax_temp = kwargs['softmax_temp'] if 'softmax_temp' in kwargs else 1.0
        
        # (NHW, M) = sum((NHW, 1, z) - (1, M, z), dim=2)
        distance = torch.norm(z.unsqueeze(1) - data['e'].unsqueeze(0), dim=2) ** 2
        alpha = -1/(2*torch.exp(self.log_sigma)**2)
        matrix = alpha*distance/T
        data['matrix'] = matrix
        # (NHW, M)
        belong = data['belong'] if 'belong' in data else None
        loss = -torch.mean(T*CustomLogSumExp.apply(matrix, belong, softmax_temp))
        loss = loss + 0.5*z_dim*(2*self.log_sigma-np.log(np.e)) + np.log(N)        
        data['lse_loss'] = loss
        
        return data
    
class CustomLogSumExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, belong=None, temp=1):
        # input : (N, M)
        # belong : (N, M)
        
        ctx.temp = temp
        # (1, M)
        output = torch.logsumexp(input, dim=0, keepdim=True)
        ctx.save_for_backward(input, output, belong)
        return output.squeeze(0)  # output을 반환할 때는 차원을 줄입니다.

    @staticmethod
    def backward(ctx, grad_output):
        temp = ctx.temp
        # (N, M), (1, M), (N, M)
        input, output, belong = ctx.saved_tensors
        # softmax 함수를 사용하여 그래디언트 계산을 수행합니다.
        # (N, M)
        if belong is None:
            softmax_result = torch.exp(input - output)
        else:
            inner_value = belong * input + (1-belong) * (np.log(max(temp, 1e-15)) + input)
            softmax_result = torch.softmax(inner_value, dim=0)
        grad_input = softmax_result * grad_output.unsqueeze(0)
        return grad_input, None, None

In [64]:
latent = Latent(init_log_sigma=0, const_sigma=True)
data = {'z': torch.randn(4, 1, 2, 2),
        'e': torch.randn(10, 1),
        'belong': torch.randint(0, 2, size=(4*2*2, 10))}
latent(data)

{'z': tensor([[[[ 0.3048, -1.0783],
           [ 0.6670,  0.8956]]],
 
 
         [[[ 0.4635, -0.5543],
           [-0.9337,  0.7969]]],
 
 
         [[[-2.0398,  1.2159],
           [ 0.2800, -2.3699]]],
 
 
         [[[-1.1778, -1.7241],
           [-2.3137,  0.8922]]]]),
 'e': tensor([[ 1.9999],
         [-0.1768],
         [-0.3227],
         [ 0.4121],
         [ 0.0367],
         [ 0.7077],
         [-1.1629],
         [ 0.6970],
         [-1.0937],
         [ 1.7306]]),
 'belong': tensor([[0, 0, 0, 0, 1, 0, 1, 0, 1, 1],
         [0, 0, 1, 1, 1, 0, 1, 0, 1, 1],
         [1, 1, 0, 0, 0, 0, 1, 1, 0, 0],
         [0, 0, 0, 1, 1, 0, 1, 1, 1, 0],
         [0, 1, 1, 1, 1, 1, 1, 0, 1, 0],
         [0, 0, 1, 0, 1, 0, 0, 1, 1, 1],
         [0, 1, 1, 0, 1, 0, 1, 0, 1, 1],
         [1, 0, 1, 0, 1, 1, 0, 1, 1, 1],
         [1, 1, 0, 1, 1, 0, 1, 0, 0, 0],
         [0, 0, 1, 1, 1, 1, 0, 1, 1, 1],
         [1, 0, 0, 1, 1, 1, 0, 0, 1, 0],
         [0, 0, 1, 0, 1, 1, 1, 0, 0, 0],
         [0, 1, 

In [65]:
# 함수를 사용하려면 apply 메서드를 사용하고, dim 매개변수를 전달합니다.
input = torch.randn(3, 4, requires_grad=True)
belong = torch.randint(0, 2, size=[*input.shape])
output = CustomLogSumExp.apply(input, belong, 0)
output.backward(torch.ones_like(output))


tensor([[1, 0, 0, 1],
        [1, 1, 1, 1],
        [1, 1, 0, 1]])
tensor([[ 0.9080,  1.1270, -1.4178, -0.1093],
        [-0.4860,  0.5325, -0.0897,  0.7186],
        [ 0.9674, -1.4297, -0.4424, -0.6589]], requires_grad=True)
tensor([[  0.9080, -33.4118, -35.9566,  -0.1093],
        [ -0.4860,   0.5325,  -0.0897,   0.7186],
        [  0.9674,  -1.4297, -34.9812,  -0.6589]])
tensor([[4.3304e-01, 1.5888e-15, 2.6498e-16, 2.5869e-01],
        [1.0743e-01, 8.7677e-01, 1.0000e+00, 5.9200e-01],
        [4.5953e-01, 1.2323e-01, 7.0277e-16, 1.4931e-01]])
