## Satellite image segmentation

In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [1]:
import mxnet as mx
import mxnet.ndarray as nd
import mxnet.gluon as gluon
import mxnet.gluon.nn as nn

from mxnet.gluon.data import Dataset, DataLoader
from mxnet.gluon.loss import Loss
from mxnet import image

### Data preparation

In [None]:
from skimage.io import imsave, imread
from datetime import datetime

In [None]:
geopedia_layers = {'tulip_field_2016':'ttl1904', 'tulip_field_2017':'ttl1905'}

In [None]:
class ImageWithMaskDataset(dataset.Dataset):
    """
    A dataset for loading images (with masks).
    Based on: mxnet.incubator.apache.org/tutorials/python/data_augmentation_with_masks.html
    
    Parameters
    ----------
    root : str
        Path to root directory.
    transform : callable, default None
        A function that takes data and label and transforms them:
    ::
        transform = lambda data, label: (data.astype(np.float32)/255, label)
    """
    def __init__(self, root, transform=None):
        self._root = os.path.expanduser(root)
        self._transform = transform
        self._exts = ['.png']
        self._list_images(self._root)
        self._mask_fn = 

    def _list_images(self, root):
        images = collections.defaultdict(dict)
        for filename in sorted(os.listdir(root)):
            name, ext = os.path.splitext(filename)
            mask_flag = name.contains("geopedia")
            if ext.lower() not in self._exts:
                continue
            if not mask_flag:
                patch_id = filename.split('_')[1]
                year = datetime.strptime(filename.split('_')[3], "%Y%m%d-%H%M%S").year
                mask_fn = 'tulip_{}_geopedia_{}.png'.format(patch_id, geopedia_layers['tulip_field_{}'.format(year)])
                images[name]["base"] = filename
                images[name]["mask"] = mask_fn
        self._image_list = list(images.values())

    def __getitem__(self, idx):
        assert 'base' in self._image_list[idx], "Couldn't find base image for: " + image_list[idx]["mask"]
        base_filepath = os.path.join(self._root, self._image_list[idx]["base"])
        base = mx.image.imread(base_filepath)
        assert 'mask' in self._image_list[idx], "Couldn't find mask image for: " + image_list[idx]["base"]
        mask_filepath = os.path.join(self._root, self._image_list[idx]["mask"])
        mask = mx.image.imread(mask_filepath)
        if self._transform is not None:
            return self._transform(base, mask)
        else:
            return base, mask

    def __len__(self):
        return len(self._image_list)

In [None]:
def positional_augmentation(joint):
    # Random crop
    crop_height = 224
    crop_width = 224
    aug = mx.image.RandomCropAug(size=(crop_width, crop_height)) # Watch out: weight before height in size param!
    aug_joint = aug(joint)
    # Deterministic resize
    resize_size = 100
    aug = mx.image.ResizeAug(resize_size)
    aug_joint = aug(aug_joint)
    # Add more translation/scale/rotation augmentations here...
    return aug_joint


def color_augmentation(base):
    # Only applied to the base image, and not the mask layers.
    aug = mx.image.BrightnessJitterAug(brightness=0.2)
    aug_base = aug(base)
    # Add more color augmentations here...
    return aug_base


def joint_transform(base, mask):
    ### Convert types
    base = base.astype('float32')/255
    mask = mask.astype('float32')/255
    
    ### Join
    # Concatinate on channels dim, to obtain an 6 channel image
    # (3 channels for the base image, plus 3 channels for the mask)
    base_channels = base.shape[2] # so we know where to split later on
    joint = mx.nd.concat(base, mask, dim=2)

    ### Augmentation Part 1: positional
    aug_joint = positional_augmentation(joint)
    
    ### Split
    aug_base = aug_joint[:, :, :base_channels]
    aug_mask = aug_joint[:, :, base_channels:]
    
    ### Augmentation Part 2: color
    aug_base = color_augmentation(aug_base)

    return aug_base, aug_mask

###  U-Net

The IoU metric tends to have a "squaring" effect on the errors relative to the Dice score (aka F score). So the F score tends to measure something closer to average performance, while the IoU score measures something closer to the worst case performance.

https://stats.stackexchange.com/questions/273537/f1-dice-score-vs-iou

As we are not very confident on the quality of the training data/ground truth, lets go for the Dice coeff

In [None]:
def dice_coef(y_true, y_pred):
    intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=(1, 2, 3))
    return mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.sum(y_true, axis=(1, 2, 3)) + mx.sym.sum(y_pred, axis=(1, 2, 3)) + 1.))


def dice_coef_loss(y_true, y_pred):
    intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=1, )
    return -mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.broadcast_add(mx.sym.sum(y_true, axis=1), mx.sym.sum(y_pred, axis=1)) + 1.))

In [16]:
def conv_block(data, filters, kernel_size):
    '''
    Returns a convolutional block composed of a convolutional layer followed
    by batch normalization and a ReLu activation.
    '''
    conv = mx.sym.Convolution(data, num_filter=filters, kernel=(kernel_size,kernel_size), pad=(1,1))
    conv = mx.sym.BatchNorm(conv)
    conv = mx.sym.Activation(conv, act_type='relu')
    return conv

In [18]:
def down_block(data, filters):
    '''
    Returns two consecutive convolutional blocks
    '''
    out = conv_block(data, filters, 3)
    out = conv_block(out,  filters, 3)
    return out

In [None]:
def up_block(data, concat, filters):
    deconv = mx.sym.Deconvolution(data, num_filter=filters, kernel=(2,2), stride=(1,1), no_bias=True)
    out = mx.sym.concat(*[deconv, concat], dim=1)
    out = conv_block(out, filters, 3)
    out = conv_block(out, filters, 3)
    return out

In [19]:
def build_unet():
    # Inputs
    data = mx.sym.Variable(name='data')
    label = mx.sym.Variable(name='label')
    
    # Down blocks
    down1 = down_block(data, 32)
    pool1 = mx.sym.Pooling(down1, kernel=(2,2), pool_type='max')
    
    down2 = down_block(pool1, 64)
    pool2 = mx.sym.Pooling(down2, kernel=(2,2), pool_type='max')
    
    down3 = down_block(pool2, 128)
    pool3 = mx.sym.Pooling(down3, kernel=(2,2), pool_type='max')
    
    down4 = down_block(pool3, 256)
    pool4 = mx.sym.Pooling(down4, kernel=(2,2), pool_type='max')
    
    down5 = down_block(pool4, 512)
    
    # Up blocks
    up4 = up_block(down5,down4, 256)
    up3 = up_block(up4,  down3, 128)
    up2 = up_block(up3,  down2, 64)
    up1 = up_block(up1,  down1, 32)
    
    # Final layers
    conv = mx.sym.Convolution(up1, num_filter=1, kernel=(1,1), name='conv10_1')
    conv = mx.sym.sigmoid(conv, name='softmax')
    
    net = mx.sym.Flatten(conv)
    loss = mx.sym.MakeLoss(dice_coef_loss(label, net), normalization='batch')
    mask_output = mx.sym.BlockGrad(conv, 'mask')
    out = mx.sym.Group([loss, mask_output])
    
#     return mx.sym.Custom(net, pos_grad_scale = pos, neg_grad_scale = neg, name = 'softmax', op_type = 'weighted_logistic_regression')
#     return mx.sym.LogisticRegressionOutput(net, name='softmax')
    return out