In [13]:
import torch
from torchmetrics import Dice
from torchmetrics.functional import dice
from monai.losses import DiceLoss
from monai.metrics import GeneralizedDiceScore

In [23]:
dsc_loss_w = 1.0

#Dice = torchmetrics.Dice(average='macro', num_classes=4)
DiceScore = Dice()
DiceFGScore = Dice(ignore_index=0) #ignore_index=0 means we ignore the background class
MonaiDiceLoss = DiceLoss()
MonaiDiceScore = GeneralizedDiceScore()
# nnUnetDiceLoss = 

# Example Tensors 1

In [3]:
# Creating Dummy Test Tensors for testing of same format as in the model [B, C, H, W, D]
# B = Batch Size, C = Number of Classes, H = Height, W = Width, D = Depth
# Example with Batch Size 1, 4 Classes, Height 3, Width 3, Depth 3
n_classes = 4
preds = torch.tensor([[0, 1, 1], [1, 2, 0], [1, 2, 0]])
masks = torch.tensor([[1, 2, 3], [0, 1, 1], [0, 0, 1]])

In [4]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds, n_classes).permute(2, 0, 1).unsqueeze(0) 
masks_oh = torch.nn.functional.one_hot(masks, n_classes).permute(2, 0, 1).unsqueeze(0)

# Example Tensors 2

In [15]:
# Testing with example tensors from terminal
# unique values in masks: [0 2 3]
# unique values in preds: [0 2]
n_classes = 4
preds = torch.tensor([[0, 0, 2], [2, 2, 3], [0, 2, 0]])
masks = torch.tensor([[0, 0, 2], [2, 2, 0], [0, 2, 2]])

In [16]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds, n_classes).permute(2, 0, 1).unsqueeze(0) 
masks_oh = torch.nn.functional.one_hot(masks, n_classes).permute(2, 0, 1).unsqueeze(0)

# Calculation of Scores

In [17]:
dice_p_cls = dice(preds, masks, average=None, num_classes=n_classes) # average=None returns dice per class
print(dice_p_cls)

# -> dice for class 1 should be 1 as predictions didn't contain the class which was also absent in the ground truth

tensor([0.7500,    nan, 0.8889, 0.0000])


In [11]:
# Dice Scores:
dsc = dice(preds, masks)
diceFG = dice(preds, masks, ignore_index=0)
dice_p_cls = dice(preds, masks, average=None, num_classes=n_classes) # average=None returns dice per class

# ET (Enhancing Tumor): label 3
dice_ET = Dice((preds == 3), (masks == 3))
dice_FG_ET = DiceFGScore((preds == 3), (masks == 3))

# TC(Tumor Core): ET + NCR = label 1 + label 3
dice_TC = Dice((preds == 1) | (preds == 3), (masks == 1) | (masks == 3))
dice_FG_TC = DiceFGScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3))

# WT (Whole Tumor): TC + ED = label 1 + label 2 + label 3
dice_WT = Dice((preds > 0), (masks > 0))
dice_FG_WT = DiceFGScore((preds > 0), (masks > 0))

print("Dice Score: ", dsc)
print("Dice Score FG: ", diceFG)
print("Dice Score per Class: ", dice_p_cls)
print("Dice Score ET: ", dice_ET)
print("Dice FG Score ET: ", dice_FG_ET)
print("Dice Score TC: ", dice_TC)
print("Dice FG Score TC: ", dice_FG_TC)
print("Dice Score WT: ", dice_WT)
print("Dice FG Score WT: ", dice_FG_WT)

Dice Score:  tensor(0.7778)
Dice Score FG:  tensor(0.8000)
Dice Score per Class:  tensor([0.7500,    nan, 0.8889, 0.0000])
Dice Score ET:  tensor(0.8889)
Dice FG Score ET:  tensor(0.)
Dice Score TC:  tensor(0.8889)
Dice FG Score TC:  tensor(0.)
Dice Score WT:  tensor(0.7778)
Dice FG Score WT:  tensor(0.8000)


In [25]:
# Dice Scores:
dsc = dice(preds, masks)
monai_dice = MonaiDiceScore(preds_oh, masks_oh)
diceFG = dice(preds, masks, ignore_index=0)
dice_p_cls = dice(preds, masks, average=None, num_classes=n_classes) # average=None returns dice per class

# ET (Enhancing Tumor): label 3
dice_ET = DiceScore((preds == 3), (masks == 3))
dice_FG_ET = DiceFGScore((preds == 3), (masks == 3))

# TC(Tumor Core): ET + NCR = label 1 + label 3
dice_TC = DiceScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3))
dice_FG_TC = DiceFGScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3))

# WT (Whole Tumor): TC + ED = label 1 + label 2 + label 3
dice_WT = DiceScore((preds > 0), (masks > 0))
dice_FG_WT = DiceFGScore((preds > 0), (masks > 0))

print("Dice Score: ", dsc)
print("Monai Dice Score: ", monai_dice)
print("Dice Score FG: ", diceFG)
print("Dice Score per Class: ", dice_p_cls)
print("Dice Score ET: ", dice_ET)
print("Dice FG Score ET: ", dice_FG_ET)
print("Dice Score TC: ", dice_TC)
print("Dice FG Score TC: ", dice_FG_TC)
print("Dice Score WT: ", dice_WT)
print("Dice FG Score WT: ", dice_FG_WT)

Dice Score:  tensor(0.7778)
Monai Dice Score:  tensor([0.7534])
Dice Score FG:  tensor(0.8000)
Dice Score per Class:  tensor([0.7500,    nan, 0.8889, 0.0000])
Dice Score ET:  tensor(0.8889)
Dice FG Score ET:  tensor(0.)
Dice Score TC:  tensor(0.8889)
Dice FG Score TC:  tensor(0.)
Dice Score WT:  tensor(0.7778)
Dice FG Score WT:  tensor(0.8000)


# Losses

In [12]:
#Dice Losses
# Brats Dice Loss (Sum of dice_ET, dice_TC, dice_WT divided by 3)
dice_ET_loss = (1 - dice((preds == 3), (masks == 3))) * dsc_loss_w / 3
dice_TC_loss = (1 - dice((preds == 1) | (preds == 3), (masks == 1) | (masks == 3)) ) * dsc_loss_w / 3
dice_WT_loss = (1 - dice((preds > 0), (masks > 0))) * dsc_loss_w / 3

dice_loss = (1-dice(preds, masks, average = 'macro', num_classes = n_classes)) * dsc_loss_w

monai_dice_loss = MonaiDiceLoss(preds_oh, masks_oh) * dsc_loss_w

print("Dice ET Loss: ", dice_ET_loss)
print("Dice TC Loss: ", dice_TC_loss)
print("Dice WT Loss: ", dice_WT_loss)
print("Dice Loss: ", dice_loss)
print("Monai Dice Loss: ", monai_dice_loss)

Dice ET Loss:  tensor(0.0370)
Dice TC Loss:  tensor(0.0370)
Dice WT Loss:  tensor(0.0741)
Dice Loss:  tensor(0.4537)
Monai Dice Loss:  tensor(0.3403)
