In [1]:
import json
import numpy as np
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Compose, RandomChoice

from utils import LandCoverDataset, Resize, ToTensor, Normalize, BrightnessJitter, ContrastJitter, SaturationJitter, HueJitter, EarlyStopping, validation
from models import UNet

In [2]:
class Config():
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    DATA_FOLDER = 'data'
    RESIZE = (224,224)
    BATCH_SIZE = 16
    TRAIN_SPLIT = .8
    VAL_TEST_SPLIT = .5
    SUFFLE_DATASET = True
    RANDOM_SEED = 2137
    try:
        with open('norm_params.json', 'r') as f:
            d = json.load(f)
            MEANS = np.array(d['means'])
            STDS = np.array(d['stds'])
    except:
        MEANS, STDS = None, None
    LR = 0.001
    EPOCHS = 200
    PRINT_EVERY_START = 10
    PRINT_EVERY = 100
    MODEL_NAME = "final_nn_model.pt"
    IN_CHANNELS=3
    FEATURES=64
    NUM_CLASSES=24

In [3]:
print(f'Using: {Config.DEVICE}')

Using: cuda


# Load data 

## Calculate mean and standard deviation for normalization

In [4]:
# set train, valid and test indexes
dataset = LandCoverDataset(root_dir=Config.DATA_FOLDER, transform=Resize(Config.RESIZE))
indexes = list(range(len(dataset)))
split_point = int(np.floor(Config.TRAIN_SPLIT * len(dataset)))
if Config.SUFFLE_DATASET:
    np.random.seed(Config.RANDOM_SEED)
    np.random.shuffle(indexes)
train_indexes, rest_indexes = indexes[:split_point], indexes[split_point:]
val_test_split_point = int(np.floor(Config.VAL_TEST_SPLIT * len(rest_indexes)))
valid_indexes, test_indexes = rest_indexes[:val_test_split_point], rest_indexes[val_test_split_point:]

# make dataset samplers
train_sampler = SubsetRandomSampler(train_indexes)
valid_sampler = SubsetRandomSampler(valid_indexes)
test_sampler = SubsetRandomSampler(test_indexes)

# train loader (for calculating normalize parameters)
loader = DataLoader(dataset=dataset, batch_size=Config.BATCH_SIZE, shuffle=False, sampler=train_sampler)

# batch means and stds
if Config.MEANS is None or Config.STDS is None:
    batch_means = []
    batch_stds = []
    for i, sample in enumerate(loader):
        images = sample['image']    
        batch_means.append(np.mean(images.numpy(), axis=(0,1,2))) # batch, height, width
        batch_stds.append(np.std(images.numpy(), axis=(0,1,2), ddof=1)) # batch, height, width

    # overall mean and std per channel
    Config.MEANS = np.array(batch_means).mean(axis=0)
    Config.STDS = np.array(batch_stds).mean(axis=0)

    # save to file
    with open('norm_params.json', 'w') as f:
        json.dump({'means': Config.MEANS.tolist(), 'stds': Config.STDS.tolist()}, f)

print(f'Means: {Config.MEANS}\nStds:  {Config.STDS}')

Means: [107.20428231 115.20819438  91.24265174]
Stds:  [28.26958329 21.92561284 20.89934165]


## Prepare dataloaders

In [5]:
# transformations
train_transform = Compose([
    Resize((224, 224)),
    RandomChoice([
        BrightnessJitter(brightness=.25),
        ContrastJitter(contrast=.15),
        SaturationJitter(saturation=.15),
        HueJitter(hue=.1),
        ]),
    Normalize(mean=Config.MEANS, std=Config.STDS),
    ToTensor(),
])

val_test_transform = Compose([
    Resize((224, 224)),
    Normalize(mean=Config.MEANS, std=Config.STDS),
    ToTensor(),
])

# datasets (using samplers from previous step to create train/valid/test split)
train_dataset = LandCoverDataset(root_dir=Config.DATA_FOLDER, transform=train_transform)
train_dataset = Subset(dataset=train_dataset, indices=train_sampler.indices)

val_test_dataset = LandCoverDataset(root_dir=Config.DATA_FOLDER, transform=val_test_transform)
valid_dataset = Subset(dataset=val_test_dataset, indices=valid_sampler.indices)
test_dataset = Subset(dataset=val_test_dataset, indices=test_sampler.indices)


# dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=4)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=4)

# Model, loss, optimizer and other

## Model

In [6]:
model = UNet(in_channels=Config.IN_CHANNELS, features=Config.FEATURES, num_classes=Config.NUM_CLASSES)
model.to(Config.DEVICE)
model

UNet(
  (encoder): Sequential(
    (0): Sequential(
      (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
    )
    (1): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
    )
    (2): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dil

## Loss and optimizer

In [7]:
criterion = nn.NLLLoss(reduction='mean').to(Config.DEVICE)
optimizer = optim.Adam(model.parameters(), lr=Config.LR)

## Other helpers

In [8]:
early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.001)

# Training

In [9]:
# start from last checkpoint
try:
    checkpoint = torch.load('unet_after_epoch.pt')
    model.load_state_dict(checkpoint['state_dict'])
    next_epoch = checkpoint['epoch'] + 1
except:
    next_epoch = 0

# train
for epoch in range(next_epoch, Config.EPOCHS):
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    for i, batch in enumerate(train_loader):
        # forward and backward propagation
        images = batch['image'].float().to(Config.DEVICE)
        labels = batch['label'].long().to(Config.DEVICE)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # save results
        running_loss += loss.item()
        _, predicted = torch.max(outputs, dim=1)
        acc = (labels == predicted).sum().item()/(Config.BATCH_SIZE * Config.RESIZE[0] * Config.RESIZE[1])
        running_acc += acc

        if (i > 0 and i % Config.PRINT_EVERY_START == 0) or (i != 0 and epoch > 3 and i % Config.PRINT_EVERY == 0):
            stats = f'Epoch: {epoch+1}/{Config.EPOCHS}, batch: {i}/{len(train_loader)}, ' \
                    f'train_loss: {running_loss/i:.5f}, train_acc: {running_acc/i:.4f}'
            print('\r'+stats, end='', flush=True)
            with open('stats.log', 'a') as f:
                print(stats, file=f)


    # calculcate loss and accuracy on validation dataset
    with torch.no_grad():
        val_loss, val_acc = validation(valid_loader, model, criterion, Config.DEVICE, Config.BATCH_SIZE, Config.RESIZE)
    stats = f'Epoch: {epoch+1}/{Config.EPOCHS}, train_loss: {running_loss/i:.5f}, valid_loss: {val_loss:.5f}, ' \
            f'train_acc: {running_acc/i:.4f}, valid_acc: {val_acc:.4f}'
    print('\r'+stats)
    with open('stats.log', 'a') as f:
        print(stats, file=f)

    # save after each epoch
    torch.save({
        'in_channels': Config.IN_CHANNELS, 
        'features': Config.FEATURES, 
        'num_classes': Config.NUM_CLASSES,
        'state_dict': model.state_dict(),
        'epoch': epoch,
    }, 'unet_after_epoch.pt')
    
    # check for early stopping
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print('Early stopping.')
        break

model.load_state_dict(torch.load('early_stopping_checkpoint.pt'))
torch.save({
        'in_channels': Config.IN_CHANNELS, 
        'features': Config.FEATURES, 
        'num_classes': Config.NUM_CLASSES,
        'state_dict': model.state_dict()
    }, Config.MODEL_NAME)
print('\nFinished training')

Epoch: 1/200, train_loss: 1.88841, valid_loss: 1.45123, train_acc: 0.4933, valid_acc: 0.5665
Validation loss decreased (inf --> 1.451230).  Saving model ...
Epoch: 2/200, train_loss: 1.49527, valid_loss: 1.33700, train_acc: 0.5412, valid_acc: 0.5808
Validation loss decreased (1.451230 --> 1.336995).  Saving model ...
Epoch: 3/200, train_loss: 1.38385, valid_loss: 1.27715, train_acc: 0.5610, valid_acc: 0.5896
Validation loss decreased (1.336995 --> 1.277152).  Saving model ...
Epoch: 4/200, train_loss: 1.31374, valid_loss: 1.37288, train_acc: 0.5795, valid_acc: 0.5354
EarlyStopping counter: 1 out of 5
Epoch: 5/200, train_loss: 1.26705, valid_loss: 1.36046, train_acc: 0.5915, valid_acc: 0.5674
EarlyStopping counter: 2 out of 5


Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/opt/conda/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/conda/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/opt/conda/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 