In [None]:
%load_ext autoreload
%autoreload 2

In [97]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from models.prepare_train_val import get_split
from models.transforms import (DualCompose,
                        ImageOnly,
                        Normalize,
                        HorizontalFlip,
                        VerticalFlip,
                        Rotate,
                        RandomBrightness,
                        RandomContrast,
                        AddMargin)


from models.dataset import SaltDataset


def mask_overlay(image, mask, color=(0, 1, 0)):
    """
    Helper function to visualize mask on the top of the image
    """
    mask = np.dstack((mask, mask, mask)) * np.array(color)
    weighted_sum = cv2.addWeighted(mask, 0.5, image, 0.5, 0.)
    img = image.copy()
    ind = mask[:, :, 1] > 0
    img[ind] = weighted_sum[ind]    
    return img

def imshow(img, mask, title=None):
    """Imshow for Tensor."""
    img = img.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    mask = mask.numpy().transpose((1, 2, 0))
    mask = np.clip(mask, 0, 1)
    fig = plt.figure(figsize = (12,6))
    fig.add_subplot(1,2,1)
    plt.imshow(mask_overlay(img, mask))
    fig.add_subplot(1,2,2)
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) 


In [65]:
train_ids, val_ids = get_split(0)

In [66]:
train_transform = DualCompose([
        AddMargin(128),
        HorizontalFlip(),
        VerticalFlip(),
        Rotate(),
        ImageOnly(RandomBrightness()),
        ImageOnly(RandomContrast()),
        ImageOnly(Normalize())
    ])

In [101]:
 train_loader = DataLoader(
        dataset=SaltDataset(train_ids, transform=train_transform),
        shuffle=True,
        num_workers=1,
        batch_size=10,
        pin_memory=torch.cuda.is_available())

In [115]:
img1, mask1 = next(iter(train_loader))
imshow(img1[0], mask1[0])

In [76]:
mask1[0].shape

torch.Size([1, 101, 101])

In [78]:
(mask1[0]==1).sum()

tensor(0)