In [3]:
import unittest

import os
import time
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torchvision

from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Compose, RandomCrop, \
Normalize, Resize, RandomHorizontalFlip, RandomVerticalFlip
from sklearn.model_selection import train_test_split
from PIL import Image, ImageFile, ImageFilter

In [4]:
 def test_validation_loader():
        seed = 0
        torch.manual_seed(seed)
        np.random.seed(seed)

        
        class RandomBlur:
            def __init__(self, p=0.5, radius=2):
                self.p = p
                self.radius = radius


            def __call__(self, x):
                prob = np.random.rand(1)[0]
                if prob < self.p:
                    x = x.filter(ImageFilter.GaussianBlur(self.radius))
                return x


        class VesselDataset(Dataset):
            def __init__(self, img_df, train_image_dir=None, valid_image_dir=None, 
                         test_image_dir=None, transform=None, mode='train', binary=True):
                self.image_ids = list(img_df.ImageId.unique())
                if binary:
                    self.image_labels = list(map(lambda x: 1 if x > 1 else 0, img_df.counts))
                else:
                    self.image_labels = list(img_df.counts - 1) # Image with no mask has 'count' == 1 in df
                self.train_image_dir = train_image_dir
                self.valid_image_dir = valid_image_dir
                self.test_image_dir = test_image_dir

                mean = [0.485, 0.456, 0.406]
                std = [0.229, 0.224, 0.225]
                if transform is not None:
                    self.train_transform = transform
                else:
                    self.train_transform = Compose([
                        Resize(size=(299,299), interpolation=2),
                        RandomHorizontalFlip(p=0.5),
                        RandomVerticalFlip(p=0.5),
                        RandomBlur(p=0.5, radius=2),
                        ToTensor(),
                        Normalize(mean, std) # Apply to all input images
                    ])
                self.valid_transform = Compose([
                    Resize(size=(299,299), interpolation=2),
                    RandomBlur(p=1.0, radius=2), # Blur all images
                    ToTensor(),
                    Normalize(mean, std) # Apply to all input images
                ])
                self.test_transform = Compose([
                    Resize(size=(299,299), interpolation=2),
                    ToTensor(),
                    Normalize(mean, std) # Apply to all input images
                ])
                self.mode = mode


            def __len__(self):
                return len(self.image_ids)


            def __getitem__(self, idx):
                img_file_name = self.image_ids[idx]
                if self.mode == 'train':
                    img_path = os.path.join(self.train_image_dir, img_file_name)
                elif self.mode == 'valid':
                    img_path = os.path.join(self.valid_image_dir, img_file_name)
                else:
                    img_path = os.path.join(self.test_image_dir, img_file_name)

                #img = imread(img_path)
                img = Image.open(img_path)
                label = self.image_labels[idx]
                if self.mode =='train':
                    img = self.train_transform(img)
                elif self.mode == 'valid':
                    img = self.valid_transform(img)
                else:
                    img = self.test_transform(img)

                if self.mode == 'train' or self.mode == 'valid':
                    return img, label
                else:
                    return img, img_file_name

        ship_dir = '../dev/'
        train_image_dir = os.path.join(ship_dir, 'imgs/')
        valid_image_dir = os.path.join(ship_dir, 'imgs/')
        masks = pd.read_csv(os.path.join(ship_dir,'train_ship_segmentations_v2.csv'))
        unique_img_ids = masks.groupby('ImageId').size().reset_index(name='counts')
        train_ids, valid_ids = train_test_split(
            unique_img_ids, 
            test_size = .999, 
            stratify = unique_img_ids['counts'],
        )
        train_df = pd.merge(unique_img_ids, train_ids)
        valid_df = pd.merge(unique_img_ids, valid_ids)

        binary = True
        vessel_dataset = VesselDataset(train_df, valid_image_dir=train_image_dir, 
                                       mode='valid', binary=binary)

        vessel_valid_dataset = VesselDataset(valid_df, valid_image_dir=valid_image_dir, 
                                       mode='valid', binary=binary)
        img_id = vessel_valid_dataset.image_ids.index('000d26c17.jpg') # Image contained in valid_image_dir
        vessel_valid_dataset.__getitem__(img_id)

In [5]:
try:
    test_validation_loader()
    print("Passed.")
except:
    raise

Passed.
