# Data Loader

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, utils
from PIL import Image

In [2]:
def img_load(path):
    img = Image.open(path).convert('RGB')
    return img

In [3]:
def img_show(img):
    plt.imshow(img)
    plt.show()

In [4]:
img_list = list()
for i in range(10000):
    img_path = 'halftone32/x_test/%s.png' % str(i)
    img = img_load(img_path)
    img_list.append(img)

In [5]:
y_train = np.load('halftone32/y_train.npy') 
y_train = y_train.tolist()

In [6]:
y_test = np.load('halftone32/y_test.npy') 
y_test = y_test.tolist()

In [7]:
transform = transforms.Compose(
    [transforms.ToTensor()]
)

In [8]:
class ListData(Dataset):
    
    def __init__(self, X_list, y_list, transform):
        self.X_list = X_list
        self.y_list = y_list
        self.transform = transform
        if len(self.X_list) == len(self.y_list):
            print('len(X_list) == len(y_list)')
        else:
            print('len(X_list) != len(y_list)')
        print('transform: %s' % self.transform)
            
    def __getitem__(self, index):
        if self.transform != None:
            X = self.transform(self.X_list[index])
            y = self.y_list[index]
        return X, y
    
    def __len__(self):
        return len(self.y_list)

In [9]:
data_test = ListData(img_list, y_test, transform)

len(X_list) == len(y_list)
transform: Compose(
    ToTensor()
)


In [10]:
data_loader_test = torch.utils.data.DataLoader(
    dataset=data_test,
    batch_size=64,
    shuffle=True,
)

In [12]:
for data in data_loader_test:
    X_test, y_test = data
    print(X_test, y_test)
    break

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0471, 0.0196, 0.0000],
          [0.0000, 0.0000, 0.0941,  ..., 0.0980, 0.1216, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0667, 0.0000],
          ...,
          [0.0000, 0.1961, 0.0000,  ..., 0.0000, 0.2078, 0.0000],
          [0.0078, 0.3020, 0.0431,  ..., 0.0353, 0.3451, 0.0000],
          [0.0000, 0.0078, 0.0000,  ..., 0.0000, 0.1098, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0902, 0.0784, 0.0000],
          [0.0000, 0.0000, 0.1490,  ..., 0.1176, 0.1804, 0.0196],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0824, 0.0000],
          ...,
          [0.0000, 0.2078, 0.0000,  ..., 0.0000, 0.2353, 0.0000],
          [0.0510, 0.3176, 0.0392,  ..., 0.0667, 0.4353, 0.0431],
          [0.0000, 0.0314, 0.0000,  ..., 0.0000, 0.1647, 0.0196]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0078, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0235,  ..., 0.0235, 0.0627, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0