In [None]:
!mkdir -p /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/efficientnet_pytorch-0.6.3.xyz /tmp/pip/cache/efficientnet_pytorch-0.6.3.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/pretrainedmodels-0.7.4.xyz /tmp/pip/cache/pretrainedmodels-0.7.4.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/segmentation-models-pytorch-0.1.2.xyz /tmp/pip/cache/segmentation_models_pytorch-0.1.2.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.1.20-py3-none-any.whl /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.2.1-py3-none-any.whl /tmp/pip/cache/
!pip install --no-index --find-links /tmp/pip/cache/ efficientnet-pytorch
!pip install --no-index --find-links /tmp/pip/cache/ segmentation-models-pytorch

# TODO
1. How to read the data?
2. How to crop the patches for testing?

# Necessary imports

In [None]:
import os
import cv2
import random
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold, KFold
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A

from skimage.color import label2rgb

import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
SEED = 421

In [None]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(SEED)

# Important variables

In [None]:
MODEL_PATH = "../input/kaggle-hubmap-segmentation-pytorch-training/unet-se_resnext50-cosineanneal-RES-256-best-FOLD-0-model.pth"

IMG_SIZE = 256
TRAIN_IMGS = '../input/hubmap-256x256/train/'
MASKS = '../input/hubmap-256x256/masks'
LABELS = '../input/hubmap-kidney-segmentation/train.csv'
NUM_WORKERS = 4

MODEL = 'unet-se_resnext50-cosineanneal'
ENCODER = 'se_resnext50_32x4d'
FOLD = 0
NFOLDS = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset

In [None]:
# https://www.kaggle.com/iafoss/256x256-images
MEAN = np.array([0.65459856,0.48386562,0.69428385])
STD = np.array([0.15167958,0.23584107,0.13146145])

In [None]:
def img2tensor(img,dtype:np.dtype=np.float32):
    # convert numpy image to Pytorch tensor image
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, fold=FOLD, train=True, preprocess_input=None, transforms=None):
        ids = pd.read_csv(LABELS).id.values
        kf = KFold(n_splits=NFOLDS,random_state=SEED,shuffle=True)
        ids = set(ids[list(kf.split(ids))[fold][0 if train else 1]])
        self.fnames = [fname for fname in os.listdir(TRAIN_IMGS) if fname.split('_')[0] in ids]
        self.train = train
        self.preprocess_input = preprocess_input
        self.transforms = transforms
        
    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.cvtColor(cv2.imread(os.path.join(TRAIN_IMGS, fname)), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(MASKS,fname),cv2.IMREAD_GRAYSCALE)
        
        if self.transforms:
            augmented = self.transforms(image=img,mask=mask)
            img,mask = augmented['image'],augmented['mask']
        
        if self.preprocess_input:
            # Normalizing the image with the given mean and std corresponding to each channel
            img = self.preprocess_input(image=img)['image']
        
        return img2tensor(img),img2tensor(mask)
    
    def __len__(self):
        return len(self.fnames)

# Transformations

In [None]:
# https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter#Data
def get_preprocess_fn(encoder_name, pretrained='imagenet'):
    return A.Lambda(image = get_preprocessing_fn(encoder_name = encoder_name, pretrained = pretrained))


def get_train_transform():
    return A.Compose([
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90(),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, 
                         border_mode=cv2.BORDER_REFLECT),
        A.OneOf([
            A.OpticalDistortion(p=0.3),
            A.GridDistortion(p=.1),
            A.IAAPiecewiseAffine(p=0.3),
        ], p=0.3),
        A.OneOf([
            A.HueSaturationValue(10,15,10),
            A.CLAHE(clip_limit=2),
            A.RandomBrightnessContrast(),            
        ], p=0.3)
    ],p=1.)

# def get_val_transform():
#     return A.Compose([
#         A.Resize(IMG_SIZE, IMG_SIZE,always_apply=True),
#     ],p=1.)

# Model
Put your model definition and loading weights here

