In [1]:
from skimage import io
from skimage import img_as_float

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 

In [4]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from models.prepare_train_val import get_split
from models.transforms import (DualCompose,
                        ImageOnly,
                        Normalize,
                        HorizontalFlip,
                        VerticalFlip,
                        Rotate,
                        RandomBrightness,
                        RandomContrast,
                        AddMargin)

from skimage.io import imread

from models.dataset import SaltDataset

from albumentations import (HorizontalFlip, VerticalFlip, Normalize,
    ShiftScaleRotate, Blur, OpticalDistortion,  GridDistortion, HueSaturationValue, IAAAdditiveGaussianNoise, GaussNoise, MotionBlur,
    MedianBlur, IAAPiecewiseAffine, IAASharpen, IAAEmboss, RandomContrast, RandomBrightness,
    Flip, OneOf, Compose
)


def mask_overlay(image, mask, color=(0, 1, 0)):
    """
    Helper function to visualize mask on the top of the image
    """
    mask = np.dstack((mask, mask, mask)) * np.array(color)
    weighted_sum = cv2.addWeighted(mask, 0.5, image, 0.5, 0.)
    img = image.copy()
    ind = mask[:, :, 1] > 0
    img[ind] = weighted_sum[ind]    
    return img

def imshow(img, mask, title=None):
    """Imshow for Tensor."""
    img = img.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    mask = mask.numpy().transpose((1, 2, 0))
    mask = np.clip(mask, 0, 1)
    fig = plt.figure(figsize = (12,6))
    fig.add_subplot(1,2,1)
    plt.imshow(mask_overlay(img, mask))
    fig.add_subplot(1,2,2)
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) 


In [5]:
import os
os.chdir('../')

In [6]:
train_ids, val_ids = get_split(0)

In [7]:
"""
train_transform = DualCompose([
        AddMargin(128),
        #HorizontalFlip(),
        #VerticalFlip(),
        #Rotate(),
        #ImageOnly(RandomBrightness()),
        #ImageOnly(RandomContrast()),
        ImageOnly(Normalize())
    ])
"""
def train_transform(p=1):
        return Compose([
            PadIfNeeded(min_height=args.train_crop_height, min_width=args.train_crop_width, p=1),
            HorizontalFlip(p=0.5),
            OneOf([
                IAAAdditiveGaussianNoise(),
                GaussNoise(),
            ], p=0.2),
            OneOf([
                MotionBlur(p=0.2),
                MedianBlur(blur_limit=3, p=0.1),
                Blur(blur_limit=3, p=0.1),
            ], p=0.2),
            ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=0, p=0.2),
            OneOf([
                OpticalDistortion(p=0.3),
                GridDistortion(p=0.1),
                IAAPiecewiseAffine(p=0.3),
            ], p=0.2),
            OneOf([
                #CLAHE(clip_limit=2),
                IAASharpen(),
                IAAEmboss(),
                RandomContrast(),
                RandomBrightness(),
            ], p=0.3),
            Normalize(p=1)
        ], p=p)


In [None]:
 train_loader = DataLoader(
        dataset=SaltDataset(train_ids, transform=train_transform),
        shuffle=True,
        num_workers=1,
        batch_size=10,
        pin_memory=torch.cuda.is_available())

In [None]:
img1, mask1 = next(iter(train_loader))
imshow(img1[0], mask1[0])

    Found GPU0 Quadro 2000D which is of cuda capability 2.1.
    PyTorch no longer supports this GPU because it is too old.
    


In [None]:
mask1[0].shape

In [None]:
(mask1[0]==1).sum()

In [None]:
pred_dir = 'data/predictions/AlbuNet/OOF/'
train_dir = 'data/train/'
def show_image(file_name):
    global mask, mask_pred
    img = imread(train_dir + "images/" + file_name)
    mask = imread(train_dir + "masks/" + file_name)
    mask_pred = img_as_float(imread(pred_dir + file_name))
    fig = plt.figure(figsize = (18,6))
    fig.add_subplot(1,4,1)
    plt.imshow(img)
    fig.add_subplot(1,4,2)
    plt.imshow(mask)
    fig.add_subplot(1,4,3)
    plt.imshow(mask_pred)
    fig.add_subplot(1,4,4)
    plt.imshow((mask_pred>0.2))
#show_image('6c793e5879.png')
#show_image('6a1fe1a81e.png')
show_image('fd1be18f7d.png')

In [None]:
mask.shape

In [None]:
(mask_pred>0).sum()

In [None]:
mask_pred.dtype

In [None]:
mask.max()

In [None]:
mask_pred.max()

In [None]:
pred_dir = '../data/predictions/test/'
train_dir = '../data/test/'
def show_image(file_name):
    global mask, mask_pred
    img = imread(train_dir + "images/" + file_name)
    mask_pred = img_as_float(imread(pred_dir + file_name))
    fig = plt.figure(figsize = (18,6))
    fig.add_subplot(1,3,1)
    plt.imshow(img)
    fig.add_subplot(1,3,2)
    plt.imshow(mask_pred)
    fig.add_subplot(1,3,3)
    plt.imshow((mask_pred>0.4))

show_image('009d3365bc.png')
show_image('00801127b0.png')
#show_image('feaae39fc4.png')
#show_image('fb56c30236.png')
#show_image('f7c8709aad.png')

In [None]:
mask_pred