# Imports

In [1]:
import torch
import torch.nn as nn
from torchmetrics import Dice
from torchmetrics.functional import dice
from monai.losses import DiceLoss, MaskedDiceLoss
from monai.metrics import GeneralizedDiceScore, DiceMetric

import os
import sys
from pathlib import Path

current_dir = os.getcwd()

file = Path(current_dir).resolve()
sys.path.append(str(file.parents[0]))
sys.path.append(str(file.parents[1]))
#sys.path.append(str(file.parents[2]))

from utils.dice import MemoryEfficientSoftDiceLoss
from utils.dice_loss import SoftDiceLoss
#import utils.dice_loss

# Setup of Dice Metrics and Losses

In [2]:
dsc_loss_w = 1.0

#Torchmetrics
#Dice = torchmetrics.Dice(average='macro', num_classes=4)
DiceScore = Dice()
DiceFGScore = Dice(ignore_index=0) #ignore_index=0 means we ignore the background class

#MONAI Losses and Scores
MonaiDiceLoss = DiceLoss()
MonaiDiceLogitLoss = DiceLoss(softmax=True, to_onehot_y=True)
MonaiMaskedDiceLoss = MaskedDiceLoss(softmax = True, to_onehot_y=True)

MonaiDiceScore = GeneralizedDiceScore()
MonaiDiceFGScore = GeneralizedDiceScore(include_background=False)
MonaiDiceMetric = DiceMetric(ignore_empty=False)
MonaiDiceMetricFG = DiceMetric(include_background=False, ignore_empty=False)

# Memory Efficient Soft Dice Loss
MESoftDiceLogitLoss = MemoryEfficientSoftDiceLoss(nn.Softmax(dim = 1), do_bg=True, smooth = 1e-5)
MESoftDiceLoss = MemoryEfficientSoftDiceLoss(do_bg=True, smooth = 1e-5)
MESoftDiceLossFG = MemoryEfficientSoftDiceLoss(do_bg=False, smooth = 1e-5)

#Soft Dice Loss
SoftDiceLogitLoss = SoftDiceLoss(nn.Softmax(dim = 1), do_bg=True, smooth = 1e-5)
SoftDiceL = SoftDiceLoss(smooth = 1e-5)


Softmax = nn.Softmax(dim = 0)

# Example Tensors 1

In [4]:
# Creating Dummy Test Tensors for testing of same format as in the model [B, C, H, W, D]
# B = Batch Size, C = Number of Classes, H = Height, W = Width, D = Depth
# Example with Batch Size 1, 4 Classes, Height 3, Width 3, Depth 3
n_classes = 4
og_probs = torch.tensor(
    [[[0.7, 0.1, 0.1],  # Class 0 (corresponding to preds == 0)
     [0.1, 0.05, 0.6],
     [0.1, 0.05, 0.6]],

    [[0.2, 0.6, 0.6],  # Class 1 (corresponding to preds == 1)
     [0.6, 0.1, 0.2],
     [0.6, 0.1, 0.2]],

    [[0.05, 0.2, 0.2],  # Class 2 (corresponding to preds == 2)
     [0.2, 0.7, 0.1],
     [0.2, 0.7, 0.1]],

    [[0.05, 0.1, 0.1],  # Class 3 (no predictions, so probabilities are lower)
     [0.1, 0.15, 0.1],
     [0.1, 0.15, 0.1]]]
)

logits = torch.log(og_probs / (1 - og_probs))

probs = Softmax(logits)

preds = torch.argmax(probs, dim=0)
masks = torch.tensor([[1, 2, 3], [0, 1, 1], [0, 0, 1]])

print(f"Preds: {preds}")
print(f"Ground Truth: {masks}")

# Uniform shape of B, C, H, W
logits.unsqueeze_(0)
probs.unsqueeze_(0)
preds.unsqueeze_(0).unsqueeze_(0) 
masks.unsqueeze_(0).unsqueeze_(0)

print(f"logits shape {logits.shape}")
print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

Preds: tensor([[0, 1, 1],
        [1, 2, 0],
        [1, 2, 0]])
Ground Truth: tensor([[1, 2, 3],
        [0, 1, 1],
        [0, 0, 1]])
logits shape torch.Size([1, 4, 3, 3])
Probs Shape:  torch.Size([1, 4, 3, 3])
Preds Shape:  torch.Size([1, 1, 3, 3])
Masks Shape:  torch.Size([1, 1, 3, 3])


In [5]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds.squeeze(1), n_classes).permute(0, 3, 1, 2)
masks_oh = torch.nn.functional.one_hot(masks.squeeze(1), n_classes).permute(0, 3, 1, 2)

if probs.dim() == 3:
    probs.unsqueeze_(0)

