## 필요한 모듈 import

In [1]:
import torch
from torchvision import datasets, transforms

## Transformation 설정

In [2]:
# Transformation
transform = transforms.Compose([
    transforms.ToTensor(),               # Tensor로 변환
    transforms.Normalize((0.5,), (0.5,)) # Normalize 적용. 1 Channel 인경우. 3 Channel 인경우 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

## PyTorch 내장 샘플 데이터셋 로드

- Download Path 설정
- `download`: 다운로드 받을지 여부 설정
- `train`: True로 설정한 경우는 training set. False인 경우 test set.
- `transform`: transformation 적용

[PyTorch 내장 데이터셋 리스트](https://pytorch.org/vision/stable/datasets.html)

In [3]:
# PyTorch 내장 데이터셋 출력
[d for d in datasets.__dir__() if d[0].isupper()]

['LSUN',
 'LSUNClass',
 'ImageFolder',
 'DatasetFolder',
 'CocoCaptions',
 'CocoDetection',
 'CIFAR10',
 'CIFAR100',
 'STL10',
 'MNIST',
 'EMNIST',
 'FashionMNIST',
 'KMNIST',
 'QMNIST',
 'SVHN',
 'PhotoTour',
 'FakeData',
 'SEMEION',
 'Omniglot',
 'SBU',
 'Flickr8k',
 'Flickr30k',
 'VOCSegmentation',
 'VOCDetection',
 'Cityscapes',
 'ImageNet',
 'Caltech101',
 'Caltech256',
 'CelebA',
 'SBDataset',
 'VisionDataset',
 'USPS',
 'Kinetics400',
 'HMDB51',
 'UCF101']

In [None]:
# Download and load the training data
trainset = datasets.FashionMNIST('./data', download=True, train=True, transform=transform)
testset = datasets.FashionMNIST('./data', download=True, train=False, transform=transform)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

## DataLoader 변환

- 다운로드 및 로드한 데이터
- `batch_size`: 배치 사이즈 정의
- `shuffle`: 셔플 여부 설정

In [None]:
# DataLoader 변환
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True)

In [None]:
# 1개의 batch 선택
image, label = next(iter(trainloader))

In [None]:
# shape 출력
image.shape, label.shape

## 10개의 데이터셋 시각화

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 5)
fig.set_size_inches(12, 5)
img = image.numpy()

for i in range(10):
    ax[i//5, i%5].imshow(img[i, 0])
    ax[i//5, i%5].spines['top'].set_visible(False)
    ax[i//5, i%5].spines['right'].set_visible(False)
    ax[i//5, i%5].spines['left'].set_visible(False)
    ax[i//5, i%5].spines['bottom'].set_visible(False)
    ax[i//5, i%5].tick_params(axis='both', length=0)
    ax[i//5, i%5].set_xticklabels('')
    ax[i//5, i%5].set_yticklabels('')
    ax[i//5, i%5].set_title(label[i].item())
    
plt.tight_layout()
plt.show()