In [1]:
import torch
from torchmetrics import Dice
from torchmetrics.functional import dice
from monai.losses import DiceLoss
from monai.metrics import GeneralizedDiceScore

In [2]:
import torch.nn as nn

class MemoryEfficientSoftDiceLoss(nn.Module):
    def __init__(self, apply_nonlin=None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.0, ddp: bool = True):
        """
        saves 1.6 GB on Dataset017 3d_lowres
        """
        super(MemoryEfficientSoftDiceLoss, self).__init__()

        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.apply_nonlin = apply_nonlin
        self.smooth = smooth
        self.ddp = ddp

    def forward(self, x, y, loss_mask=None):
        shp_x, shp_y = x.shape, y.shape

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        if not self.do_bg:
            x = x[:, 1:]

        # make everything shape (b, c)
        axes = list(range(2, len(shp_x)))

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                y = y.view((shp_y[0], 1, *shp_y[1:]))

            if all([i == j for i, j in zip(shp_x, shp_y)]):
                # if this is the case then gt is probably already a one hot encoding
                y_onehot = y
            else:
                gt = y.long()
                y_onehot = torch.zeros(shp_x, device=x.device, dtype=torch.bool)
                y_onehot.scatter_(1, gt, 1)

            if not self.do_bg:
                y_onehot = y_onehot[:, 1:]
            sum_gt = y_onehot.sum(axes) if loss_mask is None else (y_onehot * loss_mask).sum(axes)

        intersect = (x * y_onehot).sum(axes) if loss_mask is None else (x * y_onehot * loss_mask).sum(axes)
        sum_pred = x.sum(axes) if loss_mask is None else (x * loss_mask).sum(axes)

        # if self.ddp and self.batch_dice:
        #    intersect = AllGatherGrad.apply(intersect)
        #    sum_pred = AllGatherGrad.apply(sum_pred)
        #    sum_gt = AllGatherGrad.apply(sum_gt)

        if self.batch_dice:
            intersect = intersect.sum(0)
            sum_pred = sum_pred.sum(0)
            sum_gt = sum_gt.sum(0)

        dc = (2 * intersect) / (torch.clip(sum_gt + sum_pred + self.smooth, 1e-8))  # 2* intersect + self.smooth

        dc = dc.mean()
        return dc  # originally negative


In [3]:
dsc_loss_w = 1.0

#Dice = torchmetrics.Dice(average='macro', num_classes=4)
DiceScore = Dice()
DiceFGScore = Dice(ignore_index=0) #ignore_index=0 means we ignore the background class
MonaiDiceLoss = DiceLoss()
MonaiDiceScore = GeneralizedDiceScore()
SoftDiceLoss = MemoryEfficientSoftDiceLoss(batch_dice=False, do_bg=True, smooth=1e-5, ddp=False)
SoftDiceLossFG = MemoryEfficientSoftDiceLoss(batch_dice=False, do_bg=False, smooth=1e-5, ddp=False)
# nnUnetDiceLoss = 

# Example Tensors 1

In [66]:
# Creating Dummy Test Tensors for testing of same format as in the model [B, C, H, W, D]
# B = Batch Size, C = Number of Classes, H = Height, W = Width, D = Depth
# Example with Batch Size 1, 4 Classes, Height 3, Width 3, Depth 3
n_classes = 4
probs = torch.tensor(
    [[[0.7, 0.1, 0.1],  # Class 0 (corresponding to preds == 0)
     [0.1, 0.05, 0.6],
     [0.1, 0.05, 0.6]],

    [[0.2, 0.6, 0.6],  # Class 1 (corresponding to preds == 1)
     [0.6, 0.1, 0.2],
     [0.6, 0.1, 0.2]],

    [[0.05, 0.2, 0.2],  # Class 2 (corresponding to preds == 2)
     [0.2, 0.7, 0.1],
     [0.2, 0.7, 0.1]],

    [[0.05, 0.1, 0.1],  # Class 3 (no predictions, so probabilities are lower)
     [0.1, 0.15, 0.1],
     [0.1, 0.15, 0.1]]]
)
preds = torch.argmax(probs, dim=0)
masks = torch.tensor([[1, 2, 3], [0, 1, 1], [0, 0, 1]])

