In [60]:
import torch
import numpy as np

class CustomLogSumExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, belong=None, temp=1):
        # input : (M, N)
        # belong : (M, N)
        
        ctx.temp = temp
        # (M, 1)
        output = torch.logsumexp(input, dim=1, keepdim=True)
        ctx.save_for_backward(input, output, belong)
        return output.squeeze(1)  # output을 반환할 때는 차원을 줄입니다.

    @staticmethod
    def backward(ctx, grad_output):
        temp = ctx.temp
        # (M, N), (M, 1), (M, N)
        input, output, belong = ctx.saved_tensors
        # softmax 함수를 사용하여 그래디언트 계산을 수행합니다.
        # (M, N)
        if belong is None:
            softmax_result = torch.exp(input - output)
        else:
            inner_value = input * belong + torch.where(belong == 0, -np.inf, belong)
            softmax_result = torch.softmax(inner_value, dim=1)
        grad_input = softmax_result * grad_output.unsqueeze(1)
        return grad_input, None

# 함수를 사용하려면 apply 메서드를 사용하고, dim 매개변수를 전달합니다.
input = torch.randn(3, 4, requires_grad=True)
belong = torch.randint(0, 2, size=[*input.shape])
output = CustomLogSumExp.apply(input, belong)
output.backward(torch.ones_like(output))


tensor([[0, 1, 1, 1],
        [0, 1, 0, 1],
        [1, 0, 0, 0]])
tensor([[ 1.4569, -0.0574, -0.7526, -0.4921],
        [-1.2380,  2.0233, -0.4984,  1.9862],
        [ 0.0494,  0.9031,  1.1277,  0.2750]], requires_grad=True)
tensor([[  -inf, 0.9426, 0.2474, 0.5079],
        [  -inf, 3.0233,   -inf, 2.9862],
        [1.0494,   -inf,   -inf,   -inf]])
tensor([[0.0000, 0.4659, 0.2325, 0.3016],
        [0.0000, 0.5093, 0.0000, 0.4907],
        [1.0000, 0.0000, 0.0000, 0.0000]])


In [59]:
x = torch.randint(0, 2, size=(2, 3))
torch.where(x == 1, -np.inf, x)

tensor([[0., 0., 0.],
        [-inf, -inf, 0.]])

In [44]:
np.log(1e-8)

-18.420680743952364