print(f"preds_oh.shape: {preds_oh.shape}")
print(f"masks_oh.shape: {masks_oh.shape}")
print(f"probs.shape: {probs.shape}")


preds_oh.shape: torch.Size([1, 4, 3, 3])
masks_oh.shape: torch.Size([1, 4, 3, 3])
probs.shape: torch.Size([1, 4, 3, 3])


# Example Tensors 2

In [89]:
# Testing with example tensors from terminal
# unique values in masks: [0 2 3]
# unique values in preds: [0 2]
n_classes = 4

og_probs = torch.tensor(
[
    # Class 0 probabilities (corresponding to preds == 0)
    [[0.7, 0.8, 0.1],   # Class 0 has highest probability for preds == 0
     [0.1, 0.1, 0.1],  
     [0.8, 0.1, 0.7]],

    # Class 1 probabilities
    [[0.1, 0.1, 0.1],   # Class 1 does not have the highest probability anywhere
     [0.1, 0.1, 0.1],  
     [0.05, 0.05, 0.1]],

    # Class 2 probabilities (corresponding to preds == 2)
    [[0.1, 0.05, 0.75],  # Class 2 has highest probability for preds == 2
     [0.75, 0.8, 0.1],  
     [0.1, 0.75, 0.1]],

    # Class 3 probabilities (corresponding to preds == 3)
    [[0.1, 0.05, 0.05],  # Class 3 has highest probability where preds == 3
     [0.05, 0.05, 0.7],  
     [0.05, 0.1, 0.1]]
]
)

logits = torch.log(og_probs / (1 - og_probs))

probs = Softmax(logits)

preds = torch.argmax(probs, dim=0)
masks = torch.tensor([[0, 0, 2], [2, 2, 0], [0, 2, 2]])

# Uniform shape of B, C, H, W
logits.unsqueeze_(0)
probs.unsqueeze_(0)
preds.unsqueeze_(0).unsqueeze_(0) 
masks.unsqueeze_(0).unsqueeze_(0)

print(f"logits Shape: {logits.shape}")
print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

logits Shape: torch.Size([1, 4, 3, 3])
Probs Shape:  torch.Size([1, 4, 3, 3])
Preds Shape:  torch.Size([1, 1, 3, 3])
Masks Shape:  torch.Size([1, 1, 3, 3])


In [90]:
print(preds)
print(masks)

tensor([[[[0, 0, 2],
          [2, 2, 3],
          [0, 2, 0]]]])
tensor([[[[0, 0, 2],
          [2, 2, 0],
          [0, 2, 2]]]])


In [91]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds.squeeze(1), n_classes).permute(0, 3, 1, 2)
masks_oh = torch.nn.functional.one_hot(masks.squeeze(1), n_classes).permute(0, 3, 1, 2)

if probs.dim() == 3:
    probs.unsqueeze_(0)

print(f"preds_oh.shape: {preds_oh.shape}")
print(f"masks_oh.shape: {masks_oh.shape}")
print(f"probs.shape: {probs.shape}")

preds_oh.shape: torch.Size([1, 4, 3, 3])
masks_oh.shape: torch.Size([1, 4, 3, 3])
probs.shape: torch.Size([1, 4, 3, 3])


# Example Tensor 3 (Background heavy)

In [22]:
n_classes = 4

og_probs = torch.tensor(
[
    # Class 0 probabilities (background)
    [[0.9, 0.9, 0.9],  # Class 0 dominates where ground truth is 0
     [0.1, 0.1, 0.9],
     [0.9, 0.1, 0.9]],

    # Class 1 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 1 where it's not the ground truth
     [0.05, 0.05, 0.05],
     [0.05, 0.8, 0.05]],   # Higher probability for Class 1 at (2, 1)

    # Class 2 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 2 where it's not the ground truth
     [0.1, 0.8, 0.05],     # Higher probability for Class 2 at (1, 1)
     [0.05, 0.05, 0.05]],

    # Class 3 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 3 where it's not the ground truth
     [0.8, 0.05, 0.05],    # Higher probability for Class 3 at (1, 0)
     [0.05, 0.05, 0.05]]
]
)

logits = torch.log(og_probs / (1 - og_probs))

probs = Softmax(logits)

preds = torch.argmax(probs, dim=0)
masks = torch.tensor([[0, 0, 0 ], [0, 2, 3], [0, 1, 0]])

print(f"preds: {preds}")
print(f"ground truth: {masks}")

print()

print(f"logits Shape: {logits.shape}")
print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

print()
print(f"Shape transformation to Torch Lightning Format")
print()

# Uniform shape of B, C, H, W
logits.unsqueeze_(0)
probs.unsqueeze_(0)
preds.unsqueeze_(0).unsqueeze_(0) 
masks.unsqueeze_(0).unsqueeze_(0)

print(f"logits Shape: {logits.shape}")
print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