print(f"Preds: {preds}")
print(f"Ground Truth: {masks}")

print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

Preds: tensor([[0, 1, 1],
        [1, 2, 0],
        [1, 2, 0]])
Ground Truth: tensor([[1, 2, 3],
        [0, 1, 1],
        [0, 0, 1]])
Probs Shape:  torch.Size([4, 3, 3])
Preds Shape:  torch.Size([3, 3])
Masks Shape:  torch.Size([3, 3])


In [67]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds, n_classes).permute(2, 0, 1).unsqueeze(0) 
masks_oh = torch.nn.functional.one_hot(masks, n_classes).permute(2, 0, 1).unsqueeze(0)

if probs.dim() == 3:
    probs.unsqueeze_(0)

print(f"preds_oh.shape: {preds_oh.shape}")
print(f"masks_oh.shape: {masks_oh.shape}")
print(f"probs.shape: {probs.shape}")


preds_oh.shape: torch.Size([1, 4, 3, 3])
masks_oh.shape: torch.Size([1, 4, 3, 3])
probs.shape: torch.Size([1, 4, 3, 3])


# Example Tensors 2

In [80]:
# Testing with example tensors from terminal
# unique values in masks: [0 2 3]
# unique values in preds: [0 2]
n_classes = 4

probs = torch.tensor(
[
    # Class 0 probabilities (corresponding to preds == 0)
    [[0.7, 0.8, 0.1],   # Class 0 has highest probability for preds == 0
     [0.1, 0.1, 0.1],  
     [0.8, 0.1, 0.7]],

    # Class 1 probabilities
    [[0.1, 0.1, 0.1],   # Class 1 does not have the highest probability anywhere
     [0.1, 0.1, 0.1],  
     [0.05, 0.05, 0.1]],

    # Class 2 probabilities (corresponding to preds == 2)
    [[0.1, 0.05, 0.75],  # Class 2 has highest probability for preds == 2
     [0.75, 0.8, 0.1],  
     [0.1, 0.75, 0.1]],

    # Class 3 probabilities (corresponding to preds == 3)
    [[0.1, 0.05, 0.05],  # Class 3 has highest probability where preds == 3
     [0.05, 0.05, 0.7],  
     [0.05, 0.1, 0.1]]
]
)
preds = torch.argmax(probs, dim=0)
masks = torch.tensor([[0, 0, 2], [2, 2, 0], [0, 2, 2]])

print(f"preds: {preds}")
print(f"ground truth: {masks}")

print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

preds: tensor([[0, 0, 2],
        [2, 2, 3],
        [0, 2, 0]])
ground truth: tensor([[0, 0, 2],
        [2, 2, 0],
        [0, 2, 2]])
Probs Shape:  torch.Size([4, 3, 3])
Preds Shape:  torch.Size([3, 3])
Masks Shape:  torch.Size([3, 3])


In [81]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds, n_classes).permute(2, 0, 1).unsqueeze(0) 
masks_oh = torch.nn.functional.one_hot(masks, n_classes).permute(2, 0, 1).unsqueeze(0)

if probs.dim() == 3:
    probs.unsqueeze_(0)

print(f"preds_oh.shape: {preds_oh.shape}")
print(f"masks_oh.shape: {masks_oh.shape}")
print(f"probs.shape: {probs.shape}")

preds_oh.shape: torch.Size([1, 4, 3, 3])
masks_oh.shape: torch.Size([1, 4, 3, 3])
probs.shape: torch.Size([1, 4, 3, 3])


# Example Tensor 3 (Background heavy)

In [4]:
n_classes = 4

probs = torch.tensor(
[
    # Class 0 probabilities (background)
    [[0.9, 0.9, 0.9],  # Class 0 dominates where ground truth is 0
     [0.1, 0.1, 0.9],
     [0.9, 0.1, 0.9]],

    # Class 1 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 1 where it's not the ground truth
     [0.05, 0.05, 0.05],
     [0.05, 0.8, 0.05]],   # Higher probability for Class 1 at (2, 1)

    # Class 2 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 2 where it's not the ground truth
     [0.1, 0.8, 0.05],     # Higher probability for Class 2 at (1, 1)
     [0.05, 0.05, 0.05]],

    # Class 3 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 3 where it's not the ground truth
     [0.8, 0.05, 0.05],    # Higher probability for Class 3 at (1, 0)
     [0.05, 0.05, 0.05]]
]
)
preds = torch.argmax(probs, dim=0)
masks = torch.tensor([[0, 0, 0 ], [0, 2, 3], [0, 1, 0]])

