In [1]:
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch

from matplotlib.image import imread
from PIL import Image
from skimage import io
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate

## Dataset

In [8]:
class ImageDataset(Dataset):
    def __init__(self, files):
        self.files = np.array(files)
        self.length = len(files)
    
    def __getitem__(self, idx):
        try:
            img = imread(self.files[idx])
            img_lab = rgb2lab(img)
            img_lab = (img_lab + [0, 128, 128]) / [100, 255, 255]  # normalize L, a, b dimensions
            img_lightness = img_lab[:, :, 0:1].transpose(2, 0, 1)
            img_ab = img_lab[:, :, 1:3].transpose(2, 0, 1)
            return img_lightness, img_ab
        except Exception as e:
            return
    
    def __len__(self):
        return self.length

In [9]:
def collate_fn(batch):
    """Filter out grayscale images"""
    batch = list(filter(lambda x: x is not None, batch))
    return default_collate(batch)

In [10]:
# data/val/Places365_val_00000032.jpg is a grayscale image
test = ["data/val/Places365_val_00000032.jpg", "data/val/Places365_val_00000031.jpg"]

In [11]:
batch_size = 20
test_ds = ImageDataset(test)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [12]:
test_ds[0]

In [14]:
# test_ds[1]

In [20]:
x, y = next(iter(test_dl))

In [21]:
x.shape, y.shape

(torch.Size([1, 1, 256, 256]), torch.Size([1, 2, 256, 256]))

In [16]:
val_files = glob.glob("data/val/*.jpg")
len(val_files)

36500

In [17]:
ds = ImageDataset(val_files)
dl = DataLoader(ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [18]:
x, y = next(iter(dl))

In [19]:
x.shape, y.shape

(torch.Size([20, 1, 256, 256]), torch.Size([20, 2, 256, 256]))