In [1]:
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mock_inp = torch.tensor([[[0.2406, -0.4713],[0.3808, 0.7049]],[[0.5114, 0.0699],[-0.3723, -0.6597]]], dtype=torch.float)

In [30]:
mock_inp.shape

torch.Size([2, 2, 2])

In [10]:
probs = mock_inp.softmax(0)

In [45]:
probs[:,:,:]

tensor([[[0.4327, 0.3679],
         [0.6799, 0.7965]],

        [[0.5673, 0.6321],
         [0.3201, 0.2035]]])

In [12]:
mock_batch = mock_inp.unsqueeze(dim=0)

In [49]:
import numpy as np

In [54]:
sum = np.log(0.4327) + np.log(0.6799) + np.log(0.6321) + np.log(0.2035)

In [55]:
sum *= -1

In [56]:
sum / 4

0.8185792814510982

In [31]:
mock_batch.shape

torch.Size([1, 2, 2, 2])

In [36]:
target = torch.LongTensor([[[0,1],[0,1]]])

In [37]:
target[:,:]

tensor([[[0, 1],
         [0, 1]]])

In [46]:
loss = CrossEntropyLoss()

In [47]:
output = loss(mock_batch, target)

In [48]:
# 1.422055521 / 4 = 0.3555
output

tensor(0.8186)

In [59]:
class HeadSegmentationCrossEntropyLoss(torch.nn.CrossEntropyLoss):
    def __init__(self, background_weight: float = 1.0, head_weight: float = 1.0):
        super().__init__(weight=[background_weight, head_weight])

In [138]:
from torchmetrics import Metric
from torchmetrics import ConfusionMatrix

class Metric(Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        
        self.matrix = ConfusionMatrix(num_classes=2)
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def _update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        self.matrix.update(preds, target)
        self.total += target.numel()

    def _compute(self):
        TP = self.matrix[0,0]
        FN = self.matrix[0,1]
        FP = self.matrix[1,0]
        TN = self.matrix[1,1]
        
        accuracy = (TP + TN) / self.total
        precision = TP / (TP + FP)
        recall = TP / (TP + FN)
        IoU = TP / (TP + FP + FN)
        mIoU = IoU.mean
        
        return {
            accuracy,
            precision,
            recall,
            mIoU
        }