preds: tensor([[0, 0, 0],
        [3, 2, 0],
        [0, 1, 0]])
ground truth: tensor([[0, 0, 0],
        [0, 2, 3],
        [0, 1, 0]])

logits Shape: torch.Size([4, 3, 3])
Probs Shape:  torch.Size([4, 3, 3])
Preds Shape:  torch.Size([3, 3])
Masks Shape:  torch.Size([3, 3])

Shape transformation to Torch Lightning Format

logits Shape: torch.Size([1, 4, 3, 3])
Probs Shape:  torch.Size([1, 4, 3, 3])
Preds Shape:  torch.Size([1, 1, 3, 3])
Masks Shape:  torch.Size([1, 1, 3, 3])


In [4]:
print(logits)

tensor([[[[ 2.1972,  2.1972,  2.1972],
          [-2.1972, -2.1972,  2.1972],
          [ 2.1972, -2.1972,  2.1972]],

         [[-2.9444, -2.9444, -2.9444],
          [-2.9444, -2.9444, -2.9444],
          [-2.9444,  1.3863, -2.9444]],

         [[-2.9444, -2.9444, -2.9444],
          [-2.1972,  1.3863, -2.9444],
          [-2.9444, -2.9444, -2.9444]],

         [[-2.9444, -2.9444, -2.9444],
          [ 1.3863, -2.9444, -2.9444],
          [-2.9444, -2.9444, -2.9444]]]])


In [23]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds.squeeze(1), n_classes).permute(0, 3, 1, 2)
masks_oh = torch.nn.functional.one_hot(masks.squeeze(1), n_classes).permute(0, 3, 1, 2)

if probs.dim() == 3:
    probs.unsqueeze_(0)

print(f"preds_oh.shape: {preds_oh.shape}")
print(f"masks_oh.shape: {masks_oh.shape}")
print(f"probs.shape: {probs.shape}")

preds_oh.shape: torch.Size([1, 4, 3, 3])
masks_oh.shape: torch.Size([1, 4, 3, 3])
probs.shape: torch.Size([1, 4, 3, 3])


# Example Tensor 4 (perfect overlap)

In [74]:
n_classes = 4

og_probs = torch.tensor(
[
    # Class 0 probabilities (background)
    [[0.9, 0.9, 0.9],  # Class 0 dominates where ground truth is 0
     [0.1, 0.1, 0.9],
     [0.9, 0.1, 0.9]],

    # Class 1 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 1 where it's not the ground truth
     [0.05, 0.05, 0.05],
     [0.05, 0.8, 0.05]],   # Higher probability for Class 1 at (2, 1)

    # Class 2 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 2 where it's not the ground truth
     [0.1, 0.8, 0.05],     # Higher probability for Class 2 at (1, 1)
     [0.05, 0.05, 0.05]],

    # Class 3 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 3 where it's not the ground truth
     [0.8, 0.05, 0.05],    # Higher probability for Class 3 at (1, 0)
     [0.05, 0.05, 0.05]]
]
)

logits = torch.log(og_probs / (1 - og_probs))

probs = Softmax(logits)

preds = torch.argmax(probs, dim=0)

masks = preds.detach().clone()

print(f"preds: {preds}")
print(f"ground truth: {masks}")

print()

print(f"logits Shape: {logits.shape}")
print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

print()
print(f"Shape transformation to Torch Lightning Format")
print()

# Uniform shape of B, C, H, W
logits.unsqueeze_(0)
probs.unsqueeze_(0)
preds.unsqueeze_(0).unsqueeze_(0) 
masks.unsqueeze_(0).unsqueeze_(0)

print(f"logits Shape: {logits.shape}")
print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

preds: tensor([[0, 0, 0],
        [3, 2, 0],
        [0, 1, 0]])
ground truth: tensor([[0, 0, 0],
        [3, 2, 0],
        [0, 1, 0]])

logits Shape: torch.Size([4, 3, 3])
Probs Shape:  torch.Size([4, 3, 3])
Preds Shape:  torch.Size([3, 3])
Masks Shape:  torch.Size([3, 3])

Shape transformation to Torch Lightning Format

logits Shape: torch.Size([1, 4, 3, 3])
Probs Shape:  torch.Size([1, 4, 3, 3])
Preds Shape:  torch.Size([1, 1, 3, 3])
Masks Shape:  torch.Size([1, 1, 3, 3])


In [75]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds.squeeze(1), n_classes).permute(0, 3, 1, 2)
masks_oh = torch.nn.functional.one_hot(masks.squeeze(1), n_classes).permute(0, 3, 1, 2)

if probs.dim() == 3:
    probs.unsqueeze_(0)

print(f"preds_oh.shape: {preds_oh.shape}")
print(f"masks_oh.shape: {masks_oh.shape}")
print(f"probs.shape: {probs.shape}")

