In [4]:
import os
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from unet.unet_model import Unet,ResNeXtUnet
from cityscapes import build_datasets
from losses import SegmentationLoss
from utils import *


In [11]:
device = torch.device('cpu')
class_names = [
        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
        'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
        'truck', 'bus', 'train', 'motorcycle', 'bicycle', 'void'
    ]

In [18]:
# only used validation. use argmax fn
def calculate_iou(pred, target, num_classes, ignore_index=None):
    # pred : model(inp).argmax(dim=c) 
    pred = pred.long()
    target = target.long()
    
    pred = pred.view(-1)
    target = target.view(-1)
    
    # Create a mask for valid pixels if ignore_index is provided
    if ignore_index is not None:
        valid_mask = target != ignore_index
        pred = pred[valid_mask]
        target = target[valid_mask]
    
    # Ensure all values are within the valid range
    pred = torch.clamp(pred, 0, num_classes - 1)
    target = torch.clamp(target, 0, num_classes - 1)
    
    # One-hot encoding
    pred_one_hot = torch.nn.functional.one_hot(pred, num_classes=num_classes)
    target_one_hot = torch.nn.functional.one_hot(target, num_classes=num_classes)

    # Intersection and Union
    intersection = (pred_one_hot & target_one_hot).sum(dim=0)
    union = (pred_one_hot | target_one_hot).sum(dim=0)

    # IoU calculation
    iou = intersection.float() / union.float()
    iou[union == 0] = float('nan')  # Set IoU to NaN where union is zero

    return iou.tolist()

def evaluate(model, data_loader):
    model.eval()
    pixel_accuracy = 0.0
    total_miou = 0.0
    num_classes = 19
    class_ious = [[] for _ in range(num_classes)]

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc='Validating', leave=False):
            images = images.to(device)
            labels = labels.to(device).long()

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            pixel_accuracy += (predicted == labels).sum().item() / (labels.size(0) * labels.size(1) * labels.size(2))
            ious = calculate_iou(predicted, labels, num_classes)
            for cls in range(num_classes):
                if not np.isnan(ious[cls]):
                    class_ious[cls].append(ious[cls])
            total_miou += np.nanmean(ious)

    pixel_accuracy /= len(data_loader)
    total_miou /= len(data_loader)

    class_miou = [np.mean(cls_ious) if cls_ious else float('nan') for cls_ious in class_ious]

    return pixel_accuracy, total_miou, class_miou

def get_model(model_name='resnextunet',device='cpu',checkpoint_path=None,train=False,act=None,n_classes=19):
    sup_model ={
        'unet':Unet,
        'resnextunet':ResNeXtUnet,
    }
    model = sup_model[model_name](in_channels=3,n_classes=n_classes,act=act)
    chkpt = torch.load(checkpoint_path,map_location='cpu')['model']
    model.load_state_dict(chkpt)
    if not train:
        model.eval()
    model.to(device)
    return model


In [19]:
checkpoint_path = '/workspace/UNET/checkpoints/unet_void/unet_ReLU_1_240_0.49.pth'
model = get_model('unet',checkpoint_path=checkpoint_path,act=nn.ReLU(),n_classes=20)


In [20]:
_, val_loader = build_datasets(batch_size=2)

In [None]:
pixel_accuracy, total_miou, class_miou = evaluate(model,data_loader=val_loader)

In [15]:
for idx, (name, iou) in enumerate(zip(class_names, class_miou)):
    print(f"{name}({idx}): {iou*100:.2f}")
print(f'{total_miou*100:.2f}')

road(0): 87.76
sidewalk(1): 61.20
building(2): 79.13
wall(3): 14.52
fence(4): 14.16
pole(5): 39.27
traffic light(6): 21.84
traffic sign(7): 44.98
vegetation(8): 83.68
terrain(9): 25.73
sky(10): 79.21
person(11): 41.56
rider(12): 23.15
car(13): 80.62
truck(14): 11.22
bus(15): 21.23
train(16): 4.40
motorcycle(17): 6.11
bicycle(18): 36.26
void(19): 50.36
46.47
