Научимся сегментировать изображения. Для этого вопосльзуемся библиотекой https://github.com/qubvel/segmentation_models.pytorch 

## Neural networks for segmentation

Скачаем данные, которые будем сегментировать

In [None]:
# ! wget https://www.dropbox.com/s/jy34yowcf85ydba/data.zip?dl=0 -O data.zip
# ! unzip -q data.zip

Нужно натренировать сеть, которая будет сегментировать границы клеток. Ниже пример входных данных и таргета.

In [None]:
import scipy as sp
import scipy.misc
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [None]:
def visualize(**images):
    """Plot images in one row."""
    n = len(images)
    plt.figure(figsize=(10, 8))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [None]:
# Human HT29 colon-cancer cells
image = plt.imread('BBBC018_v1_images-fixed/train/00735-actin.DIB.bmp')
mask = plt.imread('BBBC018_v1_outlines/train/00735-cells.png')
visualize(image=image, mask=mask)

In [None]:
# как подсчитывается метрика
def calc_iou(prediction, ground_truth):
    n_images = len(prediction)
    intersection, union = 0, 0
    for i in range(n_images):
        intersection += np.logical_and(prediction[i] > 0, ground_truth[i] > 0).astype(np.float32).sum() 
        union += np.logical_or(prediction[i] > 0, ground_truth[i] > 0).astype(np.float32).sum()
    return float(intersection) / union

In [None]:
# !pip install -U git+https://github.com/qubvel/segmentation_models.pytorch --user
# !pip install -U git+https://github.com/albu/albumentations --user

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
import cv2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

In [None]:
import albumentations as albu

# Data and Augmentations

In [None]:
DATA_DIR = './BBBC018_v1_images-fixed/'
MASK_DIR = './BBBC018_v1_outlines/'
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(MASK_DIR, 'train')

x_val_dir = os.path.join(DATA_DIR, 'val')
y_val_dir = os.path.join(MASK_DIR, 'val')

x_test_dir = os.path.join(DATA_DIR, 'test')

##### можно заметить, что не все каналы несут полезную информацию

In [None]:
image = plt.imread('BBBC018_v1_images-fixed/train/00735-actin.DIB.bmp')
mask = plt.imread('BBBC018_v1_outlines/train/00735-cells.png')
visualize(channel0=image[:, :, 0], channel1=image[:, :, 1], channel2=image[:, :, 2], mask=mask)

In [None]:
class MyDataset(Dataset):
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.images_ids = os.listdir(images_dir)
        self.masks_ids = os.listdir(masks_dir)
        assert len(self.images_ids) == len(self.masks_ids)
        
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.images_ids]
        self.masks_fps = [os.path.join(masks_dir, mask_id) for mask_id in self.masks_ids]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = np.repeat(image[..., 1][..., np.newaxis], 3, axis=2)
        mask = cv2.imread(self.masks_fps[i], 0)
        mask = mask[..., np.newaxis]
        mask = np.where(mask > 0., 1., 0.)
        
#         image = image /  255.
#         mask = mask / 255.
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.images_ids)

In [None]:
def get_training_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(p=0.5),
        albu.Transpose(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        albu.PadIfNeeded(min_height=512, min_width=512, always_apply=True, border_mode=0),
        albu.RandomCrop(height=512, width=512, always_apply=True),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(512, 512),
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

# Model

In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = smp.Unet(
    encoder_name='se_resnext50_32x4d', 
    classes=1,
    activation='sigmoid',
)

preprocessing_fn = smp.encoders.get_preprocessing_fn('se_resnext50_32x4d', 'imagenet')

In [None]:
train_dataset = MyDataset(
    x_train_dir, 
    y_train_dir, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
)

valid_dataset = MyDataset(
    x_val_dir, 
    y_val_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

In [None]:
for i in range(3):
    image, mask = train_dataset[1]
    visualize(image=image.transpose(1,2,0), mask=mask.transpose(1,2,0).squeeze())

In [None]:
loss = smp.utils.losses.BCEDiceLoss(eps=0.)
metrics = [
    smp.utils.metrics.IoUMetric(eps=0.), # the same as calc_iou
]

optimizer = torch.optim.Adam([
    {'params': model.decoder.parameters(), 'lr': 0.001, 'betas': (0.5, 0.999), 'amsgrad': True}, 
    
    # decrease lr for encoder in order not to permute 
    # pre-trained weights with large gradients on training start
    {'params': model.encoder.parameters(), 'lr': 0.001, 'betas': (0.5, 0.999), 'amsgrad': True},  
])

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-5, T_max=10)

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# train model for 200 epochs

try:
    model.load_state_dict(torch.load('./segmentation_model.pth'))
    model.to(device)
except FileNotFoundError:
    model.train()
    max_score = 0

    for i in range(0, 200):

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)

        if max_score < valid_logs['iou']:
            max_score = valid_logs['iou']
            torch.save(model, './segmentation_model.pth')
            print('Model saved!')

        lr_scheduler.step()

    model.eval()

In [None]:
test_dataset = MyDataset(
    x_test_dir, 
    y_val_dir, # just to create dataset
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
)

for i in range(len(test_dataset)):
    test_img, _ = test_dataset[i]
    test_mask = model.predict(torch.FloatTensor(test_img[None]).cuda()).squeeze().cpu().numpy()
    plt.imshow(test_mask)
    name = '{}-mask.jpg'.format(test_dataset.images_ids[i][:5])
    plt.title(name)
#     plt.savefig('./test_masks/'+name)
    plt.show()