preds_oh.shape: torch.Size([1, 4, 3, 3])
masks_oh.shape: torch.Size([1, 4, 3, 3])
probs.shape: torch.Size([1, 4, 3, 3])


# Example Tensor 5 (No Overlap at all)

In [12]:
n_classes = 4

og_probs = torch.tensor(
[
    # Class 0 probabilities (background)
    [[0.9, 0.9, 0.9],  # Class 0 dominates where ground truth is 0
     [0.1, 0.1, 0.9],
     [0.9, 0.1, 0.9]],

    # Class 1 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 1 where it's not the ground truth
     [0.05, 0.05, 0.05],
     [0.05, 0.8, 0.05]],   # Higher probability for Class 1 at (2, 1)

    # Class 2 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 2 where it's not the ground truth
     [0.1, 0.8, 0.05],     # Higher probability for Class 2 at (1, 1)
     [0.05, 0.05, 0.05]],

    # Class 3 probabilities
    [[0.05, 0.05, 0.05],   # Low probability for Class 3 where it's not the ground truth
     [0.8, 0.05, 0.05],    # Higher probability for Class 3 at (1, 0)
     [0.05, 0.05, 0.05]]
]
)

logits = torch.log(og_probs / (1 - og_probs))

probs = Softmax(logits)

preds = torch.argmax(probs, dim=0)

masks = torch.tensor([[1, 2, 3],
        [0, 0, 2],
        [1, 2, 3]])

print(f"preds: {preds}")
print(f"ground truth: {masks}")

print()

print(f"logits Shape: {logits.shape}")
print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

print()
print(f"Shape transformation to Torch Lightning Format")
print()

# Uniform shape of B, C, H, W
logits.unsqueeze_(0)
probs.unsqueeze_(0)
preds.unsqueeze_(0).unsqueeze_(0) 
masks.unsqueeze_(0).unsqueeze_(0)

print(f"logits Shape: {logits.shape}")
print("Probs Shape: ", probs.shape)
print("Preds Shape: ", preds.shape)
print("Masks Shape: ", masks.shape)

preds: tensor([[0, 0, 0],
        [3, 2, 0],
        [0, 1, 0]])
ground truth: tensor([[1, 2, 3],
        [0, 0, 2],
        [1, 2, 3]])

logits Shape: torch.Size([4, 3, 3])
Probs Shape:  torch.Size([4, 3, 3])
Preds Shape:  torch.Size([3, 3])
Masks Shape:  torch.Size([3, 3])

Shape transformation to Torch Lightning Format

logits Shape: torch.Size([1, 4, 3, 3])
Probs Shape:  torch.Size([1, 4, 3, 3])
Preds Shape:  torch.Size([1, 1, 3, 3])
Masks Shape:  torch.Size([1, 1, 3, 3])


In [13]:
# one-hot encoding, permuting and unsqueezing to match the format BNHW[D], where B = Batch Size, N = Number of Classes, H = Height, W = Width, D = Depth
preds_oh = torch.nn.functional.one_hot(preds.squeeze(1), n_classes).permute(0, 3, 1, 2)
masks_oh = torch.nn.functional.one_hot(masks.squeeze(1), n_classes).permute(0, 3, 1, 2)

if probs.dim() == 3:
    probs.unsqueeze_(0)

print(f"preds_oh.shape: {preds_oh.shape}")
print(f"masks_oh.shape: {masks_oh.shape}")
print(f"probs.shape: {probs.shape}")

preds_oh.shape: torch.Size([1, 4, 3, 3])
masks_oh.shape: torch.Size([1, 4, 3, 3])
probs.shape: torch.Size([1, 4, 3, 3])


# Test with actual GT and dummy logits, preds and probs

In [6]:
# to be implemented

# Calculation of Scores

In [76]:
dice_p_cls = dice(logits, masks, average=None, num_classes=n_classes) # average=None returns dice per class
print(dice_p_cls)

# -> dice for class 1 should be 1 as predictions didn't contain the class which was also absent in the ground truth

tensor([1., 1., 1., 1.])


In [77]:
# Dice Scores:
dsc = dice(preds, masks)
dscwlogits = dice(logits, masks)
dscwprobs = dice(probs, masks)

print(f"Torchmetrics Dice Scores")
print(f"dsc: {dsc}")
print(f"dsc w logits {dscwlogits}")
print(f"dsc w probs {dscwprobs}")

Torchmetrics Dice Scores
dsc: 1.0
dsc w logits 1.0
dsc w probs 1.0


In [78]:
diceFG = dice(preds, masks, ignore_index=0)
dice_p_cls = dice(preds, masks, average=None, num_classes=n_classes) # average=None returns dice per class

print("Dice Score FG: ", diceFG)
print("Dice Score per Class: ", dice_p_cls)

