# Exploring class-specific dice score calculation

In [1]:
import torch

In [2]:
from TissueLabeling.config import Configuration
from TissueLabeling.data.dataset import get_data_loader
from TissueLabeling.metrics.metrics import Dice
from TissueLabeling.models.segformer import Segformer
from TissueLabeling.models.unet import Unet
from TissueLabeling.models.simple_unet import SimpleUnet
from TissueLabeling.parser import get_args
from TissueLabeling.training.trainer import Trainer
from TissueLabeling.utils import init_cuda, init_fabric, init_wandb, set_seed, main_timer

2023-12-14 07:07:14.799136: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-14 07:07:14.799189: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-14 07:07:14.799221: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-14 07:07:14.805910: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Simple Example (test Matthias's code)

In [3]:
nr_of_classes = 2

In [12]:
y_true = torch.Tensor([[1,1,0,0],[1,1,0,0],[1,1,1,0],[1,1,0,0]])
pred = torch.Tensor([[1,1,1,0],[1,1,0,0],[0,0,0,0],[0,0,0,0]])
print(y_true.shape, pred.shape)

torch.Size([4, 4]) torch.Size([4, 4])


In [11]:
y_true = y_true.reshape((1,1,4,4)).long()
pred = pred.reshape((1,1,4,4))
y_pred = torch.concat(((pred == 0).dtype(torch.float32)),axis=1)
print(y_true.shape, y_pred.shape)

NameError: name 'y_pred' is not defined

In [None]:
y_true_oh = torch.nn.functional.one_hot(
            y_true.squeeze(1), num_classes=nr_of_classes
        ).permute(0, 3, 1, 2)
y_true_oh.shape

In [6]:
class_intersect = torch.sum(
            (y_true_oh * y_pred), axis=(2, 3)
        )
print(class_intersect) # expected = [6,4]

NameError: name 'y_true_oh' is not defined

In [79]:
class_union = torch.sum(
            (y_true_oh + y_pred), axis=(2, 3)
        )
print(class_union) # expected = [18,14]

tensor([[18., 14.]])


In [83]:
class_dice = 2 * class_intersect / class_union
print(class_dice) # expected = [0.667,0.5714]

tensor([[0.6667, 0.5714]])


In [None]:
import torchmetrics
f1 = torchmetrics.F1Score("multiclass",num_classes=2,average=None)
f1(preds,target)

## Exploring outputs of test_dice.py - Binary

In [16]:
import os

In [17]:
# load files
save_path = '/om2/user/sabeen/nobrainer_data_norm/test_dice_data/binary'
nr_of_classes = 51
# get image/mask/probs from each multi_gp (saved as _#.pt where # = gpu global rank during multi gpu run)
image_0,mask_0,probs_0 = torch.load(os.path.join(save_path,'image_mask_probs_0.pt'))
image_1,mask_1,probs_1 = torch.load(os.path.join(save_path,'image_mask_probs_1.pt'))

# get class_intersect/class_union
class_intersect_0, class_union_0 = torch.load(os.path.join(save_path,f'itersect_denom_0.pt'))
class_intersect_1, class_union_1 = torch.load(os.path.join(save_path,f'itersect_denom_1.pt'))

# get gathered
class_intersect_gather_0, class_union_gather_0 = torch.load(os.path.join(save_path,f'itersect_denom_gather_0.pt'))
class_intersect_gather_1, class_union_gather_1 = torch.load(os.path.join(save_path,f'itersect_denom_gather_1.pt'))

nr_of_classes = class_intersect_0.shape[1]
print(f'nr_of_classes = {nr_of_classes}')

nr_of_classes = 2


### Using torchmetrics F1

In [18]:
from torchmetrics import F1Score

In [25]:
f1 = F1Score(task="multiclass",num_classes=2,average=None)
# requires input to be (pred, target) where pred.shape = (num_classes,num_values) and target.shape = (num_values)

In [26]:
f1(probs_0,mask_0)

ValueError: ('The `preds` and `target` should have the same shape,', ' got `preds` with shape=torch.Size([1, 2, 162, 194]) and `target` with shape=torch.Size([1, 1, 162, 194]).')

In [31]:
preds_2 = probs_0.permute(1,2,3,0)
target_2 = mask_0.permute(1,2,3,0)
f1(preds_2,target_2)

ValueError: ('The `preds` and `target` should have the same shape,', ' got `preds` with shape=torch.Size([2, 162, 194, 1]) and `target` with shape=torch.Size([1, 162, 194, 1]).')

In [35]:
preds = probs_0.reshape(2,-1).T # to reshape into (C,N) (doesn't account for batch_size)
target = mask_0.reshape(-1) # to reshape into (N)(doesn't account for batch_size)
f1(preds,target)

tensor([0.3771, 0.1108])

In [28]:
2 * class_intersect_0 / class_union_0

tensor([[0.6110, 0.1137]], grad_fn=<DivBackward0>)

### Comparison v1 - Binary

In [79]:
# find class_intersect and class_union for batch_size = 2, single gpu case
image = torch.concat((image_0,image_1),axis=0)
mask = torch.concat((mask_0,mask_1),axis=0)
probs = torch.concat((probs_0,probs_1),axis=0)

# calculate intersect and union
y_true_oh = torch.nn.functional.one_hot(
    mask.long().squeeze(1), num_classes=nr_of_classes
).permute(0, 3, 1, 2)
class_intersect = torch.sum(
    (y_true_oh * probs), axis=(2, 3)
    )
class_union = torch.sum(
    (y_true_oh + probs), axis=(2, 3)
)

In [81]:
# calculate class totals
single_intersect_sum = torch.sum(class_intersect,axis=0)
single_denom_sum = torch.sum(class_union,axis=0)

multi_intersect_sum = torch.sum(torch.concat((class_intersect_0,class_intersect_1)).reshape((2,1,nr_of_classes)),axis=0)
multi_denom_sum = torch.sum(torch.concat((class_union_0,class_union_1)).reshape((2,1,nr_of_classes)),axis=0)

gather_intersect_sum_0 = torch.sum(class_intersect_gather_0,axis=0)
gather_denom_sum_0 = torch.sum(class_union_gather_0,axis=0)

gather_intersect_sum_1 = torch.sum(class_intersect_gather_1,axis=0)
gather_denom_sum_1 = torch.sum(class_union_gather_1,axis=0)

In [82]:
# comparing single_gpu and multi_gpu
print('sanity_check: gather_0 should equal gather_1')
nr_of_classes = len(single_intersect_sum)
print(f'gather_intersect_sum_0 == gather_intersect_sum_1: {torch.sum(gather_intersect_sum_0 == gather_intersect_sum_1)} / {nr_of_classes} are equal')
print(f'gather_denom_sum_0 == gather_denom_sum_1: {torch.sum(gather_denom_sum_0 == gather_denom_sum_1)} / {nr_of_classes} are equal')
print()

print('sanity_check: multi should equal gather_0 and gather_1')
print(f'multi_intersect_sum == gather_intersect_sum_0: {torch.sum(multi_intersect_sum == gather_intersect_sum_0)} / {nr_of_classes} are equal')
print(f'multi_denom_sum == gather_denom_sum_0: {torch.sum(multi_denom_sum == gather_denom_sum_0)} / {nr_of_classes} are equal')
print()

print('Q: does single = gather?')
print(f'single_intersect_sum == gather_intersect_sum_0: {torch.sum(gather_intersect_sum_0 == single_intersect_sum)} / {nr_of_classes} are equal')
print(f'single_denom_sum == gather_denom_sum_0: {torch.sum(gather_denom_sum_0 == single_denom_sum)} / {nr_of_classes} are equal')
print()

print('Q: does single = multi?')
print(f'single_intersect_sum == multi_intersect_sum: {torch.sum(multi_intersect_sum == single_intersect_sum)} / {nr_of_classes} are equal')
print(f'single_denom_sum == multi_denom_sum: {torch.sum(multi_denom_sum == single_denom_sum)} / {nr_of_classes} are equal')

sanity_check: gather_0 should equal gather_1
gather_intersect_sum_0 == gather_intersect_sum_1: 2 / 2 are equal
gather_denom_sum_0 == gather_denom_sum_1: 2 / 2 are equal

sanity_check: multi should equal gather_0 and gather_1
multi_intersect_sum == gather_intersect_sum_0: 2 / 2 are equal
multi_denom_sum == gather_denom_sum_0: 2 / 2 are equal

Q: does single = gather?
single_intersect_sum == gather_intersect_sum_0: 1 / 2 are equal
single_denom_sum == gather_denom_sum_0: 2 / 2 are equal

Q: does single = multi?
single_intersect_sum == multi_intersect_sum: 1 / 2 are equal
single_denom_sum == multi_denom_sum: 2 / 2 are equal


### Comparison v2 - Binary

In [84]:
torch.sum(torch.concat((class_intersect_0,class_intersect_1)).reshape((2,1,nr_of_classes)) != class_intersect_gather_0)

tensor(0)

In [89]:
# find class_intersect and class_union for batch_size = 2, single gpu case
image = torch.concat((image_0,image_1),axis=0)
mask = torch.concat((mask_0,mask_1),axis=0)
probs = torch.concat((probs_0,probs_1),axis=0)

# calculate intersect and union
y_true_oh = torch.nn.functional.one_hot(
    mask.long().squeeze(1), num_classes=nr_of_classes
).permute(0, 3, 1, 2)
class_intersect = torch.sum(
    (y_true_oh * probs), axis=(2, 3)
    )
class_union = torch.sum(
    (y_true_oh + probs), axis=(2, 3)
)

In [98]:
class_intersect.reshape((2,1,nr_of_classes))

tensor([[[13343.8438,  1089.5955]],

        [[13318.8857,  1417.1342]]], grad_fn=<ViewBackward0>)

In [100]:
class_intersect_gather_0 - class_intersect.reshape((2,1,nr_of_classes))

tensor([[[ 0.0000,  0.0000]],

        [[-0.0010, -0.0001]]], grad_fn=<SubBackward0>)

In [90]:
torch.sum(class_intersect.reshape((2,1,nr_of_classes)) != class_intersect_gather_0)

tensor(2)

In [101]:
y_true_oh_single_0 = torch.nn.functional.one_hot(
    mask_0.long().squeeze(1), num_classes=nr_of_classes
).permute(0, 3, 1, 2)
class_intersect_single_0 = torch.sum(
    (y_true_oh_single_0 * probs_0), axis=(2, 3)
)
class_union_single_0 = torch.sum(
    (y_true_oh_single_0 + probs_0), axis=(2,3)
)
print(y_true_oh_single_0.shape, class_intersect_single_0.shape, class_union_single_0.shape)

torch.Size([1, 2, 162, 194]) torch.Size([1, 2]) torch.Size([1, 2])


In [102]:
y_true_oh_single_1 = torch.nn.functional.one_hot(
    mask_1.long().squeeze(1), num_classes=nr_of_classes
).permute(0, 3, 1, 2)
class_intersect_single_1 = torch.sum(
    (y_true_oh_single_1 * probs_1), axis=(2, 3)
)
class_union_single_1 = torch.sum(
    (y_true_oh_single_1 + probs_1), axis=(2,3)
)
print(y_true_oh_single_1.shape, class_intersect_single_1.shape, class_union_single_1.shape)

torch.Size([1, 2, 162, 194]) torch.Size([1, 2]) torch.Size([1, 2])


In [104]:
a = torch.concat((class_intersect_0,class_intersect_1)).reshape((2,1,nr_of_classes))
a.shape

torch.Size([2, 1, 2])

In [106]:
b = torch.concat((class_intersect_single_0, class_intersect_single_1)).reshape((2,1,nr_of_classes))
b.shape

torch.Size([2, 1, 2])

In [107]:
a_sum = torch.sum(a,axis=0)
b_sum = torch.sum(b,axis=0)
torch.sum(a_sum != b_sum)

tensor(1)

In [108]:
a_sum - b_sum

tensor([[-0.0020,  0.0000]], grad_fn=<SubBackward0>)

## Exploring outputs of test_dice.py - 51 Classes

In [77]:
import os

In [109]:
# load files
save_path = '/om2/user/sabeen/nobrainer_data_norm/test_dice_data/51class'
nr_of_classes = 51
# get image/mask/probs from each multi_gp (saved as _#.pt where # = gpu global rank during multi gpu run)
image_0,mask_0,probs_0 = torch.load(os.path.join(save_path,'image_mask_probs_0.pt'))
image_1,mask_1,probs_1 = torch.load(os.path.join(save_path,'image_mask_probs_1.pt'))

# get class_intersect/class_union
class_intersect_0, class_union_0 = torch.load(os.path.join(save_path,f'itersect_denom_0.pt'))
class_intersect_1, class_union_1 = torch.load(os.path.join(save_path,f'itersect_denom_1.pt'))

# get gathered
class_intersect_gather_0, class_union_gather_0 = torch.load(os.path.join(save_path,f'itersect_denom_gather_0.pt'))
class_intersect_gather_1, class_union_gather_1 = torch.load(os.path.join(save_path,f'itersect_denom_gather_1.pt'))

nr_of_classes = class_intersect_0.shape[1]
print(f'nr_of_classes = {nr_of_classes}')

nr_of_classes = 51


### Comparison v1 - 51 clas

In [110]:
# find class_intersect and class_union for batch_size = 2, single gpu case
image = torch.concat((image_0,image_1),axis=0)
mask = torch.concat((mask_0,mask_1),axis=0)
probs = torch.concat((probs_0,probs_1),axis=0)

# calculate intersect and union
y_true_oh = torch.nn.functional.one_hot(
    mask.long().squeeze(1), num_classes=nr_of_classes
).permute(0, 3, 1, 2)
class_intersect = torch.sum(
    (y_true_oh * probs), axis=(2, 3)
    )
class_union = torch.sum(
    (y_true_oh + probs), axis=(2, 3)
)

In [111]:
# calculate class totals
single_intersect_sum = torch.sum(class_intersect,axis=0)
single_denom_sum = torch.sum(class_union,axis=0)

multi_intersect_sum = torch.sum(torch.concat((class_intersect_0,class_intersect_1)).reshape((2,1,nr_of_classes)),axis=0)
multi_denom_sum = torch.sum(torch.concat((class_union_0,class_union_1)).reshape((2,1,nr_of_classes)),axis=0)

gather_intersect_sum_0 = torch.sum(class_intersect_gather_0,axis=0)
gather_denom_sum_0 = torch.sum(class_union_gather_0,axis=0)

gather_intersect_sum_1 = torch.sum(class_intersect_gather_1,axis=0)
gather_denom_sum_1 = torch.sum(class_union_gather_1,axis=0)

In [127]:
# comparing single_gpu and multi_gpu
print('sanity_check: gather_0 should equal gather_1')
nr_of_classes = len(single_intersect_sum)
print(f'gather_intersect_sum_0 == gather_intersect_sum_1: {torch.sum(gather_intersect_sum_0 == gather_intersect_sum_1)} / {nr_of_classes} are equal')
print(f'gather_denom_sum_0 == gather_denom_sum_1: {torch.sum(gather_denom_sum_0 == gather_denom_sum_1)} / {nr_of_classes} are equal')
print()

print('sanity_check: multi should equal gather_0 and gather_1')
print(f'multi_intersect_sum == gather_intersect_sum_0: {torch.sum(multi_intersect_sum == gather_intersect_sum_0)} / {nr_of_classes} are equal')
print(f'multi_denom_sum == gather_denom_sum_0: {torch.sum(multi_denom_sum == gather_denom_sum_0)} / {nr_of_classes} are equal')
print()

print('Q: does single = gather?')
print(f'single_intersect_sum == gather_intersect_sum_0: {torch.sum(gather_intersect_sum_0 == single_intersect_sum)} / {nr_of_classes} are equal')
print(f'single_denom_sum == gather_denom_sum_0: {torch.sum(gather_denom_sum_0 == single_denom_sum)} / {nr_of_classes} are equal')
print()

print('Q: does single = multi?')
print(f'single_intersect_sum == multi_intersect_sum: {torch.sum(multi_intersect_sum == single_intersect_sum)} / {nr_of_classes} are equal')
print(f'single_denom_sum == multi_denom_sum: {torch.sum(multi_denom_sum == single_denom_sum)} / {nr_of_classes} are equal')

sanity_check: gather_0 should equal gather_1
gather_intersect_sum_0 == gather_intersect_sum_1: 51 / 51 are equal
gather_denom_sum_0 == gather_denom_sum_1: 51 / 51 are equal

sanity_check: multi should equal gather_0 and gather_1
multi_intersect_sum == gather_intersect_sum_0: 51 / 51 are equal
multi_denom_sum == gather_denom_sum_0: 51 / 51 are equal

Q: does single = gather?
single_intersect_sum == gather_intersect_sum_0: 49 / 51 are equal
single_denom_sum == gather_denom_sum_0: 33 / 51 are equal

Q: does single = multi?
single_intersect_sum == multi_intersect_sum: 49 / 51 are equal
single_denom_sum == multi_denom_sum: 33 / 51 are equal


### Comparison v2

In [113]:
torch.sum(torch.concat((class_intersect_0,class_intersect_1)).reshape((2,1,nr_of_classes)) != class_intersect_gather_0)

tensor(0)

In [118]:
# find class_intersect and class_union for batch_size = 2, single gpu case
image = torch.concat((image_0,image_1),axis=0)
mask = torch.concat((mask_0,mask_1),axis=0)
probs = torch.concat((probs_0,probs_1),axis=0)

# calculate intersect and union
y_true_oh = torch.nn.functional.one_hot(
    mask.long().squeeze(1), num_classes=nr_of_classes
).permute(0, 3, 1, 2)
class_intersect = torch.sum(
    (y_true_oh * probs), axis=(2, 3)
    )
class_union = torch.sum(
    (y_true_oh + probs), axis=(2, 3)
)

In [119]:
torch.sum(class_intersect.reshape((2,1,nr_of_classes)) != class_intersect_gather_0)

tensor(2)

In [120]:
y_true_oh_single_0 = torch.nn.functional.one_hot(
    mask_0.long().squeeze(1), num_classes=nr_of_classes
).permute(0, 3, 1, 2)
class_intersect_single_0 = torch.sum(
    (y_true_oh_single_0 * probs_0), axis=(2, 3)
)
class_union_single_0 = torch.sum(
    (y_true_oh_single_0 + probs_0), axis=(2,3)
)
print(y_true_oh_single_0.shape, class_intersect_single_0.shape, class_union_single_0.shape)

torch.Size([1, 51, 162, 194]) torch.Size([1, 51]) torch.Size([1, 51])


In [121]:
y_true_oh_single_1 = torch.nn.functional.one_hot(
    mask_1.long().squeeze(1), num_classes=nr_of_classes
).permute(0, 3, 1, 2)
class_intersect_single_1 = torch.sum(
    (y_true_oh_single_1 * probs_1), axis=(2, 3)
)
class_union_single_1 = torch.sum(
    (y_true_oh_single_1 + probs_1), axis=(2,3)
)
print(y_true_oh_single_1.shape, class_intersect_single_1.shape, class_union_single_1.shape)

torch.Size([1, 51, 162, 194]) torch.Size([1, 51]) torch.Size([1, 51])


In [122]:
a = torch.concat((class_intersect_0,class_intersect_1)).reshape((2,1,nr_of_classes))
a.shape

torch.Size([2, 1, 51])

In [123]:
b = torch.concat((class_intersect_single_0, class_intersect_single_1)).reshape((2,1,nr_of_classes))
b.shape

torch.Size([2, 1, 51])

In [124]:
a_sum = torch.sum(a,axis=0)
b_sum = torch.sum(b,axis=0)
torch.sum(a_sum != b_sum)

tensor(2)

In [125]:
a_sum - b_sum

tensor([[1.2207e-04, 1.5259e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00]], grad_fn=<SubBackward0>)

In [129]:
single_intersect_sum - gather_intersect_sum_0

tensor([[-1.2207e-04, -1.5259e-05,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00]], grad_fn=<SubBackward0>)

In [130]:
single_denom_sum - gather_denom_sum_0

tensor([[ 0.0000,  0.0000,  0.0001,  0.0000,  0.0000,  0.0001,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000, -0.0001, -0.0002,  0.0000,  0.0000,
          0.0000, -0.0001,  0.0000, -0.0001, -0.0001,  0.0000,  0.0001, -0.0002,
          0.0000,  0.0000,  0.0000,  0.0001,  0.0000,  0.0001,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000, -0.0001, -0.0001,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000, -0.0001,  0.0000,  0.0000, -0.0001,  0.0000, -0.0001,
         -0.0001,  0.0001,  0.0000]], grad_fn=<SubBackward0>)