In [1]:
import dataset_utils
import os
import random
import torch
import matplotlib.pyplot as plt
import numpy as np
import segmentation_models_pytorch as smp
from config import cfg
from datasets import LandCoverDataset
from models.deeplab import get_model as get_deeplab_model
from PIL import Image
from torch.utils.data import DataLoader, Subset
from utils import get_validation_augmentation, get_training_augmentation, get_preprocessing, save_history, save_model, reverse_one_hot, colour_code_segmentation, visualize

In [2]:
LANDCOVER_ROOT = '/root/deepglobe'
cfg.merge_from_file('cfg/deeplab_resnet50_advance_aug.yaml')

In [3]:
train_df, val_df = dataset_utils.get_landcover_train_val_df(LANDCOVER_ROOT, random_state=cfg.SEED)
dataset_info = dataset_utils.get_landcover_info(LANDCOVER_ROOT, include_unknow=False)
class_names = dataset_info['class_names']
class_rgb_values = dataset_info['class_rgb_values']
select_class_rgb_values = dataset_info['select_class_rgb_values']

In [5]:
device = torch.device('cuda:0')
num_classes = len(select_class_rgb_values)
model, preprocessing_fn = get_deeplab_model(num_classes, cfg.MODEL.encoder)
weight_path = '/root/rtml/project/weights/deeplabv3_resnet50_advance_aug/best_model.pth'
model.load_state_dict(torch.load(weight_path, map_location=device))
model.eval()

DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequentia

In [6]:
model = model.to(device)

In [7]:
# model = torch.load('best_model_max.pth', map_location=device)
# model.eval()

In [8]:
valid_dataset = LandCoverDataset(
    val_df,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
    return_path=True
)

In [9]:
val_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)

In [10]:
iou_func = smp.utils.metrics.IoU(threshold=0.5).to(device)

In [11]:
iou_list = []

with torch.no_grad():
    for image_tensor, mask_tensor, path in val_loader:
        image_tensor = image_tensor.to(device)
        mask_tensor = mask_tensor.to(device)
        pred = model(image_tensor)
        iou = iou_func(pred, mask_tensor).cpu().item()
        iou_list.append({'image_path': path, 'iou': iou})

In [13]:
mean = sum([i['iou'] for i in sorted(iou_list, key=lambda i: i['iou'], reverse=True)]) / len(iou_list)

In [14]:
mean

0.545038360320179