In [None]:
import torch
import torch.nn.functional as F
from torch import nn as nn

class _AbstractDiceLoss(nn.Module):
    """
    Base class for different implementations of Dice loss.
    """

    def __init__(self, weight=None, normalization='sigmoid'):
        super(_AbstractDiceLoss, self).__init__()
        self.register_buffer('weight', weight)
        # The output from the network during training is assumed to be un-normalized probabilities and we would
        # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data,
        # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems.
        # However if one would like to apply Softmax in order to get the proper probability distribution from the
        # output, just specify `normalization=Softmax`
        assert normalization in ['sigmoid', 'softmax', 'none']
        if normalization == 'sigmoid':
            self.normalization = nn.Sigmoid()
        elif normalization == 'softmax':
            self.normalization = nn.Softmax(dim=1)
        else:
            self.normalization = lambda x: x

    def dice(self, input, target, weight):
        # actual Dice score computation; to be implemented by the subclass
        raise NotImplementedError

    def forward(self, input, target):
        # get probabilities from logits
        input = self.normalization(input)

        # compute per channel Dice coefficient
        per_channel_dice = self.dice(input, target, weight=self.weight)

        # average Dice score across all channels/classes
        return 1. - torch.mean(per_channel_dice)

def flatten(tensor):
    """Flattens a given tensor such that the channel axis is first.
    The shapes are transformed as follows:
       (N, C, D, H, W) -> (C, N * D * H * W)
    """
    # number of channels
    C = tensor.size(1)
    # new axis order
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
    transposed = tensor.permute(axis_order)
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
    return transposed.contiguous().view(C, -1)

class GeneralizedDiceLoss(_AbstractDiceLoss):
    """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf.
    """

    def __init__(self, normalization='sigmoid', epsilon=1e-6):
        super().__init__(weight=None, normalization=normalization)
        self.epsilon = epsilon

    def dice(self, input, target, weight):
        assert input.size() == target.size(), "'input' and 'target' must have the same shape"
        print('input value after forward',input[1, :, 0, 0, 0])

        input = flatten(input)
        target = flatten(target)
        target = target.float()

        if input.size(0) == 1:
            print('resizing input and target')
            # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
            # put foreground and background voxels in separate channels
            input = torch.cat((input, 1 - input), dim=0)
            target = torch.cat((target, 1 - target), dim=0)

        print('input value after resizing',input.shape)
        # GDL weighting: the contribution of each label is corrected by the inverse of its volume
        w_l = target.sum(-1)
        print('w_l value after sum',w_l.shape)
        print('w_l value after sum',w_l)
        
        w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
        print('w_l value after clamp',w_l.shape)
        print('w_l value after clamp',w_l)
        w_l.requires_grad = False

        intersect = (input * target).sum(-1)
        intersect = intersect * w_l

        denominator = (input + target).sum(-1)
        denominator = (denominator * w_l).clamp(min=self.epsilon)

        return 2 * (intersect.sum() / denominator.sum())
    
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
    """
    Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
    Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.

    Args:
         input (torch.Tensor): NxCxSpatial input tensor
         target (torch.Tensor): NxCxSpatial target tensor
         epsilon (float): prevents division by zero
         weight (torch.Tensor): Cx1 tensor of weight per channel/class
    """
    # Apply softmax to the input
    input = F.softmax(input, dim=1)

    # input and target shapes must match
    assert input.size() == target.size(), "'input' and 'target' must have the same shape"

    input = flatten(input)
    target = flatten(target)
    target = target.float()

    # compute per channel Dice Coefficient
    intersect = (input * target).sum(-1)
    if weight is not None:
        print(' weight used')
        intersect = weight * intersect

    # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
    denominator = (input + target).sum(-1)
    
    # denominator = (input * input).sum(-1) + (target * target).sum(-1)
    return 2 * (intersect / denominator.clamp(min=epsilon))

class DiceCoefficient:
    """Computes Dice Coefficient.
    Generalized to multiple channels by computing per-channel Dice Score
    (as described in https://arxiv.org/pdf/1707.03237.pdf) and then simply taking the average.
    Input is expected to be probabilities instead of logits.
    """

    def __init__(self, epsilon=1e-6, **kwargs):
        self.epsilon = epsilon

    def __call__(self, input, target):
        # Average across channels in order to get the final score
        return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon))

In [9]:
input_tensor  = torch.randn(16, 2, 48, 48, 24)          # batch=16, channel=2
target_tensor = torch.randint(0, 2, (16, 2, 48, 48, 24))
print('input_tensor max', input_tensor.max())
print('input_tensor min', input_tensor.min())
print('target_tensor max', target_tensor.max())
print('target_tensor min', target_tensor.min())
print('inpute tensor value', input_tensor[1, :, 0, 0, 0])

input_tensor max tensor(4.6797)
input_tensor min tensor(-5.2074)
target_tensor max tensor(1)
target_tensor min tensor(0)
inpute tensor value tensor([0.8078, 0.7103])


In [17]:
gdl = GeneralizedDiceLoss(normalization='softmax', epsilon=1e-6)
# gdl = GeneralizedDiceLoss(normalization='sigmoid', epsilon=1e-6)
gdl.register_forward_hook(
    lambda module, inp, out: print(">> GeneralizedDiceLoss.forward 被调用")
)

# 再给 dice 方法也注册一个简单的打印（可选）
orig_dice = gdl.dice
def hooked_dice(input, target, weight):
    print(">> GeneralizedDiceLoss.dice 被调用")
    return orig_dice(input, target, weight)
gdl.dice = hooked_dice

# 调用前向，观察打印
loss = gdl(input_tensor, target_tensor)
print("Computed loss:", loss.item())

>> GeneralizedDiceLoss.dice 被调用
input value after forward tensor([0.5244, 0.4756])
input value after resizing torch.Size([2, 884736])
w_l value after sum torch.Size([2])
w_l value after sum tensor([442172., 441973.])
w_l value after clamp torch.Size([2])
w_l value after clamp tensor([5.1147e-12, 5.1193e-12])
>> GeneralizedDiceLoss.forward 被调用
Computed loss: 0.5000306963920593


In [5]:
dsc = DiceCoefficient(epsilon=1e-6)
dicescore = dsc(input_tensor, target_tensor)
print("Computed loss:", dicescore.item())

Computed loss: 0.49981674551963806