print(f"preds: {preds}")
print(f"ground truth: {masks}")

print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

preds: tensor([[0, 0, 0],
        [3, 2, 0],
        [0, 1, 0]])
ground truth: tensor([[0, 0, 0],
        [0, 2, 3],
        [0, 1, 0]])
Probs Shape:  torch.Size([4, 3, 3])
Preds Shape:  torch.Size([3, 3])
Masks Shape:  torch.Size([3, 3])


In [5]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds, n_classes).permute(2, 0, 1).unsqueeze(0) 
masks_oh = torch.nn.functional.one_hot(masks, n_classes).permute(2, 0, 1).unsqueeze(0)

if probs.dim() == 3:
    probs.unsqueeze_(0)

print(f"preds_oh.shape: {preds_oh.shape}")
print(f"masks_oh.shape: {masks_oh.shape}")
print(f"probs.shape: {probs.shape}")

preds_oh.shape: torch.Size([1, 4, 3, 3])
masks_oh.shape: torch.Size([1, 4, 3, 3])
probs.shape: torch.Size([1, 4, 3, 3])


# Calculation of Scores

In [6]:
dice_p_cls = dice(preds, masks, average=None, num_classes=n_classes) # average=None returns dice per class
print(dice_p_cls)

# -> dice for class 1 should be 1 as predictions didn't contain the class which was also absent in the ground truth

tensor([0.8333, 1.0000, 1.0000, 0.0000])


In [7]:
# Dice Scores:
dsc = dice(preds, masks)
diceFG = dice(preds, masks, ignore_index=0)
dice_p_cls = dice(preds, masks, average=None, num_classes=n_classes) # average=None returns dice per class

SoftDiceScore = 1 - SoftDiceLoss(probs, masks_oh)


# ET (Enhancing Tumor): label 3
dice_ET = DiceScore((preds == 3), (masks == 3))
dice_FG_ET = DiceFGScore((preds == 3), (masks == 3))

# TC(Tumor Core): ET + NCR = label 1 + label 3
dice_TC = DiceScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3))
dice_FG_TC = DiceFGScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3))

# WT (Whole Tumor): TC + ED = label 1 + label 2 + label 3
dice_WT = DiceScore((preds > 0), (masks > 0))
dice_FG_WT = DiceFGScore((preds > 0), (masks > 0))

print("\n")
print(f"BraTs Region Scores:")
print("Dice Score ET: ", dice_ET)
print("Dice FG Score ET: ", dice_FG_ET)
print("Dice Score TC: ", dice_TC)
print("Dice FG Score TC: ", dice_FG_TC)
print("Dice Score WT: ", dice_WT)
print("Dice FG Score WT: ", dice_FG_WT)

print("\n")
print(f"Dice Scores:")
print("Dice Score: ", dsc)
print(f"Soft Dice Score: {SoftDiceScore}")
print("Dice Score FG: ", diceFG)
print("Dice Score per Class: ", dice_p_cls)



BraTs Region Scores:
Dice Score ET:  tensor(0.7778)
Dice FG Score ET:  tensor(0.)
Dice Score TC:  tensor(0.7778)
Dice FG Score TC:  tensor(0.5000)
Dice Score WT:  tensor(0.7778)
Dice FG Score WT:  tensor(0.6667)


Dice Scores:
Dice Score:  tensor(0.7778)
Soft Dice Score: 0.43246108293533325
Dice Score FG:  tensor(0.6667)
Dice Score per Class:  tensor([0.8333, 1.0000, 1.0000, 0.0000])


In [8]:
#Using One-Hot Encoded Predictions and Masks
print(f"Dice Score with One-Hot Encoded Predictions and Masks:")
print(f"Dice with oh preds: {DiceScore(preds_oh, masks_oh)}")
print(f"DiceFG with oh preds: {DiceFGScore(preds_oh, masks_oh)}")

