In [47]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import seaborn as sns

from glob import glob
import os

from dataset import HLS_US
from model import MACUNet_model

In [13]:
file_list = [os.path.basename(f)[:-11] for f in glob('./data/hls/*')]

In [14]:
dataset = HLS_US('./data/', file_list, num_classes=13)
dataloader = DataLoader(dataset, batch_size=2)

model = MACUNet_model(6, 13)

In [79]:
for x, y in dataloader:
    logits = model.forward(x)
    pred = logits.argmax(dim=1)

x.shape, y.shape, logits.shape

(torch.Size([2, 6, 224, 224]),
 torch.Size([2, 224, 224]),
 torch.Size([2, 13, 224, 224]))

In [82]:
def calculate_metrics(y_true, y_pred, num_classes):
    # Flatten both true and predicted labels
    y_true_flat = y_true.view(-1)
    y_pred_flat = y_pred.view(-1)
    
    # Initialize variables to store metrics
    accuracy = torch.zeros(num_classes)
    recall = torch.zeros(num_classes)
    precision = torch.zeros(num_classes)
    iou = torch.zeros(num_classes)

    for cls in range(num_classes):
        # True positives
        TP = torch.sum((y_true_flat == cls) & (y_pred_flat == cls)).item()
        
        # False positives
        FP = torch.sum((y_true_flat != cls) & (y_pred_flat == cls)).item()
        
        # False negatives
        FN = torch.sum((y_true_flat == cls) & (y_pred_flat != cls)).item()
        
        # Intersection
        intersection = torch.sum((y_true_flat == cls) & (y_pred_flat == cls)).item()
        
        # Union
        union = TP + FP + FN
        
        # Calculate metrics
        if union != 0:
            accuracy[cls] = intersection / union
            recall[cls] = TP / (TP + FN) if (TP + FN) != 0 else 0
            precision[cls] = TP / (TP + FP)
            iou[cls] = intersection / (union - intersection)

    # Overall metrics
    overall_accuracy = torch.mean(accuracy)
    overall_recall = torch.mean(recall)
    overall_precision = torch.mean(precision)
    overall_iou = torch.mean(iou)

    return overall_accuracy, overall_recall, overall_precision, overall_iou, accuracy, recall, precision, iou

In [84]:
overall_accuracy, overall_recall, overall_precision, overall_iou, accuracy, recall, precision, iou = calculate_metrics(y, pred, 13)
overall_accuracy, overall_recall, overall_precision, overall_iou, accuracy, recall, precision, iou

(tensor(0.0091),
 tensor(0.0516),
 tensor(0.0562),
 tensor(0.0095),
 tensor([1.5254e-03, 2.6914e-02, 4.5212e-03, 0.0000e+00, 2.0208e-02, 8.5566e-03,
         2.0667e-03, 9.3729e-04, 1.3182e-03, 9.3484e-05, 0.0000e+00, 0.0000e+00,
         5.2586e-02]),
 tensor([0.0015, 0.0316, 0.0382, 0.0000, 0.0264, 0.0093, 0.0220, 0.2222, 0.2437,
         0.0137, 0.0000, 0.0000, 0.0615]),
 tensor([1.2469e-01, 1.5338e-01, 5.1020e-03, 0.0000e+00, 7.9475e-02, 9.7561e-02,
         2.2756e-03, 9.4038e-04, 1.3237e-03, 9.4118e-05, 0.0000e+00, 0.0000e+00,
         2.6553e-01]),
 tensor([1.5278e-03, 2.7658e-02, 4.5418e-03, 0.0000e+00, 2.0625e-02, 8.6304e-03,
         2.0710e-03, 9.3817e-04, 1.3200e-03, 9.3493e-05, 0.0000e+00, 0.0000e+00,
         5.5504e-02]))

In [89]:
import torchmetrics

accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=13)
accuracy(pred, y)

tensor(0.0197)

In [95]:
iou = torchmetrics.JaccardIndex(task='multiclass', num_classes=13)
iou(pred, y)

tensor(0.0091)

In [101]:
torch.tensor(0).float()

tensor(0.)