# Since the evaluation method for this competition is "instance segmentation", this notebook cannot be used directly for submission.
Please use this notebook as a reference for semantic segmentation.

The codes in this notebook refer to https://github.com/YutaroOgawa/pytorch_advanced/tree/master/3_semantic_segmentation, https://www.kaggle.com/inversion/run-length-decoding-quick-start and https://www.kaggle.com/ihelon/cell-segmentation-run-length-decoding

Please upvote the notebooks.

Copyright (c) 2019 Yutaro Ogawa

Released under the MIT license https://github.com/YutaroOgawa/pytorch_advanced/blob/master/LICENSE

# Training notebook is [here](https://www.kaggle.com/kurokia/semantic-segmentation-by-pspnet-train).

In [None]:
from PIL import Image, ImageOps
import cv2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch

import os
import sys

In [None]:
os.mkdir('./utils')
sys.path.append('./utils')

In [None]:
from shutil import copyfile
copyfile(src = "../input/utils-inf/data_augumentation.py", dst = "./utils/data_augumentation.py")
copyfile(src = "../input/utils-inf/dataloader.py", dst = "./utils/dataloader.py")
copyfile(src = "../input/utils-inf/pspnet.py", dst = "./utils/pspnet.py")

from dataloader import make_datapath_list, DataTransform

In [None]:
rootpath = "../input/sartorius-cell-instance-segmentation/"
val_anno_list = make_datapath_list(
    rootpath=rootpath)

In [None]:
from pspnet import PSPNet

net = PSPNet(n_classes=1)

state_dict = torch.load("../input/weights/pspnet50_40.pth",
                        map_location={'cuda:0': 'cpu'})
net.load_state_dict(state_dict)

In [None]:
sample_sub_df = pd.read_csv('../input/sartorius-cell-instance-segmentation/sample_submission.csv')
sample_sub_df

In [None]:
test_id_list = list(sample_sub_df['id'].values)
test_id_list

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height, width, channels) of array to return 
    color: color for the mask
    Returns numpy array (mask)

    '''
    s = mask_rle.split()
    
    starts = list(map(lambda x: int(x) - 1, s[0::2]))
    lengths = list(map(int, s[1::2]))
    ends = [x + y for x, y in zip(starts, lengths)]
    
    img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.float32)
            
    for start, end in zip(starts, ends):
        img[start : end] = color
    
    return img.reshape(shape)

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
color_mean = (1.0, 1.0, 1.0)
color_std = (1.0, 1.0, 1.0)

In [None]:
def make_test_mask(test_id):

    image_file_path = f"../input/sartorius-cell-instance-segmentation/test/{test_id}.png"

    img = Image.open(image_file_path)
    img = img.convert("RGB")
    img_width, img_height = img.size

    transform = DataTransform(
        input_size=520, color_mean=color_mean, color_std=color_std)

    anno_file_path = val_anno_list[0]
    anno_class_img = Image.open(anno_file_path)  
    anno_class_img = anno_class_img.convert("L")
    anno_class_img = ImageOps.invert(anno_class_img)
    anno_class_img = anno_class_img.quantize()
    p_palette = anno_class_img.getpalette()
    phase = "val"
    img, anno_class_img = transform(phase, img, anno_class_img)

    net.eval()
    x = img.unsqueeze(0) 
    outputs = net(x)
    y = outputs[0]

    y = y.detach().numpy()[0][0]
    anno_class_img = Image.fromarray(np.uint8(y), mode="P")
    anno_class_img = anno_class_img.resize((img_width, img_height), Image.NEAREST)
    anno_class_img.putpalette(p_palette)

    anno_class_img = anno_class_img.convert('I')
    n = np.array(anno_class_img).astype(np.uint8)
    n = np.clip(n, 0, 1)

    test_mask = rle_encode(n)

    sample_sub_df.loc[sample_sub_df['id'] == test_id, "predicted"] = test_mask
    
    return

In [None]:
for test_id in test_id_list:
    make_test_mask(test_id)

In [None]:
sample_sub_df.to_csv('submission.csv', index=False)

In [None]:
def plot_masks(image_id, colors=True):
    labels = sample_sub_df[sample_sub_df["id"] == image_id]["predicted"].tolist()

    if colors:
        mask = np.zeros((520, 704, 3))
        for label in labels:
            mask += rle_decode(label, shape=(520, 704, 3), color=np.random.rand(3))
    else:
        mask = np.zeros((520, 704, 1))
        for label in labels:
            mask += rle_decode(label, shape=(520, 704, 1))
    mask = mask.clip(0, 1)

    image = cv2.imread(f"../input/sartorius-cell-instance-segmentation/test/{image_id}.png")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(16, 32))
    plt.subplot(3, 1, 1)
    plt.imshow(image)
    plt.axis("off")
    plt.subplot(3, 1, 2)
    plt.imshow(image)
    plt.imshow(mask, alpha=0.5)
    plt.axis("off")
    plt.subplot(3, 1, 3)
    plt.imshow(mask)
    plt.axis("off")
    
    plt.show();

In [None]:
plot_masks("7ae19de7bc2a", colors=False)