# Import

In [None]:
# Installing segmentation_models_pytorch
!mkdir -p /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/efficientnet_pytorch-0.6.3.xyz /tmp/pip/cache/efficientnet_pytorch-0.6.3.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/pretrainedmodels-0.7.4.xyz /tmp/pip/cache/pretrainedmodels-0.7.4.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/segmentation-models-pytorch-0.1.2.xyz /tmp/pip/cache/segmentation_models_pytorch-0.1.2.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.1.20-py3-none-any.whl /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.2.1-py3-none-any.whl /tmp/pip/cache/
!pip install --no-index --find-links /tmp/pip/cache/ efficientnet-pytorch
!pip install --no-index --find-links /tmp/pip/cache/ segmentation-models-pytorch

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GroupKFold

from tqdm import tqdm
import os, gc
import random
from PIL import Image
import tifffile as tiff
import cv2
import zipfile
import collections
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from skimage import exposure
from bokeh.plotting import figure as bokeh_figure
from bokeh.io import output_notebook, show, output_file
from bokeh.models import ColumnDataSource, HoverTool, Panel
from bokeh.models.widgets import Tabs
from PIL import Image
from sklearn import preprocessing
from random import randint

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

from segmentation_models_pytorch.unet import Unet
from segmentation_models_pytorch.encoders import get_preprocessing_fn

import torchvision
from torchvision import transforms
from albumentations import *
import albumentations as A
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import CosineAnnealingLR

import warnings
warnings.filterwarnings("ignore")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
 
seed = 2020
seed_everything(seed)
sz = 512
NFOLDS = 5

#ImageNet
mean = np.array([[[0.485, 0.456, 0.406]]])
std = np.array([[[0.229, 0.224, 0.225]]])

# DataFrame

In [None]:
test_df = pd.read_csv('../input/sartorius-cell-instance-segmentation/sample_submission.csv')
test_df

# Dataset

In [None]:
class Sartorius_Seg_Dataset(Dataset):
    def __init__(self, df, preprocess_input=None, transform=None):
        self.df = df
        self.preprocess_input = preprocess_input
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_id = self.df.iloc[idx, 0]
        img = cv2.imread(f'../input/sartorius-cell-instance-segmentation/test/{image_id}.png').astype(np.float32)
        img = cv2.resize(img, (sz, sz)).astype(np.float32)
                
        if self.transform:
            img = self.transform(image=img)['image']
            
        if self.preprocess_input:
            img = self.preprocess_input(image=img)['image']
        
        img = img.transpose((2, 0, 1))
        img = torch.from_numpy(img)
            
        return img, image_id

# Model

In [None]:
ENCODER_NAME = 'efficientnet-b0'
preprocessing_fn = Lambda(image=get_preprocessing_fn(encoder_name=ENCODER_NAME, pretrained='imagenet'))

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = Unet(
            encoder_name='efficientnet-b0', 
            encoder_weights=None, 
            classes=1, 
            activation=None
        )

    def forward(self, images):
        img_masks = self.model(images)
        return img_masks

# Inference

In [None]:
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
TH = 0.5
test_ds = Sartorius_Seg_Dataset(df=test_df, preprocess_input=preprocessing_fn)
test_dl = DataLoader(dataset=test_ds, batch_size=16, shuffle=False, num_workers=2)
pred_dict = {}
seed_everything(seed)

for i, (img_btch, image_id) in enumerate(test_dl):
    img_btch = img_btch.to(device, dtype=torch.float)
    pred_mask_btch = 0
    
    for fold in range(NFOLDS):
        print(f'===============Fold:{fold}===============')
        model = Model()
        model.load_state_dict(torch.load(f'../input/sartorius-efnetb0-unet-ver001/fold_{fold}.pth'))
        model.to(device)
        model.eval()
        
        with torch.no_grad():
            pred_mask_btch += nn.Sigmoid()(model(img_btch.float())) / NFOLDS
            
    preds = (pred_mask_btch >= TH).cpu().numpy().astype(np.uint8)
    
    del img_btch, pred_mask_btch
    gc.collect()
            
    for b in range(preds.shape[0]):
        pred = preds[b]
        pred = cv2.resize(np.squeeze(pred), (704, 520), interpolation=cv2.INTER_NEAREST)
        pred = cv2.connectedComponents(pred.astype(np.uint8))[1]
                
        for p in range(1, pred.max()+1):
            pred2 = np.where(pred==p, 1, 0)
            x1 = np.min(np.where(pred2==1)[1])
            x2 = np.max(np.where(pred2==1)[1])
            y1 = np.min(np.where(pred2==1)[0])
            y2 = np.max(np.where(pred2==1)[0])
            if (x1 == x2) | (y1 == y2):
                continue 
            rle = rle_encode(pred2)
            pred_dict[f'{image_id[b]}_{p}'] = rle

In [None]:
test_df = pd.DataFrame.from_dict(pred_dict, orient='index').reset_index().rename(columns={'index': 'id', 0: 'predicted'})
test_df.id = test_df.id.apply(lambda x: x.split('_')[0])
test_df = test_df.sort_values('id').reset_index(drop=True)
test_df

In [None]:
test_df.to_csv('submission.csv', index=False)