In [1]:
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import glob

import _path

PATH_HEAD = _path.PATH_HEAD

In [2]:
label_name = ["airplane","automobile","bird","cat","deer",\
              "dog","frog","horse","ship","truck"]
label_dict = {}

for idx, name in enumerate(label_name):
    label_dict[name] = idx
label_dict

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [3]:
def default_loader():
    return Image.open(path).convert('RGB')

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(28),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, hue=0.2),
    transforms.ToTensor(),
])

In [4]:
class MyDataset(Dataset):
    def __init__(self, im_list, transform=None, loader = default_loader):
        super(MyDataset, self).__init__()
        imgs = []
        
        for im_item in im_list:
            im_label_name = im_item.split('\\')[-2]
            imgs.append([im_item, label_dict[im_label_name]])
        
        self.imgs = imgs
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        im_path, im_label = self.imgs[index]        
        im_data = self.loader(im_path)        
        if self.transform is not None:
            im_data = self.transform(im_data)        
        return im_data, im_label
    
    def __len__(self):
        return len(self.imgs)

In [5]:
im_train_list = glob.glob(r'{}\_data_\data_ml\cifar-10-picture\TRAIN\*\*.png'.format(PATH_HEAD))
len(im_train_list)

50000

In [6]:
im_test_list = glob.glob(r'{}\_data_\data_ml\cifar-10-picture\TEST\*\*.png'.format(PATH_HEAD))
len(im_test_list)

10000

In [7]:
train_dataset = MyDataset(im_train_list, transform=train_transform)
test_dataset = MyDataset(im_test_list,transform=transforms.ToTensor())

In [8]:
train_loader = DataLoader(dataset=train_dataset,
                               batch_size=128,
                               shuffle=True,
                               num_workers=4)

test_loader = DataLoader(dataset=test_dataset,
                               batch_size=128,
                               shuffle=False,
                               num_workers=4)

print("num_of_train", len(train_dataset))
print("num_of_test", len(test_dataset))

num_of_train 50000
num_of_test 10000
