In [1]:
import os
import argparse
from tqdm import tqdm
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from data_loader.cityscapes import CityscapesDataset
import utils.transforms
from model.deeplabv3plus import DeepLabv3Plus

from model.metric import iou, SegmentationMetrics

In [2]:
writer = SummaryWriter()
run = 'run1'
checkpoint_file = os.path.join('./checkpoints', run, 'best_model.pt')
img_root = './data/leftImg8bit'
mask_root = './data/gtFine'
transform = utils.transforms.Compose(
    [utils.transforms.Resize((512, 512)), utils.transforms.ToTensor()]
)

train_set = CityscapesDataset("train", img_root, mask_root, transform=transform)
val_set = CityscapesDataset("val", img_root, mask_root, transform=transform)

batch_size = 4
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 20  # 0 -> 19
ignore_idx = train_set.ignoreId  # ignore 19
learning_rate = 0.001
model = DeepLabv3Plus(num_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=ignore_idx)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [4]:
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [5]:
def evaluate_performance(epoch, model, criterion, data_loader, num_classes, device):
    model.eval()

    metrics = SegmentationMetrics(num_classes=num_classes, ignore_idx=ignore_idx)
    total_loss = 0
    # n_batches = len(data_loader)
    # n_samples = len(data_loader.dataset)
    # preds = torch.zeros((n_samples, 512, 512))
    # gts = torch.zeros((n_samples, 512, 512))

    with torch.no_grad():
        for i, sample in enumerate(tqdm(data_loader)):
            images = sample["image"].to(device)
            masks = sample["mask"].to(device)

            pred = model(images)
            loss = criterion(pred, masks)
            total_loss += loss.item()

            pred_cls = torch.argmax(pred, dim=1)
            metrics.update(pred_cls, masks)

    ious, mIoU = metrics.iou()
    return ious, mIoU

In [6]:
ious, mIoU = evaluate_performance(32, model, criterion, val_loader, 20, device)
print(ious)
print(mIoU)

100%|██████████| 125/125 [01:40<00:00,  1.24it/s]

[0.76731085 0.5687753  0.74858997 0.1736465  0.13302853 0.24040474
 0.1629812  0.36271731 0.82386181 0.41362495 0.81984696 0.40960035
 0.02410235 0.78240072 0.18086475 0.41780845 0.19681325 0.04706755
 0.34496439 0.        ]
0.40096894348639733





In [7]:
ious

array([0.76731085, 0.5687753 , 0.74858997, 0.1736465 , 0.13302853,
       0.24040474, 0.1629812 , 0.36271731, 0.82386181, 0.41362495,
       0.81984696, 0.40960035, 0.02410235, 0.78240072, 0.18086475,
       0.41780845, 0.19681325, 0.04706755, 0.34496439, 0.        ])