In [None]:
import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, models, datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

In [None]:
data_dir = './flower_data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomRotation(45), # Random rotation, choose one between -45 and 45
        transforms.CenterCrop(224), # Crop from the center
        transforms.RandomHorizontalFlip(p=0.5), # Random horizontal flip, select one probability
        transforms.RandomVerticalFlip(p=0.5), # Random vertical flip
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1), # Arg 1 is brightness, arg 2 is contrast, arg 3 is saturation, arg 4 is hue
        transforms.RandomGrayscale(p=0.025), # Convert the probability into gray rate, R=G=B for 3 channels
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Average, standard error
    ]),
    'valid': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [None]:
batch_size = 8

image_datasets = {x: datasets.ImageFolder(str(os.path.join(data_dir, x)), data_transforms[x]) for x in ['train', 'valid']}
data_loaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes

In [None]:
image_datasets

In [None]:
data_loaders

In [None]:
dataset_sizes

In [None]:
with open('cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)

In [None]:
cat_to_name

In [None]:
def im_convert(tensor):
    """Display the data"""
    
    image = tensor.to('cpu').clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1, 2, 0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)
    
    return image

In [None]:
fig = plt.figure(figsize=(20, 12))
columns = 4
rows = 2

data_iter = iter(data_loaders['valid'])
inputs, classes = next(data_iter)

for idx in range(columns * rows):
    ax = fig.add_subplot(rows, columns, idx + 1, xticks=[], yticks=[])
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()