# Use DiceFG when using one-hot encoding and non-binary case

Dice Score with One-Hot Encoded Predictions and Masks:
Dice with oh preds: 0.8888888955116272
DiceFG with oh preds: 0.7777777910232544


# Losses

In [9]:
#Dice Losses
# Brats Dice Loss
dice_ET_loss = (1 - DiceScore((preds == 3), (masks == 3))) * dsc_loss_w
dice_TC_loss = (1 - DiceScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3)) ) * dsc_loss_w
dice_WT_loss = (1 - DiceScore((preds > 0), (masks > 0))) * dsc_loss_w

# Brats FG Dice Loss
diceFG_ET_loss = (1 - DiceFGScore((preds == 3), (masks == 3))) * dsc_loss_w
diceFG_TC_loss = (1 - DiceFGScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3)) ) * dsc_loss_w
diceFG_WT_loss = (1 - DiceFGScore((preds > 0), (masks > 0))) * dsc_loss_w

dice_loss = (1-DiceScore(preds, masks)) * dsc_loss_w

dice_loss_p_class = (1 - dice_p_cls) * dsc_loss_w

softdice_loss = SoftDiceLoss(probs, masks_oh) * dsc_loss_w

monai_dice_loss = MonaiDiceLoss(preds_oh, masks_oh) * dsc_loss_w

print("Dice ET Loss: ", dice_ET_loss)
print(f"Dice ET FG Loss: {diceFG_ET_loss}")
print("Dice TC Loss: ", dice_TC_loss)
print(f"Dice TC FG Loss: {diceFG_TC_loss}")
print("Dice WT Loss: ", dice_WT_loss)
print(f"Dice WT FG Loss: {diceFG_WT_loss}")
print("Dice Loss: ", dice_loss)
print(f"Soft Dice Loss: {softdice_loss}")
print("Monai Dice Loss: ", monai_dice_loss)
print("Dice Loss per Class: ", dice_loss_p_class)

Dice ET Loss:  tensor(0.2222)
Dice ET FG Loss: 1.0
Dice TC Loss:  tensor(0.2222)
Dice TC FG Loss: 0.5
Dice WT Loss:  tensor(0.2222)
Dice WT FG Loss: 0.3333333134651184
Dice Loss:  tensor(0.2222)
Soft Dice Loss: 0.5675389170646667
Monai Dice Loss:  tensor(0.2917)
Dice Loss per Class:  tensor([0.1667, 0.0000, 0.0000, 1.0000])


In [10]:
# Soft Dice Loss on Brats Regions
masks_ET = (masks == 3).unsqueeze(0).unsqueeze(0)
masks_TC = (masks == 1) | (masks == 3).unsqueeze(0).unsqueeze(0)
masks_WT = (masks > 0).unsqueeze(0).unsqueeze(0)

probs_ET  = probs[:, 3].unsqueeze(0)
probs_TC  = torch.maximum(probs[:, 1], probs[:, 3]).unsqueeze(0)
probs_WT  = torch.maximum(probs_TC.squeeze(0), probs[:, 2]).unsqueeze(0)

softdiceL_ET = SoftDiceLoss(probs_ET, masks_ET)
softdiceL_TC = SoftDiceLoss(probs_TC, masks_TC)
softdiceL_WT = SoftDiceLoss(probs_WT, masks_WT)

print(f"softdice_ET: {softdiceL_ET}")
print(f"softdice_TC: {softdiceL_TC}")
print(f"softdice_WT: {softdiceL_WT}")

# Only on Foreground -> nan because our probs are binary
softdiceL_ET = SoftDiceLossFG(probs_ET, masks_ET)
softdiceL_TC = SoftDiceLossFG(probs_TC, masks_TC)
softdiceL_WT = SoftDiceLossFG(probs_WT, masks_WT)

print(f"softdice_ET: {softdiceL_ET}")
print(f"softdice_TC: {softdiceL_TC}")
print(f"softdice_WT: {softdiceL_WT}")

softdice_ET: 0.04545434191823006
softdice_TC: 0.43037867546081543
softdice_WT: 0.5789464116096497
softdice_ET: nan
softdice_TC: nan
softdice_WT: nan


