In [1]:
from unet import UNet

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.utils as utils
import torchvision.datasets as dataset
import torchvision.transforms as transforms

from pathlib import Path

In [2]:
from torch.utils.data import Dataset

class NumbersDataset(Dataset):
    
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [3]:
dataset = NumbersDataset()
print(len(dataset))
print(dataset[100])
print(dataset[122:361])

TypeError: __init__() missing 1 required positional argument: 'root_dir'

In [2]:
from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image

import numpy as np
import h5py

In [5]:
class BasicDataset(Dataset):
    
    def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''):
        
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.scale = scale
        self.mask_suffix = mask_suffix
        assert 0 < scale <= 1, 'Scale must be between 0 and 1'

        self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
                    if not file.startswith('.')]
        logging.info(f'Creating dataset with {len(self.ids)} examples')
    
    def __len__(self):
        
        return len(self.ids)
    
    def __getitem__(self, i):
        
        idx = self.ids[i]
        mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
        img_file = glob(self.imgs_dir + idx + '.*')

        assert len(mask_file) == 1, \
            f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
        assert len(img_file) == 1, \
            f'Either no image or multiple images found for the ID {idx}: {img_file}'
        mask = Image.open(mask_file[0])
        img = Image.open(img_file[0])

        assert img.size == mask.size, \
            f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'

        img = self.preprocess(img, self.scale)
        mask = self.preprocess(mask, self.scale)

        return {
            'image': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor)
        }

In [3]:
ds_file_name  = "/media/philipp/ed7d22ba-5a3b-4d31-bf6c-6add6e106b3d/test/256x256/1m/dataset_256.hdf5";
x = 'ortho'
y = 'ground_truth'

In [9]:
# read from hdf5
def read_hdf5(hdf5_file_name):
    """ Reads image from HDF5.
        Parameters:
        ---------------
        num_images   number of images to read

        Returns:
        ----------
        images      images array, (N, 32, 32, 3) to be stored
        labels      associated meta data, int label (N, 1)
    """
    images, labels = [], []

    # Open the HDF5 file
    file = h5py.File(hdf5_file_name, "r+")

    images = np.array(file["/ortho"]).astype("uint8")
    labels = np.array(file["/ground_truth"]).astype("uint8")

    return images, labels

In [10]:
images, labels = read_hdf5(ds_file_name)

In [11]:
images.shape

(20928, 256, 256, 4)

In [21]:
class BasicDataset(Dataset):
    
    def __init__(self, file_name, scale=1):
        
        self.hdf5_file_name = file_name
        self.scale = scale
        
        # load dataset
        self.images, self.labels = self.read_hdf5(self.hdf5_file_name)
        
        assert 0 < scale <= 1, 'Scale must be between 0 and 1'
        logging.info(f'Creating dataset with {len(self.images)} examples')
    
    def __len__(self):
        
        return len(self.images)
    
    @classmethod
    def preprocess(cls, img_nd):

        if len(img_nd.shape) == 2:
            img_nd = np.expand_dims(img_nd, axis=2)

        # HWC to CHW
        img_trans = img_nd.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255

        return img_trans
    
    def __getitem__(self, i):
        
        mask = self.labels[i]
        img = self.images[i]

        #assert img.size == mask.size, \
        #    f'Image and mask should be the same size, but are {img[:,:,1].size} and {mask.size}'

        img = self.preprocess(img)
        mask = self.preprocess(mask)

        return {
            'image': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor)
        }
    
    def read_hdf5(self, hdf5_file_name):
        """ Reads image from HDF5.
            Parameters:
            ---------------
            num_images   number of images to read

            Returns:
            ----------
            images      images array, (N, 32, 32, 3) to be stored
            labels      associated meta data, int label (N, 1)
        """
        images, labels = [], []

        # Open the HDF5 file
        file = h5py.File(hdf5_file_name, "r+")

        images = np.array(file["/ortho"]).astype("uint8")
        labels = np.array(file["/ground_truth"]).astype("uint8")

        return images, labels

In [22]:
dataset = BasicDataset(ds_file_name)

In [23]:
dataset[1]

{'image': tensor([[[0.3882, 0.4392, 0.3765,  ..., 0.5608, 0.5725, 0.5529],
          [0.3922, 0.3843, 0.3843,  ..., 0.5294, 0.5412, 0.5529],
          [0.3725, 0.3961, 0.3804,  ..., 0.5725, 0.5725, 0.5804],
          ...,
          [0.2588, 0.2588, 0.2588,  ..., 0.2667, 0.3725, 0.4745],
          [0.2471, 0.2549, 0.2471,  ..., 0.3804, 0.4392, 0.4078],
          [0.2549, 0.2549, 0.2627,  ..., 0.4235, 0.3765, 0.4431]],
 
         [[0.3922, 0.4588, 0.3765,  ..., 0.5843, 0.6000, 0.5765],
          [0.4039, 0.3922, 0.3882,  ..., 0.5529, 0.5725, 0.5765],
          [0.3804, 0.4039, 0.3882,  ..., 0.6078, 0.6039, 0.6039],
          ...,
          [0.2627, 0.2745, 0.2902,  ..., 0.2745, 0.3882, 0.5176],
          [0.2588, 0.2745, 0.2745,  ..., 0.4000, 0.4784, 0.4078],
          [0.2667, 0.2784, 0.2863,  ..., 0.4588, 0.3843, 0.4706]],
 
         [[0.3176, 0.3882, 0.2941,  ..., 0.4745, 0.5059, 0.4667],
          [0.3216, 0.3098, 0.3020,  ..., 0.4275, 0.4471, 0.4627],
          [0.3059, 0.3255, 0.31

In [None]:
import matplotlib.pyplot as plt

In [None]:
def plot_img_and_mask(img, mask):
    classes = mask.shape[2] if len(mask.shape) > 2 else 1
    fig, ax = plt.subplots(1, classes + 1)
    ax[0].set_title('Input image')
    ax[0].imshow(img)
    if classes > 1:
        for i in range(classes):
            ax[i+1].set_title(f'Output mask (class {i+1})')
            ax[i+1].imshow(mask[:, :, i])
    else:
        ax[1].set_title(f'Output mask')
        ax[1].imshow(mask)
    plt.xticks([]), plt.yticks([])
    plt.show()