Dice Score FG:  tensor(1.)
Dice Score per Class:  tensor([1., 1., 1., 1.])


In [79]:
# Monai Dice Score
monaidicescore = MonaiDiceScore(preds_oh, masks_oh)
monaidicescoreFG = MonaiDiceFGScore(preds_oh, masks_oh)

monaidicemetric = MonaiDiceMetric(preds_oh, masks_oh)
monaidicemetricFG = MonaiDiceMetricFG(preds_oh, masks_oh)


print(f"Monai Generalized Dice Score")
print(f"Monai Dice Score: {monaidicescore}")
print(f"Monai Dice FG Score {monaidicescoreFG}")
print(f"Monai Dice Metric: {monaidicemetric}")
print(f"Monai Dice Metric FG: {monaidicemetricFG}")
print(f"Monai Dice Metric Mean: {monaidicemetric.mean()}")
print(f"Monai Dice Metric FG Mean: {monaidicemetricFG.mean()}")

Monai Generalized Dice Score
Monai Dice Score: tensor([1.])
Monai Dice FG Score tensor([1.])
Monai Dice Metric: tensor([[1., 1., 1., 1.]])
Monai Dice Metric FG: tensor([[1., 1., 1.]])
Monai Dice Metric Mean: 1.0
Monai Dice Metric FG Mean: 1.0


In [80]:
# Soft Dice Scores

mesoftdicescore = MESoftDiceLoss(probs, masks)
mesoftdicelogitscore = MESoftDiceLogitLoss(logits, masks)
softdicescore = -SoftDiceLogitLoss(logits, masks)
mesoftdicescorefg = MESoftDiceLossFG(probs, masks)

print(f"ME Soft Dice Score w Probs: {mesoftdicescore}")
print(f"ME Soft Dice Score w Logits {mesoftdicelogitscore}")
print(f"Soft Dice Score w Logits: {softdicescore}")
print(f"ME soft dice score FG {mesoftdicescorefg}")

ME Soft Dice Score w Probs: 0.9515506625175476
ME Soft Dice Score w Logits 0.9515506625175476
Soft Dice Score w Logits: 0.951554536819458
ME soft dice score FG 0.9404711723327637


In [81]:
# Calculate the intersection and union for Soft Dice Score
def soft_dice_score(probabilities, one_hot_masks, smooth=1e-5):
    num_classes = probabilities.size(1)

    intersection = (probabilities * one_hot_masks).sum(dim=(0, 2, 3))
    union = probabilities.sum(dim=(0, 2, 3)) + one_hot_masks.sum(dim=(0, 2, 3))

    dice_score = (2 * intersection + smooth) / (union + smooth)

    return dice_score

# Calculate the Soft Dice Score
soft_dice = soft_dice_score(probs, masks_oh)
print(soft_dice)
print(soft_dice.mean())

tensor([0.9848, 0.9449, 0.9385, 0.9380])
tensor(0.9516)


In [82]:
# ET (Enhancing Tumor): label 3
dice_ET = DiceScore((preds == 3), (masks == 3))
dice_FG_ET = DiceFGScore((preds == 3), (masks == 3))

# TC(Tumor Core): ET + NCR = label 1 + label 3
dice_TC = DiceScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3))
dice_FG_TC = DiceFGScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3))

# WT (Whole Tumor): TC + ED = label 1 + label 2 + label 3
dice_WT = DiceScore((preds > 0), (masks > 0))
dice_FG_WT = DiceFGScore((preds > 0), (masks > 0))

print("\n")
print(f"BraTs Region Scores:")
print("Dice Score ET: ", dice_ET)
print("Dice FG Score ET: ", dice_FG_ET)
print("Dice Score TC: ", dice_TC)
print("Dice FG Score TC: ", dice_FG_TC)
print("Dice Score WT: ", dice_WT)
print("Dice FG Score WT: ", dice_FG_WT)



BraTs Region Scores:
Dice Score ET:  tensor(1.)
Dice FG Score ET:  tensor(1.)
Dice Score TC:  tensor(1.)
Dice FG Score TC:  tensor(1.)
Dice Score WT:  tensor(1.)
Dice FG Score WT:  tensor(1.)


In [83]:
#Using One-Hot Encoded Predictions and Masks
print(f"Dice Score with One-Hot Encoded Predictions and Masks:")
print(f"Dice with oh preds: {DiceScore(preds_oh, masks_oh)}")
print(f"DiceFG with oh preds: {DiceFGScore(preds_oh, masks_oh)}")

# Use DiceFG when using one-hot encoding and non-binary case

Dice Score with One-Hot Encoded Predictions and Masks:
Dice with oh preds: 1.0
DiceFG with oh preds: 1.0


# Losses

## Overall

### Soft 

