In [1]:
import os
import argparse
from tqdm import tqdm
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from data_loader.cityscapes import CityscapesDataset
import utils.transforms
from model.deeplabv3plus import DeepLabv3Plus

from model.metric import iou

In [2]:
writer = SummaryWriter()
checkpoint_file = os.path.join('./checkpoints', 'best_model.pt')
img_root = './data/leftImg8bit'
mask_root = './data/gtFine'
transform = utils.transforms.Compose(
    [utils.transforms.Resize((512, 512)), utils.transforms.ToTensor()]
)

train_set = CityscapesDataset("train", img_root, mask_root, transform=transform)
val_set = CityscapesDataset("val", img_root, mask_root, transform=transform)

batch_size = 4
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 20  # 0 -> 19
ignore_idx = train_set.ignoreId  # ignore 19
learning_rate = 0.001
model = DeepLabv3Plus(num_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=ignore_idx)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [4]:
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [5]:
def evaluate_performance(epoch, model, criterion, data_loader, num_classes, device):
    model.eval()

    total_loss = 0
    n_batches = len(data_loader)
    n_samples = len(data_loader.dataset)
    preds = torch.zeros((n_samples, 512, 512))
    gts = torch.zeros((n_samples, 512, 512))

    with torch.no_grad():
        for i, sample in enumerate(tqdm(data_loader)):
            images = sample["image"].to(device)
            masks = sample["mask"].to(device)
            b_size = len(masks)

            pred = model(images)
            loss = criterion(pred, masks)
            total_loss += loss.item()

            gts[i : i + b_size, :, :] = masks.cpu().detach()
            pred = torch.argmax(pred, dim=1).cpu().detach()
            preds[i : i + b_size, :, :] = pred

    ious, mIoU = iou(preds, gts)
    return ious, mIoU

In [6]:
ious, mIoU = evaluate_performance(32, model, criterion, val_loader, 20, device)
print(ious)
print(mIoU)

100%|██████████| 125/125 [00:47<00:00,  2.61it/s]


torch.Size([131072000])
torch.Size([131072000])
[0.97133951 0.5742225  0.73734696 0.1426262  0.12089531 0.21887813
 0.14401666 0.37732962 0.82021093 0.43701157 0.79479526 0.38161614
 0.03652031 0.75085983 0.19854661 0.39220037 0.46746474 0.00454752
 0.31206627]
0.41486812731557604
