In [130]:
import numpy as np
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Compose, RandomChoice

from utils import LandCoverDataset, Resize, ToTensor, Normalize, BrightnessJitter, ContrastJitter, SaturationJitter, HueJitter
from models import UNet

from tensorboardX import SummaryWriter
writer = SummaryWriter()

In [131]:
class Config():
    DATA_FOLDER = 'data'
    BATCH_SIZE = 16
    TRAIN_SPLIT = .8
    VAL_TEST_SPLIT = .5
    SUFFLE_DATASET = True
    RANDOM_SEED = 2137

# Load data 

## Calculate mean and standard deviation for normalization

In [132]:
# set train, valid and test indexes
dataset = LandCoverDataset(root_dir=Config.DATA_FOLDER, transform=Resize((224,224)))
indexes = list(range(len(dataset)))
split_point = int(np.floor(Config.TRAIN_SPLIT * len(dataset)))
if Config.SUFFLE_DATASET:
    np.random.seed(Config.RANDOM_SEED)
    np.random.shuffle(indexes)
train_indexes, rest_indexes = indexes[:split_point], indexes[split_point:]
val_test_split_point = int(np.floor(Config.VAL_TEST_SPLIT * len(rest_indexes)))
valid_indexes, test_indexes = rest_indexes[:val_test_split_point], rest_indexes[val_test_split_point:]

# make dataset samplers
train_sampler = SubsetRandomSampler(train_indexes)
valid_sampler = SubsetRandomSampler(valid_indexes)
test_sampler = SubsetRandomSampler(test_indexes)

# train loader (for calculating normalize parameters)
loader = DataLoader(dataset=dataset, batch_size=Config.BATCH_SIZE, shuffle=False, sampler=train_sampler)

# batch means and stds
batch_means = []
batch_stds = []
for i, sample in enumerate(loader):
    images = sample['image']    
    batch_means.append(np.mean(images.numpy(), axis=(0,1,2))) # batch, height, width
    batch_stds.append(np.std(images.numpy(), axis=(0,1,2), ddof=1)) # batch, height, width

# overall mean and std per channel
means = np.array(batch_means).mean(axis=0)
stds = np.array(batch_stds).mean(axis=0)

print(f'Means: {means}\nStds:  {stds}')

Means: [107.20428231 115.20819438  91.24265174]
Stds:  [28.28229905 21.92473988 20.90197824]


## Prepare dataloaders

In [129]:
# transformations
train_transform = Compose([
    Resize((224, 224)),
    Normalize(mean=means, std=stds),
    RandomChoice([
        BrightnessJitter(brightness=.25),
        ContrastJitter(contrast=.15),
        SaturationJitter(saturation=.15),
        HueJitter(hue=.1),
        ]),
    ToTensor(),
])

val_test_transform = Compose([
    Resize((224, 224)),
    Normalize(mean=means, std=stds),
    ToTensor(),
])

# datasets (using samplers from previous step to create train/valid/test split)
train_dataset = LandCoverDataset(root_dir=Config.DATA_FOLDER, transform=train_transform)
train_dataset = Subset(dataset=train_dataset, indices=train_sampler.indices)

val_test_dataset = LandCoverDataset(root_dir=Config.DATA_FOLDER, transform=val_test_transform)
valid_dataset = Subset(dataset=val_test_dataset, indices=valid_sampler.indices)
test_dataset = Subset(dataset=val_test_dataset, indices=test_sampler.indices)


# dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=4)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=4)