# MulticlassJaccardIndex

**Semantic Segmentation**

**Inputs** (batch size, channels, height, width)

**Preds** (batch size, classes, height, width)

**Ground Truths** (batch size, height, width)

In [107]:
import torch
from torchmetrics.classification import MulticlassJaccardIndex

## For single image

In [108]:
# labels = torch.randint(0, 3, (1,2,2))

labels = torch.tensor(
    [[[0, 2],
      [2, 1]]]
)
print(labels)

tensor([[[0, 2],
         [2, 1]]])


In [109]:
# preds = torch.rand((1,3,2,2))

preds = torch.tensor(
    [[[[0.0751, 0.0239],
       [0.2764, 0.3437]],
       
      [[0.2993, 0.4008],
       [0.3562, 0.1968]],

      [[0.9602, 0.6309],
       [0.1658, 0.0777]]]]
)
print(preds)

tensor([[[[0.0751, 0.0239],
          [0.2764, 0.3437]],

         [[0.2993, 0.4008],
          [0.3562, 0.1968]],

         [[0.9602, 0.6309],
          [0.1658, 0.0777]]]])


In [110]:
(torch.argmax(preds, dim=1, keepdim=False) == 2).int()

tensor([[[1, 1],
         [0, 0]]], dtype=torch.int32)

In [111]:
(labels == 2).int()

tensor([[[0, 1],
         [1, 0]]], dtype=torch.int32)

In [112]:
# micro is iou for overall class

metric = MulticlassJaccardIndex(num_classes=3, average="micro")
metric(preds, labels)

tensor(0.1429)

In [113]:
1/7

0.14285714285714285

In [114]:
# none is iou for each class

metric = MulticlassJaccardIndex(num_classes=3, average="none")
metric(preds, labels)

tensor([0.0000, 0.0000, 0.3333])

In [115]:
1/3

0.3333333333333333

In [116]:
# macro is average iou of all classes

metric = MulticlassJaccardIndex(num_classes=3, average="macro")
metric(preds, labels)

tensor(0.1111)

In [117]:
1/9

0.1111111111111111

## For batch

In [118]:
batch = 2
classes = 3

In [119]:
# targets = torch.randint(0, 3, (batch,2,2))

labels = torch.tensor(
    [[[0, 1],
      [1, 2]],

     [[0, 1],
      [2, 2]]]
)
print(labels)

tensor([[[0, 1],
         [1, 2]],

        [[0, 1],
         [2, 2]]])


In [120]:
# preds = torch.rand((batch,classes,2,2))

preds = torch.tensor(
    [[[[0.1552, 0.0311],
       [0.2162, 0.9957]],

      [[0.0857, 0.8358],
       [0.3440, 0.7529]],

      [[0.6019, 0.6657],
       [0.6416, 0.6647]]],


     [[[0.3775, 0.0812],
       [0.6923, 0.8075]],

      [[0.2147, 0.0133],
       [0.2500, 0.9218]],

      [[0.9361, 0.2890],
       [0.8075, 0.6322]]]]
)
print(preds)

tensor([[[[0.1552, 0.0311],
          [0.2162, 0.9957]],

         [[0.0857, 0.8358],
          [0.3440, 0.7529]],

         [[0.6019, 0.6657],
          [0.6416, 0.6647]]],


        [[[0.3775, 0.0812],
          [0.6923, 0.8075]],

         [[0.2147, 0.0133],
          [0.2500, 0.9218]],

         [[0.9361, 0.2890],
          [0.8075, 0.6322]]]])


In [121]:
torch.argmax(preds, dim=1, keepdim=False)

tensor([[[2, 1],
         [2, 0]],

        [[2, 2],
         [2, 1]]])

In [122]:
metric = MulticlassJaccardIndex(num_classes=3, average="none")
metric(preds, labels)

tensor([0.0000, 0.2500, 0.1429])

In [123]:
metric = MulticlassJaccardIndex(num_classes=3, average="micro")
metric(preds, labels)

tensor(0.1429)

In [124]:
metric = MulticlassJaccardIndex(num_classes=3, average="macro")
metric(preds, labels)

tensor(0.1310)

In [125]:
torch.tensor([0.0000, 0.2500, 0.1429]).mean()

tensor(0.1310)

In [126]:
preds.shape

torch.Size([2, 3, 2, 2])

# MultilabelJaccardIndex

**Semantic Segmentation**

**Inputs** (batch size, channels, height, width)

**Preds** (batch size, classes, height, width)

**Ground Truths** (batch size, classes, height, width)

In [127]:
import torch
from torchmetrics.classification import MultilabelJaccardIndex

In [128]:
# labels = torch.randint(0, 2, (1,3,2,2))

labels = torch.tensor(
    [[[[0, 0],
       [1, 0]],

      [[1, 0],
       [1, 1]],

      [[1, 1],
       [1, 1]]]]
)
print(labels)

tensor([[[[0, 0],
          [1, 0]],

         [[1, 0],
          [1, 1]],

         [[1, 1],
          [1, 1]]]])


In [129]:
# preds = torch.rand((1,3,2,2))

preds = torch.tensor(
    [[[[0.9709, 0.4751],
       [0.6475, 0.9009]],

      [[0.7615, 0.1242],
       [0.0036, 0.6726]],

      [[0.0765, 0.1898],
       [0.2989, 0.2798]]]]
)
print(preds)

tensor([[[[0.9709, 0.4751],
          [0.6475, 0.9009]],

         [[0.7615, 0.1242],
          [0.0036, 0.6726]],

         [[0.0765, 0.1898],
          [0.2989, 0.2798]]]])


In [138]:
print(labels)
print((preds >= 0.5).int())

tensor([[[[0, 0],
          [1, 0]],

         [[1, 0],
          [1, 1]],

         [[1, 1],
          [1, 1]]]])
tensor([[[[1, 0],
          [1, 1]],

         [[1, 0],
          [0, 1]],

         [[0, 0],
          [0, 0]]]], dtype=torch.int32)


In [131]:
metric = MultilabelJaccardIndex(num_labels=3, average="none")
metric(preds, labels)

tensor([0.3333, 0.6667, 0.0000])

In [132]:
metric = MultilabelJaccardIndex(num_labels=3, average="micro")
metric(preds, labels)

tensor(0.3000)

In [133]:
metric = MultilabelJaccardIndex(num_labels=3, average="macro")
metric(preds, labels)

tensor(0.3333)

In [139]:
torch.tensor([0.3333, 0.6667, 0.0000]).mean()

tensor(0.3333)