# Exploring class-specific dice score calculation

In [1]:
import torch

In [2]:
from TissueLabeling.config import Configuration
from TissueLabeling.data.dataset import get_data_loader
from TissueLabeling.metrics.metrics import Dice
from TissueLabeling.models.segformer import Segformer
from TissueLabeling.models.unet import Unet
from TissueLabeling.models.simple_unet import SimpleUnet
from TissueLabeling.parser import get_args
from TissueLabeling.training.trainer import Trainer
from TissueLabeling.utils import init_cuda, init_fabric, init_wandb, set_seed, main_timer

2023-12-13 08:59:05.644129: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-13 08:59:05.644191: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-13 08:59:05.651740: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-13 08:59:06.136033: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Simple Example (test Matthias's code)

In [72]:
nr_of_classes = 2

In [73]:
y_true = torch.Tensor([[1,1,0,0],[1,1,0,0],[1,1,1,0],[1,1,0,0]])
pred = torch.Tensor([[1,1,1,0],[1,1,0,0],[0,0,0,0],[0,0,0,0]])

In [74]:
y_true = ground_truth.reshape((1,1,4,4)).long()
pred = pred.reshape((1,1,4,4))
y_pred = torch.concat(((pred == 0).type(torch.float32)),axis=1)
print(y_true.shape, y_pred.shape)

torch.Size([1, 1, 4, 4]) torch.Size([1, 2, 4, 4])


In [75]:
y_true_oh = torch.nn.functional.one_hot(
            y_true.squeeze(1), num_classes=nr_of_classes
        ).permute(0, 3, 1, 2)
y_true_oh.shape

torch.Size([1, 2, 4, 4])

In [78]:
class_intersect = torch.sum(
            (y_true_oh * y_pred), axis=(2, 3)
        )
print(class_intersect) # expected = [6,4]

tensor([[6., 4.]])


In [79]:
class_denom = torch.sum(
            (y_true_oh + y_pred), axis=(2, 3)
        )
print(class_denom) # expected = [18,14]

tensor([[18., 14.]])


In [83]:
class_dice = 2 * class_intersect / class_denom
print(class_dice) # expected = [0.667,0.5714]

tensor([[0.6667, 0.5714]])


## Exploring outputs of test_dice.py

In [3]:
import os

In [13]:
# load files
save_path = '/om2/user/sabeen/nobrainer_data_norm/test_dice_data/'
nr_of_classes = 51
# get image/mask/probs from each multi_gp (saved as _#.pt where # = gpu global rank during multi gpu run)
image_0,mask_0,probs_0 = torch.load(os.path.join(save_path,'image_mask_probs_0.pt'))
image_1,mask_1,probs_1 = torch.load(os.path.join(save_path,'image_mask_probs_1.pt'))

# get class_intersect/class_denom
class_intersect_0, class_denom_0 = torch.load(os.path.join(save_path,f'itersect_denom_0.pt'))
class_intersect_1, class_denom_1 = torch.load(os.path.join(save_path,f'itersect_denom_1.pt'))

# get gathered
class_intersect_gather_0, class_denom_gather_0 = torch.load(os.path.join(save_path,f'itersect_denom_gather_0.pt'))
class_intersect_gather_1, class_denom_gather_1 = torch.load(os.path.join(save_path,f'itersect_denom_gather_1.pt'))

nr_of_classes = class_intersect_0.shape[1]
print(f'nr_of_classes = {nr_of_classes}')

nr_of_classes = 51


In [14]:
# find class_intersect and class_denom for batch_size = 2, single gpu case
image = torch.concat((image_0,image_1),axis=0)
mask = torch.concat((mask_0,mask_1),axis=0)
probs = torch.concat((probs_0,probs_1),axis=0)

# calculate intersect and union
y_true_oh = torch.nn.functional.one_hot(
    mask.long().squeeze(1), num_classes=nr_of_classes
).permute(0, 3, 1, 2)
class_intersect = torch.sum(
    (y_true_oh * probs), axis=(2, 3)
    )
class_denom = torch.sum(
    (y_true_oh + probs), axis=(2, 3)
)

In [15]:
# calculate class totals
single_intersect_sum = torch.sum(class_intersect,axis=0)
single_denom_sum = torch.sum(class_denom,axis=0)

multi_intersect_sum = torch.sum(torch.concat((class_intersect_0,class_intersect_1)),axis=0)
multi_denom_sum = torch.sum(torch.concat((class_denom_0,class_denom_1)),axis=0)

gather_intersect_sum_0 = torch.sum(class_intersect_gather_0,axis=0)
gather_denom_sum_0 = torch.sum(class_denom_gather_0,axis=0)

gather_intersect_sum_1 = torch.sum(class_intersect_gather_1,axis=0)
gather_denom_sum_1 = torch.sum(class_denom_gather_1,axis=0)

In [19]:
# comparing single_gpu and multi_gpu
print('sanity_check: gather_0 should equal gather_1')
nr_of_classes = len(single_intersect_sum)
print(f'gather_intersect_sum_0 == gather_intersect_sum_1: {torch.sum(gather_intersect_sum_0 == gather_intersect_sum_1)} / {nr_of_classes} are equal')
print(f'gather_denom_sum_0 == gather_denom_sum_1: {torch.sum(gather_denom_sum_0 == gather_denom_sum_1)} / {nr_of_classes} are equal')
print()

print('sanity_check: multi should equal gather_0 and gather_1')
print(f'multi_intersect_sum == gather_intersect_sum_0: {torch.sum(multi_intersect_sum == gather_intersect_sum_0)} / {nr_of_classes} are equal')
print(f'multi_denom_sum == gather_denom_sum_0: {torch.sum(multi_denom_sum == gather_denom_sum_0)} / {nr_of_classes} are equal')
print()

print('Q: does single = gather?')
print(f'single_intersect_sum == gather_intersect_sum_0: {torch.sum(gather_intersect_sum_0 == single_intersect_sum)} / {nr_of_classes} are equal')
print(f'single_denom_sum == gather_denom_sum_0: {torch.sum(gather_denom_sum_0 == single_denom_sum)} / {nr_of_classes} are equal')
print()

print('Q: does single = multi?')
print(f'single_intersect_sum == multi_intersect_sum: {torch.sum(multi_intersect_sum == single_intersect_sum)} / {nr_of_classes} are equal')
print(f'single_denom_sum == multi_denom_sum: {torch.sum(multi_denom_sum == single_denom_sum)} / {nr_of_classes} are equal')

sanity_check: gather_0 should equal gather_1
gather_intersect_sum_0 == gather_intersect_sum_1: 51 / 51 are equal
gather_denom_sum_0 == gather_denom_sum_1: 51 / 51 are equal

sanity_check: multi should equal gather_0 and gather_1
multi_intersect_sum == gather_intersect_sum_0: 51 / 51 are equal
multi_denom_sum == gather_denom_sum_0: 51 / 51 are equal

Q: does single = gather?
single_intersect_sum == gather_intersect_sum_0: 49 / 51 are equal
single_denom_sum == gather_denom_sum_0: 34 / 51 are equal

Q: does single = multi?
single_intersect_sum == multi_intersect_sum: 49 / 51 are equal
single_denom_sum == multi_denom_sum: 34 / 51 are equal