In [11]:
# Soft Dice Loss on Brats Regions with two channels each (0 = BG, 1 = FG)
masks_ET = (masks == 3).unsqueeze(0).unsqueeze(0)
masks_TC = (masks == 1) | (masks == 3).unsqueeze(0).unsqueeze(0)    
masks_WT = (masks > 0).unsqueeze(0).unsqueeze(0)

probs_ET_FG  = probs[:, 3].unsqueeze(0)
probs_ET_BG = torch.maximum(torch.maximum(probs[:, 0], probs[:, 1]), probs[:, 2]).unsqueeze(0)
probs_ET = torch.cat((probs_ET_BG, probs_ET_FG), dim=1)

probs_TC_FG  = torch.maximum(probs[:, 1], probs[:, 3]).unsqueeze(0)
probs_TC_BG = torch.maximum(probs[:, 0], probs[:, 2]).unsqueeze(0)
probs_TC = torch.cat((probs_TC_BG, probs_TC_FG), dim=1)

probs_WT_FG  = torch.maximum(torch.maximum(probs[: , 1], probs[:, 2]), probs[:, 3]).unsqueeze(0)
probs_WT_BG = probs[:, 0].unsqueeze(0)

probs_WT = torch.cat((probs_WT_BG, probs_WT_FG), dim=1)

softdiceL_ET = SoftDiceLoss(probs_ET, masks_ET)
softdiceL_TC = SoftDiceLoss(probs_TC, masks_TC)
softdiceL_WT = SoftDiceLoss(probs_WT, masks_WT)

print(f"softdiceloss_ET: {softdiceL_ET}")
print(f"softdiceloss_TC: {softdiceL_TC}")
print(f"softdiceloss_WT: {softdiceL_WT}")

# Only on Foreground -> nan because our probs are binary
softdiceFGL_ET = SoftDiceLossFG(probs_ET, masks_ET)
softdiceFGL_TC = SoftDiceLossFG(probs_TC, masks_TC)
softdiceFGL_WT = SoftDiceLossFG(probs_WT, masks_WT)

print("\n")

print(f"Foreground softdiceloss_ET: {softdiceFGL_ET}")
print(f"Foreground softdiceloss_TC: {softdiceFGL_TC}")
print(f"Foreground softdiceloss_WT: {softdiceFGL_WT}")

softdiceloss_ET: 0.433322936296463
softdiceloss_TC: 0.618174135684967
softdiceloss_WT: 0.6826353073120117


Foreground softdiceloss_ET: 0.04545434191823006
Foreground softdiceloss_TC: 0.43037867546081543
Foreground softdiceloss_WT: 0.5789464116096497


In [12]:
# Soft Dice Loss on Brats Regions with two channels each (0 = BG, 1 = FG) -> doesn't make a difference
masks_ET = (masks == 3)
masks_TC = (masks == 1) | (masks == 3)
masks_WT = (masks > 0)

masks_ET_oh = torch.nn.functional.one_hot(masks_ET.long(), 2).permute(2, 0, 1).unsqueeze(0)
masks_TC_oh = torch.nn.functional.one_hot(masks_TC.long(), 2).permute(2, 0, 1).unsqueeze(0)
masks_WT_oh = torch.nn.functional.one_hot(masks_WT.long(), 2).permute(2, 0, 1).unsqueeze(0)

masks_ET.unsqueeze_(0).unsqueeze_(0)
masks_TC.unsqueeze_(0).unsqueeze_(0)
masks_WT.unsqueeze_(0).unsqueeze_(0)

probs_ET_FG  = probs[:, 3].unsqueeze(0)
probs_ET_BG = torch.maximum(torch.maximum(probs[:, 0], probs[:, 1]), probs[:, 2]).unsqueeze(0)
probs_ET = torch.cat((probs_ET_BG, probs_ET_FG), dim=1)

probs_TC_FG  = torch.maximum(probs[:, 1], probs[:, 3]).unsqueeze(0)
probs_TC_BG = torch.maximum(probs[:, 0], probs[:, 2]).unsqueeze(0)
probs_TC = torch.cat((probs_TC_BG, probs_TC_FG), dim=1)