In [53]:
# Soft Dice Losses
mesoftdicelogitloss = 1 - MESoftDiceLogitLoss(logits, masks)
mesoftdiceloss = 1 - MESoftDiceLoss(probs, masks) 

softdicelogitloss = 1 + SoftDiceLogitLoss(logits, masks)
softdiceloss = 1 + SoftDiceL(probs, masks)

print(f"soft dice logit loss {softdicelogitloss}")
print(f"soft dice loss {softdiceloss}")
print()
print(f"ME soft dice logit loss {mesoftdicelogitloss}")
print(f"ME soft dice loss {mesoftdiceloss}")
print()
print(f"Custom soft dice loss: {1- soft_dice_score(probs, masks_oh).mean()}")

soft dice logit loss 0.04844546318054199
soft dice loss 0.04844546318054199

ME soft dice logit loss 0.04844933748245239
ME soft dice loss 0.04844933748245239

Custom soft dice loss: 0.04844540357589722


### Hard

In [54]:
# Hard Losses
diceloss = 1- DiceScore(logits, masks)
MonaiDiceLogitLoss = DiceLoss(softmax = True, to_onehot_y=True)
monaidiceloss = MonaiDiceLogitLoss(logits, masks)

print(f" dice loss w torchmetrics {diceloss}")
print(f" monai dice loss:  {monaidiceloss}")

 dice loss w torchmetrics 0.0
 monai dice loss:  0.04844541847705841


## Brats Regions

In [55]:
brats_regions = {'ET': [3], 'TC': [1, 3], 'WT': [1, 2, 3]}

def prepare_region_logits(logits, target_classes):
    """
    Efficiently sum the logits for the given target classes.
    """
    return logits[:, target_classes].sum(dim=1, keepdim=True)

def prepare_region_gt(gt, target_classes):
    """
    Efficiently create a binary mask for the relevant target classes.
    """
    return torch.isin(gt, torch.tensor(target_classes))

### Hard

In [56]:
#Dice Losses
# Brats Dice Loss
dice_ET_loss = (1 - DiceScore((preds == 3), (masks == 3))) * dsc_loss_w
dice_TC_loss = (1 - DiceScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3)) ) * dsc_loss_w
dice_WT_loss = (1 - DiceScore((preds > 0), (masks > 0))) * dsc_loss_w

# Brats FG Dice Loss
diceFG_ET_loss = (1 - DiceFGScore((preds == 3), (masks == 3))) * dsc_loss_w
diceFG_TC_loss = (1 - DiceFGScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3)) ) * dsc_loss_w
diceFG_WT_loss = (1 - DiceFGScore((preds > 0), (masks > 0))) * dsc_loss_w


print(f"Old Approach")
print(f"Dice ET Loss: {dice_ET_loss}")
print(f"Dice ET FG Loss: {diceFG_ET_loss}")
print(f"Dice TC Loss: {dice_TC_loss}")
print(f"Dice TC FG Loss: {diceFG_TC_loss}")
print(f"Dice WT Loss: {dice_WT_loss}")
print(f"Dice WT FG Loss: {diceFG_WT_loss}")

Old Approach
Dice ET Loss: 0.0
Dice ET FG Loss: 0.0
Dice TC Loss: 0.0
Dice TC FG Loss: 0.0
Dice WT Loss: 0.0
Dice WT FG Loss: 0.0


### Soft

In [57]:
########

#LOOKING AT RESULTS FOR PERFECT OVERLAP EXAMPLE THIS CAN'T BE RIGHT, OR CAN IT?

########

logits_ET = prepare_region_logits(logits, brats_regions['ET'])
logits_TC = prepare_region_logits(logits, brats_regions['TC'])
logits_WT = prepare_region_logits(logits, brats_regions['WT'])

gt_ET = prepare_region_gt(masks, brats_regions['ET'])
gt_TC = prepare_region_gt(masks, brats_regions['TC'])
gt_WT = prepare_region_gt(masks, brats_regions['WT'])

softdicelogitlossET = -SoftDiceLogitLoss(logits_ET, gt_ET)
softdicelogitlossTC = -SoftDiceLogitLoss(logits_TC, gt_TC)
softdicelogitlossWT = -SoftDiceLogitLoss(logits_WT, gt_WT)

mesoftdicelogitlossET = MESoftDiceLogitLoss(logits_ET, gt_ET)
mesoftdicelogitlossTC = MESoftDiceLogitLoss(logits_TC, gt_TC)
mesoftdicelogitlossWT = MESoftDiceLogitLoss(logits_WT, gt_WT)

