In [1]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as scp

In [2]:
dataset_name = 'Flowers102'

train_transform = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                        [0.229, 0.224, 0.225])])

testval_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

train_dataset = torchvision.datasets.Flowers102(root='./data', split='train', download=True, transform=train_transform)
val_dataset = torchvision.datasets.Flowers102(root='./data', split='val', download=True, transform=testval_transform)
test_dataset = torchvision.datasets.Flowers102(root='./data', split='test', download=True, transform=testval_transform)

In [3]:
BATCH_SIZE = 64

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [4]:
def image_preprocessing(pil_image):    
    # -------- Resize with Aspect Ratio maintained--------- #
    # First fixing the short axes
    if pil_image.size[0] > pil_image.size[1]:
        pil_image.thumbnail((10000000, 256))
    else:
        pil_image.thumbnail((256, 100000000))
    
    # ---------Crop----------- #
    left_margin = (pil_image.width - 224) / 2
    bottom_margin = (pil_image.height - 224) / 2
    right_margin = left_margin + 224
    top_margin = bottom_margin + 224
    
    pil_image = pil_image.crop((left_margin, bottom_margin, right_margin, top_margin))
    
    # --------- Convert to np then Normalize ----------- #
    np_image = np.array(pil_image) / 255
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    np_image = (np_image -mean) / std
    
    # --------- Transpose to fit PyTorch Axes ----------#
    np_image = np_image.transpose([2, 0, 1])
    
    return np_image

def imshow(pt_image, ax = None, title = None):
    '''
    Takes in a PyTorch-compatible image with [Ch, H, W],
    Convert it back to [H, W, Ch], 
    Undo the preprocessing,
    then display it on a grid
    '''
    if ax is None:
        fig, ax = plt.subplots()
    
    # --------- Transpose ----------- #
    plt_image = pt_image.transpose((1, 2, 0))
    
    # --------- Undo the preprocessing --------- #
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    plt_image = plt_image * std + mean
    
    if title is not None:
        ax.set_title(title)
        
    # Image need to be clipped between 0 and 1 or it looks noisy
    plt_image = np.clip(plt_image, 0, 1)
    
    # this imshow is a function defined in the plt module
    ax.imshow(plt_image)
    
    return ax

In [12]:
label_path = './data/flowers-102/imagelabels.mat'
label_arr = scp.loadmat(label_path)['labels']
label_arr

array([[77, 77, 77, ..., 62, 62, 62]], dtype=uint8)

In [14]:
split_path = './data/flowers-102/setid.mat'
data_splits = scp.loadmat(split_path)
train_split = data_splits['trnid']
print(train_split.shape)
val_split = data_splits['valid']
print(val_split.shape)
test_split = data_splits['tstid']
print(test_split.shape)

(1, 1020)
(1, 1020)
(1, 6149)


In [15]:
# Select a random sample
seed = 0
torch.manual_seed(seed)
images, labels = next(iter(train_loader))
print(images.shape)
# image = image_preprocessing(images)
# imshow(image)

torch.Size([64, 3, 224, 224])


In [21]:
# class_sample_counts = {}

# for c in range(1, 103):
#     class_sample_counts[c] = train_dataset._labels.count(c)

# # Plot showing the class imbalance
# plt.figure(figsize=(22, 5))    
# plt.bar(range(len(class_sample_counts)), list(class_sample_counts.values()), align='center')
# plt.xticks(range(len(class_sample_counts)), list(class_sample_counts.keys()))
# plt.xticks(rotation=80)
# plt.title('Samples per Class')
# plt.show()