# PyTorch IO 介紹
* 梯度下降時, 需要隨機讀數張影像近來, 並且對影像做data augmentation, 如果這些事卡到訓練流程就不好了
* torch.utils.data.Dataset 跟 torch.utils.data.DataLoader 就是在解決這問題

In [None]:
import PIL.Image as Image
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torchvision.transforms as transforms

In [None]:
class MnistImageDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.img_names_label = []
        for cls in range(10):
            dir = os.path.join(self.img_dir, str(cls))
            for i in os.listdir(dir):
                self.img_names_label.append((i, cls))

        self.transform = transform

    def __len__(self):
        return len(self.img_names_label)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, 
                                str(self.img_names_label[idx][1]), self.img_names_label[idx][0])
        image = Image.open(img_path, mode='r')
        if self.transform:
            image = self.transform(image)
        
        image = transforms.ToTensor()(image)
        label = self.img_names_label[idx][1]
        return image, label

In [None]:
train_dataset = MnistImageDataset(img_dir='/Data/dataset_zoo/mnist/train')
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=4, shuffle=True, num_workers=1, pin_memory=False, prefetch_factor=2, persistent_workers=True)

for x, y in train_dataloader:
    print(x.shape)
    print(y[0])
    break

plt.imshow(x[0].squeeze())

# Data Augmentation
* 對 training data 做一些影像處理, 可增加訓練資料的多樣性

In [None]:
transform = transforms.Compose([
    transforms.RandomVerticalFlip(p=1),
    transforms.RandomHorizontalFlip(p=0.9),
])

train_aug_dataset = MnistImageDataset(img_dir='/Data/dataset_zoo/mnist/train', transform=transform)
train_aug_dataloader = torch.utils.data.DataLoader(
    train_aug_dataset, batch_size=4, shuffle=True)

x, y = next(iter(train_aug_dataloader))
print(x.shape)
print(y[0])
plt.imshow(x[0].squeeze())

In [None]:
transform_set = [
    transforms.GaussianBlur(7,3),
    transforms.RandomRotation(30)
]

transform = transforms.Compose([
    transforms.RandomVerticalFlip(p=1),
    transforms.RandomHorizontalFlip(p=0.9),
    transforms.RandomApply(transform_set, p=0.5)
])

train_aug_dataset = MnistImageDataset(img_dir='/Data/dataset_zoo/mnist/train', transform=transform)
train_aug_dataloader = torch.utils.data.DataLoader(
    train_aug_dataset, batch_size=4, shuffle=True, num_workers=1, pin_memory=False, prefetch_factor=2, persistent_workers=True)

x, y = next(iter(train_aug_dataloader))
print(x.shape)
print(y[0])
plt.imshow(x[0].squeeze())