probs_WT_FG  = torch.maximum(torch.maximum(probs[: , 1], probs[:, 2]), probs[:, 3]).unsqueeze(0)
probs_WT_BG = probs[:, 0].unsqueeze(0)

probs_WT = torch.cat((probs_WT_BG, probs_WT_FG), dim=1)

softdiceL_ET = SoftDiceLoss(probs_ET, masks_ET_oh)
softdiceL_TC = SoftDiceLoss(probs_TC, masks_TC_oh)
softdiceL_WT = SoftDiceLoss(probs_WT, masks_WT_oh)

print(f"softdiceloss_ET: {softdiceL_ET}")
print(f"softdiceloss_TC: {softdiceL_TC}")
print(f"softdiceloss_WT: {softdiceL_WT}")

# Only on Foreground -> nan because our probs are binary
softdiceFGL_ET = SoftDiceLossFG(probs_ET, masks_ET_oh)
softdiceFGL_TC = SoftDiceLossFG(probs_TC, masks_TC_oh)
softdiceFGL_WT = SoftDiceLossFG(probs_WT, masks_WT_oh)

print("\n")

print(f"Foreground softdiceloss_ET: {softdiceFGL_ET}")
print(f"Foreground softdiceloss_TC: {softdiceFGL_TC}")
print(f"Foreground softdiceloss_WT: {softdiceFGL_WT}")

softdiceloss_ET: 0.433322936296463
softdiceloss_TC: 0.618174135684967
softdiceloss_WT: 0.6826353073120117


Foreground softdiceloss_ET: 0.04545434191823006
Foreground softdiceloss_TC: 0.43037867546081543
Foreground softdiceloss_WT: 0.5789464116096497


# Test ChatGPT approach

In [17]:
def prepare_region_logits(logits, target_classes):
    """
    Efficiently sum the logits for the given target classes.
    """
    return logits[:, target_classes].sum(dim=1, keepdim=True)

def prepare_region_gt(gt, target_classes):
    """
    Efficiently create a binary mask for the relevant target classes.
    """
    return torch.isin(gt, torch.tensor(target_classes))

In [30]:
brats_regions = {'ET': [3], 'TC': [1, 3], 'WT': [1, 2, 3]}

probs_tc = prepare_region_logits(probs, brats_regions['TC'])
gt_tc = prepare_region_gt(masks.unsqueeze(0).unsqueeze(0), brats_regions['TC'])
probs_et = prepare_region_logits(probs, brats_regions['ET'])
gt_et = prepare_region_gt(masks.unsqueeze(0).unsqueeze(0), brats_regions['ET'])
probs_wt = prepare_region_logits(probs, brats_regions['WT'])
gt_wt = prepare_region_gt(masks.unsqueeze(0).unsqueeze(0), brats_regions['WT'])

print(f"shape probs_tc: {probs_tc.shape}")
print(f"shape gt tc: {gt_tc.shape}")

shape probs_tc: torch.Size([1, 1, 3, 3])
shape gt tc: torch.Size([1, 1, 3, 3])


In [32]:
softdiceloss_et = SoftDiceLoss(probs_et, gt_et)
softdicelossFG_et = SoftDiceLossFG(probs_et, gt_et)

softdiceloss_tc = SoftDiceLoss(probs_tc, gt_tc)
softdicelossFG_tc = SoftDiceLossFG(probs_tc, gt_tc)

softdiceloss_wt = SoftDiceLoss(probs_wt, gt_wt)
softdicelossFG_wt = SoftDiceLossFG(probs_wt, gt_wt)

print(f"Soft Dice Loss ET: {softdiceloss_et}")
print(f"Soft Dice Loss TC: {softdiceloss_tc}")
print(f"Soft Dice Loss WT: {softdiceloss_wt}")

print(f"Soft Dice FG Loss  ET: {softdicelossFG_et}")
print(f"Soft Dice FG Loss TC: {softdicelossFG_tc}")
print(f"Soft Dice FG Loss WT: {softdicelossFG_wt}")

Soft Dice Loss ET: 0.04545434191823006
Soft Dice Loss TC: 0.4318172037601471
Soft Dice Loss WT: 0.5864652395248413
Soft Dice FG Loss  ET: nan
Soft Dice FG Loss TC: nan
Soft Dice FG Loss WT: nan
