In [1]:
import sys
sys.path.append('..')

In [2]:
import numpy as np
import torch
import torch.nn as nn
import collections
import torchvision
import torchvision.transforms as transforms

from PIL import Image
from pathlib import Path
from catalyst.data.augmentor import Augmentor
from catalyst.utils.factory import UtilsFactory
from catalyst.models.segmentation import UNet

from varian.models.segnet import SegNet

In [3]:
np.random.seed(1488)
n_images = 500
bs = 4
n_workers = 4

In [4]:
images = [np.random.uniform(high=256, size=(512, 512)).astype(np.uint8)
          for x in range(n_images)]
masks = [(np.random.uniform(high=1, size=(512, 512)) > 0.5).astype(float)
         for x in range(n_images)]
data = list(zip(images, masks))

train_data = data[:n_images // 2]
valid_data = data[n_images // 2:]

In [5]:
# # Saving mock data
# data_dir = Path('..') / 'data'

# np.save(data_dir / 'mock_data_images', np.stack(images))
# np.save(data_dir / 'mock_data_masks', np.stack(masks))

# (data_dir / 'mock_data_images').mkdir(exist_ok=True)
# (data_dir / 'mock_data_images' / 'train').mkdir(exist_ok=True)
# (data_dir / 'mock_data_images' / 'valid').mkdir(exist_ok=True)

# (data_dir / 'mock_data_masks').mkdir(exist_ok=True)
# (data_dir / 'mock_data_masks' / 'train').mkdir(exist_ok=True)
# (data_dir / 'mock_data_masks' / 'valid').mkdir(exist_ok=True)

# for i, (image, mask) in enumerate(train_data):
#     Image.fromarray(image).save(data_dir / 'mock_data_images' / 'train' / (str(i) + '.tiff'))
#     Image.fromarray(mask).save(data_dir / 'mock_data_masks' / 'train' / (str(i) + '.tiff'))
    
# for i, (image, mask) in enumerate(valid_data):
#     Image.fromarray(image).save(data_dir / 'mock_data_images' / 'valid' / (str(i) + '.tiff'))
#     Image.fromarray(mask).save(data_dir / 'mock_data_masks' / 'valid' / (str(i) + '.tiff'))

In [6]:
data_transform = transforms.Compose([
    # TODO specify augmentations (e.g. histogram normalization)
    Augmentor(
        dict_key="features",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy().astype(np.float32) / 256.).unsqueeze_(0).float()),
    Augmentor(
        dict_key="features",
        augment_fn=lambda x: np.vstack([x, x, x])),
    Augmentor(
        dict_key="targets",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy()).unsqueeze_(0).float()),
])

open_fn = lambda x: {"features": x[0], "targets": x[1]}


train_loader = UtilsFactory.create_loader(
    train_data, 
    open_fn=open_fn, 
    dict_transform=data_transform, 
    batch_size=bs, 
    workers=n_workers, 
    shuffle=True)

valid_loader = UtilsFactory.create_loader(
    valid_data, 
    open_fn=open_fn, 
    dict_transform=data_transform, 
    batch_size=bs, 
    workers=n_workers, 
    shuffle=False)

loaders = collections.OrderedDict()
loaders["train"] = train_loader
loaders["valid"] = valid_loader

In [7]:
model = SegNet(num_classes=1)

In [8]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# scheduler = None  # for OneCycle usage
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 40], gamma=0.3)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

In [9]:
from catalyst.dl.callbacks import (
    ClassificationLossCallback, 
    BaseMetrics, Logger, TensorboardLogger,
    OptimizerCallback, SchedulerCallback, CheckpointCallback, 
    PrecisionCallback, OneCycleLR)

n_epochs = 50
logdir = "./logs/segmentation_notebook"

callbacks = collections.OrderedDict()

callbacks["loss"] = ClassificationLossCallback()
callbacks["optimizer"] = OptimizerCallback()
callbacks["metrics"] = BaseMetrics()

# OneCylce custom scheduler callback
callbacks["scheduler"] = OneCycleLR(
    cycle_len=n_epochs,
    div=3, cut_div=4, momentum_range=(0.95, 0.85))

# Pytorch scheduler callback
# callbacks["scheduler"] = SchedulerCallback(
#     reduce_metric="loss_main")

callbacks["saver"] = CheckpointCallback()
callbacks["logger"] = Logger()
callbacks["tflogger"] = TensorboardLogger()

In [10]:
from catalyst.dl.runner import ClassificationRunner

runner = ClassificationRunner(
    model=model, 
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler)
runner.train_stage(
    loaders=loaders, 
    callbacks=callbacks, 
    logdir=logdir,
    epochs=n_epochs, verbose=True)

SegNet(
  (enc1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (enc2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (enc3): Sequential

0 * Epoch (train): 100% 63/63 [00:58<00:00,  1.08it/s, batch time=11.01256, data time=0.00289, loss_main=0.69993, lr_main=0.00039, momentum_main=0.94199, sample per second=0.36322]
0 * Epoch (valid): 100% 63/63 [00:12<00:00,  4.95it/s, batch time=0.12282, data time=0.00302, loss_main=0.69861, lr_main=0.00000, momentum_main=0.00000, sample per second=32.56858]
[2018-10-20 10:31:11,179] 0 * Epoch (train) metrics: data time: 0.00874 | batch time: 0.77687 | sample per second: 10.72851 | lr_main: 0.00036 | momentum_main: 0.94593 | loss_main: 0.70824
[2018-10-20 10:31:11,181] 0 * Epoch (valid) metrics: data time: 0.01236 | batch time: 0.20201 | sample per second: 20.57005 | lr_main: 0.00000 | momentum_main: 0.00000 | loss_main: 0.70087
[2018-10-20 10:31:11,182] 

1 * Epoch (train): 100% 63/63 [00:32<00:00,  1.92it/s, batch time=0.21946, data time=0.00299, loss_main=0.69725, lr_main=0.00044, momentum_main=0.93386, sample per second=18.22640]
1 * Epoch (valid): 100% 63/63 [00:12<00:00,  4.93it

KeyboardInterrupt: 