In [31]:
%load_ext autoreload
%autoreload 2


import torch
import torchvision
import fastai

assert(torch.__version__ == '1.1.0')
assert(torchvision.__version__== '0.3.0')
assert(fastai.__version__ == '1.0.55')


import copy
from pathlib import Path
from collections import OrderedDict

import torch.nn as nn
from torch import optim
from torch.optim import lr_scheduler
from torchvision.models import resnet34 as resnet
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset
from torch.nn.modules import Module


from pytorch_utils.hooks import hook_context_manager
from pytorch_utils.callbacks import MyLrFinder, RecordMetric
from pytorch_utils.trainer import learn

from torchvision import models, transforms
from torchvision.utils import make_grid
from torchvision.datasets.folder import ImageFolder, default_loader, IMG_EXTENSIONS, make_dataset
from torch.utils.data.dataset import Subset, random_split

from PIL import Image


class Furniture_Dataset(ImageFolder):
    '''
    Custom Dataset from ImageFolder with subset of classes defined by classes_for_consideration.
    If classes_for_consideration = None, then all classes are choosen.
    '''
    
    def __init__(self, root, transform, classes_for_consideration = None):
        super().__init__(root, transform = transform)
        
        if classes_for_consideration:
           # check if the classes in classes_for_consideration are part of available classes deteced by DatasetFolder in the root folder
            if all([_ in self.classes for _ in classes_for_consideration]):
                classes, class_to_idx = classes_for_consideration, {classes_for_consideration[i]: i for i in range(len(classes_for_consideration))}
                samples = make_dataset(self.root, class_to_idx, extensions = IMG_EXTENSIONS)
                self.classes = classes
                self.class_to_idx = class_to_idx
                self.samples = samples
                self.targets = [s[1] for s in samples]
            else:
                print("Certain class in classes_for_consideration is not available in possible classes. Check your classes_for_consideration. Choosing all classes instead")



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
trfms = {}
trfms['train'] = transforms.Compose([transforms.RandomResizedCrop(224),
                                           transforms.RandomHorizontalFlip(),
                                            transforms.RandomVerticalFlip(),
                                            transforms.RandomRotation(10),
                                   transforms.ToTensor(),
                                   transforms.Normalize(*imagenet_stats)])

trfms['val'] = transforms.Compose([transforms.Resize((224,224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize(*imagenet_stats)])

In [33]:
from fastai.vision import  get_transforms, imagenet_stats



all_data = Furniture_Dataset(root = Path('../Data'),transform = trfms,
#             classes_for_consideration = ['arts_and_crafts', 'mid-century-modern', 'rustic', 'traditional'])
                                            classes_for_consideration = None,
                            )



#Create random split and create indices for train and val
indices = {}
# torch.manual_seed(1)
random_indices = torch.randperm(len(all_data))
valid_pct = .25

split = int((1-valid_pct)*len(random_indices))

if split%2:
    split = split-1
    
indices['train']= random_indices[:split]
indices['val'] = random_indices[split:]



In [34]:
datasets ={}
data = {}
batch_size = 32

val_train_ratio = 0.2
val_length = int(val_train_ratio*len(all_data))
train_length = len(all_data) - val_length



datasets['train'], datasets['val'] = random_split(all_data, [train_length, val_length])

for datatype in ['train','val']:
    # Dataloader
    data[datatype] = torch.utils.data.DataLoader(datasets[datatype], batch_size, num_workers=4, shuffle = True)

data['classes'] = all_data.classes

In [None]:

import math
import matplotlib.pyplot as plt
import numpy as np
def show_batch(dataloader, plot_num = None):
    
    inputs, classes = next(iter(dataloader))
    
    classes = classes.tolist()
    
    # denormalization
    mean, std = imagenet_stats
    std = torch.tensor(std)
    mean = torch.tensor(mean)
    inputs = inputs.mul(std[None,:,None,None])+mean[None,:,None,None]

    if (not plot_num) or (plot_num > inputs.shape[0]):
        plot_num = inputs.shape[0]
    
    ncols = 4
    nrows = math.ceil(plot_num/ncols)
    
    fig, ax = plt.subplots(nrows, ncols)
    ax = ax.ravel()
        
    for img in range(plot_num):
        ax[img].imshow(inputs[img,...].permute(1,2,0))  
        ax[img].set_title(dataloader.dataset.subset.classes[classes[img]])

    for _ in ax:
        _.axis('off')
  