print(f"Soft Dice Logit Loss ET: {softdicelogitlossET}")
print(f"Soft Dice Logit Loss TC: {softdicelogitlossTC}")
print(f"Soft Dice Logit Loss WT: {softdicelogitlossWT}")
print()
print(f"ME Soft Dice Logit Loss ET: {mesoftdicelogitlossET}")
print(f"ME Soft Dice Logit Loss TC: {mesoftdicelogitlossTC}")
print(f"ME Soft Dice Logit Loss WT: {mesoftdicelogitlossWT}")

Soft Dice Logit Loss ET: 0.2000008076429367
Soft Dice Logit Loss TC: 0.36363697052001953
Soft Dice Logit Loss WT: 0.5000004172325134

ME Soft Dice Logit Loss ET: 0.19999980926513672
ME Soft Dice Logit Loss TC: 0.3636360466480255
ME Soft Dice Logit Loss WT: 0.49999961256980896


### Masked

In [66]:
########

#LOOKING AT RESULTS FOR PERFECT OVERLAP EXAMPLE THIS CAN'T BE RIGHT, WAAAAY TO HIGH

########

maskeddicelossET = MonaiMaskedDiceLoss(logits, masks, (masks == 3))
maskeddicelossTC = MonaiMaskedDiceLoss(logits, masks, ((masks == 1) | (masks == 3)))
maskeddicelossWT = MonaiMaskedDiceLoss(logits, masks, (masks > 0))

print (f"Masked Monai Dice Loss ET: {maskeddicelossET}")
print (f"Masked Monai Dice Loss TC: {maskeddicelossTC}")
print (f"Masked Monai Dice Loss WT: {maskeddicelossWT}")

Masked Monai Dice Loss ET: 0.7813824415206909
Masked Monai Dice Loss TC: 0.6462633013725281
Masked Monai Dice Loss WT: 0.49325984716415405


### Manual Masking

In [67]:
et_mask = (masks == 3).to(torch.uint8)
tc_mask = (((masks == 1) | (masks == 3))).to(torch.uint8)
wt_mask = ((masks > 0)).to(torch.uint8)

In [68]:
# Soft Dice Region Loss

SoftDiceLogitScore = MemoryEfficientSoftDiceLoss(nn.Softmax(dim=1), smooth= 1e-5)

mesoftdicelogitlossET = 1 - SoftDiceLogitScore(logits*et_mask, masks*et_mask)
mesoftdicelogitlossTC = 1 - SoftDiceLogitScore(logits*tc_mask, masks*tc_mask)
mesoftdicelogitlossWT = 1 - SoftDiceLogitScore(logits*wt_mask, masks*wt_mask)

print(f" ME Soft Dice Logit Loss ET: {mesoftdicelogitlossET}")
print(f" ME Soft Dice Logit Loss TC: {mesoftdicelogitlossTC}")
print(f" ME Soft Dice Logit Loss WT: {mesoftdicelogitlossWT}")

 ME Soft Dice Logit Loss ET: 0.7813857793807983
 ME Soft Dice Logit Loss TC: 0.6462664008140564
 ME Soft Dice Logit Loss WT: 0.4932623505592346


In [69]:
# Hard Dice Region Loss

MonaiDiceLogitLoss = DiceLoss(softmax = True, to_onehot_y=True)

monaidicelogitlossET = MonaiDiceLogitLoss(logits*et_mask, masks*et_mask)
monaidicelogitlossTC = MonaiDiceLogitLoss(logits*tc_mask, masks*tc_mask)
monaidicelogitlossWT = MonaiDiceLogitLoss(logits*wt_mask, masks*wt_mask)

print(f"Monai Dice Loss")
print(f"Monai Dice Logit Loss ET: {monaidicelogitlossET}")
print(f"Monai Dice Logit Loss TC: {monaidicelogitlossTC}")
print(f"Monai Dice Logit Loss WT: {monaidicelogitlossWT}")

Monai Dice Loss
Monai Dice Logit Loss ET: 0.7813824415206909
Monai Dice Logit Loss TC: 0.6462633013725281
Monai Dice Logit Loss WT: 0.49325984716415405


In [70]:
# Hard Dice FG Region Loss

MonaiDiceFGLogitLoss = DiceLoss(softmax = True, to_onehot_y=True, include_background=False)

monaidicelogitlossFG_ET = MonaiDiceFGLogitLoss(logits*et_mask, masks*et_mask)
monaidicelogitlossFG_TC = MonaiDiceFGLogitLoss(logits*tc_mask, masks*tc_mask)
monaidicelogitlossFG_WT = MonaiDiceFGLogitLoss(logits*wt_mask, masks*wt_mask)

print(f"Monai Foreground Dice Loss:")
print(f"Monai Dice Logit Loss ET: {monaidicelogitlossFG_ET}")
print(f"Monai Dice Logit Loss TC: {monaidicelogitlossFG_TC}")
print(f"Monai Dice Logit Loss WT: {monaidicelogitlossFG_WT}")

