In [1]:
import dataset_utils
import os
import random
import torch
import matplotlib.pyplot as plt
import numpy as np
from config import cfg
from datasets import LandCoverDataset
from models.deeplab import get_model as get_deeplab_model
from PIL import Image
from torch.utils.data import DataLoader, Subset
from utils import get_validation_augmentation, get_training_augmentation, get_preprocessing, save_history, save_model, reverse_one_hot, colour_code_segmentation, visualize

In [2]:
LANDCOVER_ROOT = '/root/deepglobe'
cfg.merge_from_file('cfg/deeplab_resnet50_advance_aug.yaml')

In [3]:
train_df, val_df = dataset_utils.get_landcover_train_val_df(LANDCOVER_ROOT, random_state=cfg.SEED)
dataset_info = dataset_utils.get_landcover_info(LANDCOVER_ROOT, include_unknow=False)
class_names = dataset_info['class_names']
class_rgb_values = dataset_info['class_rgb_values']
select_class_rgb_values = dataset_info['select_class_rgb_values']

In [4]:
num_classes = len(select_class_rgb_values)
_, preprocessing_fn = get_deeplab_model(num_classes, cfg.MODEL.encoder)

In [5]:
device = torch.device('cuda:0')

In [6]:
valid_dataset = LandCoverDataset(
    val_df,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
    return_path=True
)

In [None]:
val_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)