In [17]:
import numpy as np
import torch

import logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s — %(name)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s",
    force=True
)

# Collect metrics of several volumes using SMP

In [None]:
from segmentation_models_pytorch import metrics

In [2]:
# lets assume we have multilabel prediction for 3 classes
output = np.random.randint(0,3, size=(2,32,32,32))
target = np.random.randint(0,3, size=(2,32,32,32))

In [3]:
output_tc=torch.from_numpy(output).long()
target_tc=torch.from_numpy(target).long()

In [4]:
# first compute statistics for true positives, false positives, false negative and
# true negative "pixels"
# mode='multiclass' shape (N,C, ...) and torch.LongTensor
# mode='multilabel' shape (N, ...) and torch.LongTensor
tp, fp, fn, tn = metrics.get_stats(output_tc, target_tc, mode='multilabel')

In [5]:
tp

tensor([[ 999, 1004, 1099,  963, 1053, 1025,  969, 1024, 1021, 1010, 1022, 1164,
         1043, 1026, 1042, 1029, 1041, 1041, 1104, 1001,  985, 1037, 1007, 1024,
          967, 1057, 1104, 1084, 1018, 1060, 1029, 1043],
        [1032,  984, 1014,  942,  978, 1025,  966, 1004, 1010, 1017, 1048,  924,
          992, 1028, 1092,  934, 1063, 1011, 1008, 1033,  994,  962, 1047, 1017,
         1101, 1042, 1097, 1060, 1053,  969, 1002, 1068]])

In [6]:
print(tp,fp,fn,tn)

tensor([[ 999, 1004, 1099,  963, 1053, 1025,  969, 1024, 1021, 1010, 1022, 1164,
         1043, 1026, 1042, 1029, 1041, 1041, 1104, 1001,  985, 1037, 1007, 1024,
          967, 1057, 1104, 1084, 1018, 1060, 1029, 1043],
        [1032,  984, 1014,  942,  978, 1025,  966, 1004, 1010, 1017, 1048,  924,
          992, 1028, 1092,  934, 1063, 1011, 1008, 1033,  994,  962, 1047, 1017,
         1101, 1042, 1097, 1060, 1053,  969, 1002, 1068]]) tensor([[   2,   14,  -24,   51,  -38,    5,   43,   28,    4,   12,  -19, -119,
          -13,   -4,    2,  -52,   15,   -5,  -45,   39,   40,   12,   26,    6,
           39,  -24,  -47,  -54,   10,  -63,  -20,   31],
        [ -12,  -12,    4,   39,   45,    2,   47,   55,   21,  -16,   10,   53,
           26,   -4,  -68,   55,  -43,   19,   39,    4,    0,   39,    7,   18,
          -59,   18,  -34,  -19,  -35,   28,    7,  -37]]) tensor([[ 22, -21, -78,   6,   4, -14,   1,  -4,  -8,  10,   7, -66,   0, -21,
         -31,  47, -17, -43, -55, -15, 

In [7]:
# then compute metrics with required reduction (see metric docs)
#iou_score = metrics.iou_score(tp, fp, fn, tn, reduction="micro")
iou_score = metrics.iou_score(tp, fp, fn, tn, reduction="macro")
iou_score

tensor(1.0011)

In [8]:
metrics.iou_score(tp, fp, fn, tn, reduction="micro")

tensor(1.0007)

In [10]:
metrics.iou_score(tp, fp, fn, tn, reduction=None)

tensor([[0.9765, 1.0070, 1.1023, 0.9441, 1.0334, 1.0089, 0.9566, 0.9771, 1.0039,
         0.9787, 1.0119, 1.1890, 1.0126, 1.0250, 1.0286, 1.0049, 1.0019, 1.0483,
         1.0996, 0.9766, 0.9363, 0.9886, 0.9951, 0.9865, 0.9307, 1.0655, 1.0866,
         1.0916, 0.9760, 1.0516, 0.9990, 1.0019],
        [1.0198, 0.9714, 1.0060, 0.9190, 0.9140, 1.0049, 0.9191, 0.9314, 0.9912,
         1.0020, 1.0214, 0.8571, 0.9367, 0.9828, 1.1269, 0.8904, 1.0630, 0.9902,
         0.9582, 1.0127, 0.9567, 0.9304, 1.0106, 0.9883, 1.0752, 0.9943, 1.1014,
         1.0222, 1.0426, 0.9604, 0.9872, 1.0220]])

In [9]:
f1_score = metrics.f1_score(tp, fp, fn, tn, reduction="macro")
f1_score

tensor(1.0002)

In [37]:
accuracy = metrics.accuracy(tp, fp, fn, tn, reduction="macro")
accuracy

tensor(0.9948)

This is not appropriate to what I wish to do. There is no option to seperate metrics for each class

# My metrics

In [28]:
def get_metric_stats_per_class(data_lbls,target, nclasses=None):
    """
    Gets true positive, false positive, false negative and true negative sums (tp,fp,fn,tn)
    for each class

    If nclasses is None (default) then it gets for all the classes based in the maximum
    value of both data_lbls and target

    """
    if isinstance(data_lbls, np.ndarray):
        data_lbls = torch.from_numpy(data_lbls)

    if isinstance(target, np.ndarray):
        target = torch.from_numpy(target)
    
    data_flat = torch.ravel(data_lbls)
    target_flat = torch.ravel(target)
    
    #Get max class
    if nclasses is None:
        nclasses = max(int(torch.max(data_flat)), int(torch.max(target_flat)) )+1
        logging.info(f"nclasses:{nclasses}")

    stats_per_class=[]
    for iclass in range(nclasses):
        pred_bin = data_flat==iclass
        gnd_bin =  target_flat==iclass
        tp = int(torch.sum(torch.bitwise_and(pred_bin,gnd_bin)))
        tn = int(torch.sum(torch.bitwise_and(torch.bitwise_not(pred_bin),torch.bitwise_not(gnd_bin))))
        fp = int(torch.sum(torch.bitwise_and(pred_bin,torch.bitwise_not(gnd_bin))))
        fn = int(torch.sum(torch.bitwise_and(torch.bitwise_not(pred_bin),gnd_bin)))
        metr_t= (tp,fp,fn,tn)  # Same order as used in SMP get_stats
        stats_per_class.append(metr_t)

    return stats_per_class

In [29]:
def accuracy_from_stats(tp,fp,fn,tn):
    return (tp+tn)/(tp+tn+fp+fn)

def iou_from_stats(tp,fp,fn,tn):
    return tp/(tp+fp+fn)

def f1dice_from_stats(tp,fp,fn,tn):
    return 2*tp/(2*tp+fp+fn)

def recall_from_stats(tp,fp,fn,tn):
    return tp/(tp+fn)

def precision_from_stats(tp,fp,fn,tn):
    return tp/(tp+fp)

In [30]:
stats = get_metric_stats_per_class(output, target)
stats

2024-06-28 14:11:14,977 — root — INFO — get_metric_stats_per_class:22 — nclasses:3


[(7255, 14528, 14511, 29242),
 (7386, 14459, 14637, 29054),
 (7257, 14651, 14490, 29138)]

In [31]:
for i,st0 in enumerate(stats):
    print(f"class:{i}, f1dice = {f1dice_from_stats(*stats[i])}, accuracy={accuracy_from_stats(*stats[i])}")

class:0, f1dice = 0.3331879032813612, accuracy=0.5569000244140625
class:1, f1dice = 0.3367374851828212, accuracy=0.5560302734375
class:2, f1dice = 0.3324705073874699, accuracy=0.5553436279296875