Monai Foreground Dice Loss:
Monai Dice Logit Loss ET: 0.8414978384971619
Monai Dice Logit Loss TC: 0.6608918309211731
Monai Dice Logit Loss WT: 0.45629552006721497


In [71]:
# Hard Dice Torchmetric Region Loss

DiceScore = Dice()

dicelogitscoreET = DiceScore(logits*et_mask, masks*et_mask)
dicelogitscoreTC = DiceScore(logits*tc_mask, masks*tc_mask)
dicelogitscoreWT = DiceScore(logits*wt_mask, masks*wt_mask)

print(f"Torchmetrics Dice Score:")
print(f" Torchmetrics Dice Logit Score ET: {dicelogitscoreET}")
print(f" Torchmetrics Dice Logit Score TC: {dicelogitscoreTC}")
print(f" Torchmetrics Dice Logit Score WT: {dicelogitscoreWT}")
print()
print(f"Torchmetrics Dice Loss:")
print(f" Torchmetrics Dice Logit Loss ET: {1 -dicelogitscoreET}")
print(f" Torchmetrics Dice Logit Loss TC: {1 -dicelogitscoreTC}")
print(f" Torchmetrics Dice Logit Score WT: {1- dicelogitscoreWT}")

Torchmetrics Dice Score:
 Torchmetrics Dice Logit Score ET: 1.0
 Torchmetrics Dice Logit Score TC: 1.0
 Torchmetrics Dice Logit Score WT: 1.0

Torchmetrics Dice Loss:
 Torchmetrics Dice Logit Loss ET: 0.0
 Torchmetrics Dice Logit Loss TC: 0.0
 Torchmetrics Dice Logit Score WT: 0.0


In [72]:
# Hard Dice Torchmetric FG Region Loss

DiceFGScore = Dice(ignore_index=0)

dicelogitscoreFG_ET = DiceFGScore(logits*et_mask, masks*et_mask)
dicelogitscoreFG_TC = DiceFGScore(logits*tc_mask, masks*tc_mask)
dicelogitscoreFG_WT = DiceFGScore(logits*wt_mask, masks*wt_mask)

print(f"Torchmetrics Dice FG Score:")
print(f" Torchmetrics Dice Logit Score ET: {dicelogitscoreFG_ET}")
print(f" Torchmetrics Dice Logit Score TC: {dicelogitscoreFG_TC}")
print(f" Torchmetrics Dice Logit Score WT: {dicelogitscoreFG_WT}")
print()
print(f"Torchmetrics Dice FG Loss:")
print(f" Torchmetrics Dice Logit Loss ET: {1 - dicelogitscoreFG_ET}")
print(f" Torchmetrics Dice Logit Loss TC: {1 - dicelogitscoreFG_TC}")
print(f" Torchmetrics Dice Logit Loss WT: {1 - dicelogitscoreFG_WT}")

Torchmetrics Dice FG Score:
 Torchmetrics Dice Logit Score ET: 1.0
 Torchmetrics Dice Logit Score TC: 1.0
 Torchmetrics Dice Logit Score WT: 1.0

Torchmetrics Dice FG Loss:
 Torchmetrics Dice Logit Loss ET: 0.0
 Torchmetrics Dice Logit Loss TC: 0.0
 Torchmetrics Dice Logit Loss WT: 0.0


In [73]:
#Old Approach for Comparison
# Brats Dice Loss
dice_ET_loss = (1 - DiceScore((preds == 3), (masks == 3))) * dsc_loss_w
dice_TC_loss = (1 - DiceScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3)) ) * dsc_loss_w
dice_WT_loss = (1 - DiceScore((preds > 0), (masks > 0))) * dsc_loss_w

# Brats FG Dice Loss
diceFG_ET_loss = (1 - DiceFGScore((preds == 3), (masks == 3))) * dsc_loss_w
diceFG_TC_loss = (1 - DiceFGScore((preds == 1) | (preds == 3), (masks == 1) | (masks == 3)) ) * dsc_loss_w
diceFG_WT_loss = (1 - DiceFGScore((preds > 0), (masks > 0))) * dsc_loss_w


print(f"Old Approach")
print(f"Dice ET Loss: {dice_ET_loss}")
print(f"Dice ET FG Loss: {diceFG_ET_loss}")
print(f"Dice TC Loss: {dice_TC_loss}")
print(f"Dice TC FG Loss: {diceFG_TC_loss}")
print(f"Dice WT Loss: {dice_WT_loss}")
print(f"Dice WT FG Loss: {diceFG_WT_loss}")

Old Approach
Dice ET Loss: 0.0
Dice ET FG Loss: 0.0
Dice TC Loss: 0.0
Dice TC FG Loss: 0.0
Dice WT Loss: 0.0
Dice WT FG Loss: 0.0
