In [1]:
import sys
sys.path.append('..')

In [2]:
import numpy as np
import torch
import torch.nn as nn
import collections
import torchvision
import torchvision.transforms as transforms

from PIL import Image
from pathlib import Path
from catalyst.data.augmentor import Augmentor
from catalyst.utils.factory import UtilsFactory
from catalyst.models.segmentation import UNet

In [3]:
np.random.seed(1488)
n_images = 500
bs = 4
n_workers = 4

In [4]:
images = [np.random.uniform(high=256, size=(512, 512)).astype(np.uint8)
          for x in range(n_images)]
masks = [(np.random.uniform(high=1, size=(512, 512)) > 0.5).astype(float)
         for x in range(n_images)]
data = list(zip(images, masks))

train_data = data[:n_images // 2]
valid_data = data[n_images // 2:]

In [7]:
# Saving mock data
data_dir = Path('..') / 'data'

np.save(data_dir / 'mock_data_images', np.stack(images))
np.save(data_dir / 'mock_data_masks', np.stack(masks))

(data_dir / 'mock_data_images').mkdir(exist_ok=True)
(data_dir / 'mock_data_images' / 'train').mkdir(exist_ok=True)
(data_dir / 'mock_data_images' / 'valid').mkdir(exist_ok=True)

(data_dir / 'mock_data_masks').mkdir(exist_ok=True)
(data_dir / 'mock_data_masks' / 'train').mkdir(exist_ok=True)
(data_dir / 'mock_data_masks' / 'valid').mkdir(exist_ok=True)

for i, (image, mask) in enumerate(train_data):
    Image.fromarray(image).save(data_dir / 'mock_data_images' / 'train' / (str(i) + '.tiff'))
    Image.fromarray(mask).save(data_dir / 'mock_data_masks' / 'train' / (str(i) + '.tiff'))
    
for i, (image, mask) in enumerate(valid_data):
    Image.fromarray(image).save(data_dir / 'mock_data_images' / 'valid' / (str(i) + '.tiff'))
    Image.fromarray(mask).save(data_dir / 'mock_data_masks' / 'valid' / (str(i) + '.tiff'))

In [8]:
data_transform = transforms.Compose([
    # TODO specify augmentations (e.g. histogram normalization)
    Augmentor(
        dict_key="features",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy().astype(np.float32) / 256.).unsqueeze_(0).float()),
    Augmentor(
        dict_key="targets",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy()).unsqueeze_(0).float()),
])

open_fn = lambda x: {"features": x[0], "targets": x[1]}


train_loader = UtilsFactory.create_loader(
    train_data, 
    open_fn=open_fn, 
    dict_transform=data_transform, 
    batch_size=bs, 
    workers=n_workers, 
    shuffle=True)

valid_loader = UtilsFactory.create_loader(
    valid_data, 
    open_fn=open_fn, 
    dict_transform=data_transform, 
    batch_size=bs, 
    workers=n_workers, 
    shuffle=False)

loaders = collections.OrderedDict()
loaders["train"] = train_loader
loaders["valid"] = valid_loader

In [9]:
model = UNet(in_channels=1)

In [10]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# scheduler = None  # for OneCycle usage
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 40], gamma=0.3)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

In [11]:
from catalyst.dl.callbacks import (
    ClassificationLossCallback, 
    BaseMetrics, Logger, TensorboardLogger,
    OptimizerCallback, SchedulerCallback, CheckpointCallback, 
    PrecisionCallback, OneCycleLR)

n_epochs = 50
logdir = "./logs/segmentation_notebook"

callbacks = collections.OrderedDict()

callbacks["loss"] = ClassificationLossCallback()
callbacks["optimizer"] = OptimizerCallback()
callbacks["metrics"] = BaseMetrics()

# OneCylce custom scheduler callback
callbacks["scheduler"] = OneCycleLR(
    cycle_len=n_epochs,
    div=3, cut_div=4, momentum_range=(0.95, 0.85))

# Pytorch scheduler callback
# callbacks["scheduler"] = SchedulerCallback(
#     reduce_metric="loss_main")

callbacks["saver"] = CheckpointCallback()
callbacks["logger"] = Logger()
callbacks["tflogger"] = TensorboardLogger()

In [None]:
from catalyst.dl.runner import ClassificationRunner

runner = ClassificationRunner(
    model=model, 
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler)
runner.train_stage(
    loaders=loaders, 
    callbacks=callbacks, 
    logdir=logdir,
    epochs=n_epochs, verbose=True)

UNet(
  (encoder): Encoder(
    (block1): EncoderBlock(
      (block): Sequential(
        (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu1): ReLU()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu2): ReLU()
      )
    )
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (block2): EncoderBlock(
      (block): Sequential(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu1): ReLU()
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu2): ReLU()
      )
    )
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (block3): EncoderBlock(
      (block): Sequential(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu1): ReLU()
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1,

0 * Epoch (train):  43% 27/63 [00:20<00:27,  1.30it/s, batch time=0.13417, data time=0.00308, loss_main=0.69315, lr_main=0.00036, momentum_main=0.94657, sample per second=29.81275]