In [None]:
class HuBMAP(nn.Module):
    def __init__(self):
        super(HuBMAP, self).__init__()
        # since this is a binary segmentation problem FTU or non-FTU so classes = 1
#         self.cnn_model = smp.Unet(encoder_name='se_resnext50_32x4d', encoder_weights='imagenet', classes=1, activation=None)
#         self.cnn_model = smp.FPN(encoder_name='se_resnext50_32x4d', encoder_weights='imagenet', classes=1, activation=None)

        self.cnn_model = smp.Unet(encoder_name=ENCODER, encoder_weights=None, classes=1, activation=None)
#         self.cnn_model = smp.Unet(encoder_name='resnet34', encoder_weights='imagenet', classes=1, activation=None)
        
    def forward(self, imgs):
        img_segs = self.cnn_model(imgs)
        return img_segs

In [None]:
def visualise_predictions(image_patches, pred_masks, gt_masks, figsize=(25, 25)):
    assert image_patches.shape == pred_masks.shape and pred_masks.shape == gt_masks.shape, "image patches and masks should be of the same shape"
    num_patches = image_patches.shape[0]
    grid_dim = int(sqrt(num_patches))
    
    fig, axs = plt.subplots(grid_dim, grid_dim, figsize=figsize)
    
    for i, (img_patch, pred_mask, gt_mask) in enumerate(zip(image_patches, pred_masks, gt_masks)):
        x = i // ndim
        y = i % ndim
        
        # plot the GT mask on the image
        img_data = label2rgb(image=img_patch, label=gt_mask, bg_label=0, alpha=0.2)
        
        # plot the predicted mask on the image
        img_data = label2rgb(image=img_data, label=pred_mask, bg_label=0, alpha=0.2)
        
        axs[x, y].imshow(img_data)

# Make the prediction

In [None]:
# create dataloader

ds = HuBMAPDataset(preprocess_input=get_preprocess_fn(encoder_name=ENCODER), transforms=None)
dl = DataLoader(ds,batch_size=25,shuffle=False,num_workers=NUM_WORKERS)
imgs,masks = next(iter(dl))

# instantiate model and load weights
model = HuBMAP()
model.load_state_dict(torch.load(MODEL_PATH))
model.to(DEVICE)
model.eval()


predictions = []
# select a random batch and make predictions on it
for i, img in tqdm(enumerate(imgs), total=imgs.shape[0]):
    with torch.no_grad():
        pred_mask = model(img.unsqueeze(0).to(DEVICE))
    
    pred_mask = pred_mask.cpu()[0]
    
    binary_mask = (pred_mask > 0).int()
    predictions.append(binary_mask)
    
    
predictions = torch.stack(predictions, dim=0)

# Visualize the predictions

In [None]:
plt.figure(figsize=(16,16))
for i,(img, pred_mask, gt_mask) in enumerate(zip(imgs, predictions, masks)):
    img = img.permute(1, 2, 0) * STD + MEAN
    img = (img*255.0).numpy().astype(np.uint8)
    gt_mask = gt_mask.squeeze().numpy().astype(np.uint8)
    pred_mask = pred_mask.squeeze().numpy().astype(np.uint8)
    pred_mask[pred_mask == 1] = 2
    
    full_mask = gt_mask + pred_mask
    # plot the GT mask on the image
    
    # red: GT mask, blue: Predicted mask, green: overlap between GT and predicted
    img_data = label2rgb(image=img, label=full_mask, bg_label=0, alpha=0.2, colors=['red', 'blue', 'green'])
        
    # plot the predicted mask on the image
#     img_data = label2rgb(image=img_data, label=pred_mask, bg_label=0, colors=['blue'], alpha=0.2)
        
    plt.subplot(5,5,i+1)
    plt.imshow(img_data)
#     plt.imshow(mask.squeeze().numpy(), alpha=0.2)
    plt.axis('off')
    plt.subplots_adjust(wspace=None, hspace=None)
    
# del ds,dl,imgs,masks