In [1]:
import numpy as np
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Compose, RandomChoice

from utils import LandCoverDataset, Resize, ToTensor, Normalize, BrightnessJitter, ContrastJitter, SaturationJitter, HueJitter, EarlyStopping, validation
from models import UNet

from tensorboardX import SummaryWriter

In [4]:
class Config():
    DATA_FOLDER = 'data'
    RESIZE = (224,224)
    BATCH_SIZE = 16
    TRAIN_SPLIT = .8
    VAL_TEST_SPLIT = .5
    SUFFLE_DATASET = True
    RANDOM_SEED = 2137
    LR = 0.001
    EPOCHS = 200
    PRINT_EVERY_START = 10
    PRINT_EVERY = 100
    MODEL_NAME = "final_nn_model.pt"

# Load data 

## Calculate mean and standard deviation for normalization

In [3]:
# set train, valid and test indexes
dataset = LandCoverDataset(root_dir=Config.DATA_FOLDER, transform=Resize(Config.RESIZE))
indexes = list(range(len(dataset)))
split_point = int(np.floor(Config.TRAIN_SPLIT * len(dataset)))
if Config.SUFFLE_DATASET:
    np.random.seed(Config.RANDOM_SEED)
    np.random.shuffle(indexes)
train_indexes, rest_indexes = indexes[:split_point], indexes[split_point:]
val_test_split_point = int(np.floor(Config.VAL_TEST_SPLIT * len(rest_indexes)))
valid_indexes, test_indexes = rest_indexes[:val_test_split_point], rest_indexes[val_test_split_point:]

# make dataset samplers
train_sampler = SubsetRandomSampler(train_indexes)
valid_sampler = SubsetRandomSampler(valid_indexes)
test_sampler = SubsetRandomSampler(test_indexes)

# train loader (for calculating normalize parameters)
loader = DataLoader(dataset=dataset, batch_size=Config.BATCH_SIZE, shuffle=False, sampler=train_sampler)

# batch means and stds
batch_means = []
batch_stds = []
for i, sample in enumerate(loader):
    images = sample['image']    
    batch_means.append(np.mean(images.numpy(), axis=(0,1,2))) # batch, height, width
    batch_stds.append(np.std(images.numpy(), axis=(0,1,2), ddof=1)) # batch, height, width

# overall mean and std per channel
means = np.array(batch_means).mean(axis=0)
stds = np.array(batch_stds).mean(axis=0)

print(f'Means: {means}\nStds:  {stds}')

Means: [107.20428231 115.20819438  91.24265174]
Stds:  [28.27333931 21.92629058 20.89825626]


## Prepare dataloaders

In [5]:
# transformations
train_transform = Compose([
    Resize((224, 224)),
    RandomChoice([
        BrightnessJitter(brightness=.25),
        ContrastJitter(contrast=.15),
        SaturationJitter(saturation=.15),
        HueJitter(hue=.1),
        ]),
    Normalize(mean=means, std=stds),
    ToTensor(),
])

val_test_transform = Compose([
    Resize((224, 224)),
    Normalize(mean=means, std=stds),
    ToTensor(),
])

# datasets (using samplers from previous step to create train/valid/test split)
train_dataset = LandCoverDataset(root_dir=Config.DATA_FOLDER, transform=train_transform)
train_dataset = Subset(dataset=train_dataset, indices=train_sampler.indices)

val_test_dataset = LandCoverDataset(root_dir=Config.DATA_FOLDER, transform=val_test_transform)
valid_dataset = Subset(dataset=val_test_dataset, indices=valid_sampler.indices)
test_dataset = Subset(dataset=val_test_dataset, indices=test_sampler.indices)


# dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=4)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=4)

# Model, loss, optimizer and other

## Model

In [6]:
writer = SummaryWriter(comment=Config.MODEL_NAME.split('.')[0])

In [7]:
model = UNet(in_channels=3, features=64, num_classes=24)
writer.add_graph(model, torch.randn((1,3,224,224)))
model

UNet(
  (encoder): Sequential(
    (0): Sequential(
      (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
    )
    (1): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
    )
    (2): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dil

## Loss and optimizer

In [8]:
criterion = nn.NLLLoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=Config.LR)

## Other helpers

In [9]:
early_stopping = EarlyStopping(patience=10, verbose=True, delta=0.001)

# Training

In [12]:
for epoch in range(Config.EPOCHS):
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    for i, batch in enumerate(train_loader):
        # forward and backward propagation
        images = batch['image'].float()
        labels = batch['label'].long()
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # save results
        running_loss += loss.item()
        _, predicted = torch.max(outputs, dim=1)
        acc = (labels == predicted).sum().item()/(Config.BATCH_SIZE * Config.RESIZE[0] * Config.RESIZE[1])
        running_acc += acc
        # tensorboardx
        writer.add_scalar('running_loss', loss.item(), (epoch+1)*(i+1))
        writer.add_scalar('running_acc', acc, (epoch+1)*(i+1))

        if (i != 0 and i % Config.PRINT_EVERY_START == 0) or (i != 0 and epoch > 3 and i % Config.PRINT_EVERY == 0):
            stats = f'\rEpoch: {epoch+1}/{Config.EPOCHS}, batch: {i}/{int(np.ceil(len(train_loader)/Config.BATCH_SIZE))}, ' \
                    f'train_loss: {running_loss/i:.5f}, train_acc: {running_acc/i:.4f}'
            print(stats, end='', flush=True)

    # calculcate loss and accuracy on validation dataset
    with torch.no_grad():
        val_loss, val_acc = validation(valid_loader, model, criterion, Config.BATCH_SIZE, Config.RESIZE)
    stats = f'Epoch: {epoch+1}/{Config.EPOCHS}, train_loss: {running_loss/i:.5f}, valid_loss: {val_loss:.5f}, ' \
            f'train_acc: {running_acc/i:.4f}, valid_acc: {val_acc:.4f}'
    print(stats)

    # tensorboardx
    writer.add_scalar('validation_loss', val_loss, epoch+1)
    writer.add_scalar('validation_acc', val_acc, epoch+1)
    
    # check for early stopping
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print('Early stopping.')
        break

model.load_state_dict(torch.load('early_stopping_checkpoint.pt'))
torch.save(model.state_dict(), Config.MODEL_NAME)
print('Finished training')

Epoch: 1/200, batch: 20/17280, train_loss: 3.39895, train_acc: 0.0844

KeyboardInterrupt: 