In [None]:
import os
import numpy as np
import csv
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def get_label(dir_name):
    return dir_name.split(' ')[0]

class Fruits360Dataset(ImageFolder):
    def find_classes(self, directory):
        dirs = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        dirs_to_classes = {cls_name: get_label(cls_name) for cls_name in dirs}
        classes = set(dirs_to_classes.values())
        classes_to_indices = {c: i for i, c in enumerate(classes)}
        class_to_idx = {d: classes_to_indices[dirs_to_classes[d]] for d in dirs}
        return classes, class_to_idx

class DeviceLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

In [None]:
class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()

        self.flatten = nn.Flatten()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.pool = nn.MaxPool2d((2, 2), stride=2)

        self.drop1 = nn.Dropout(.25)
        self.drop2 = nn.Dropout(.4)

        self.batch1 = nn.BatchNorm2d(32)
        self.batch2 = nn.BatchNorm2d(64)
        self.batch3 = nn.BatchNorm1d(128)

        #Conv Layers
        self.layer1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.layer2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.layer3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.layer4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        #FC Layers
        self.fc1 = nn.Linear(4096, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.batch1(self.relu(self.layer1(x)))
        x = self.batch1(self.relu(self.layer2(x)))
        x = self.drop1(self.pool(x))

        x = self.batch2(self.relu(self.layer3(x)))
        x = self.batch2(self.relu(self.layer4(x)))
        x = self.drop2(self.pool(x))

        x = self.batch3(self.relu(self.fc1(self.flatten(x))))
        outs = self.fc2(x)
    
    def training_step(self, batch):
        images, labels = batch 
        return F.cross_entropy(self(images), labels)
        
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                   
        loss = F.cross_entropy(out, labels)   
        _, preds = torch.max(out, dim=1)
        acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))
        return {'validation_loss': loss.detach(), 'validation_accuracy': acc}
        
    def validation_epoch_end(self, outputs):
        batch_accuracies = [x['validation_accuracy'] for x in outputs]
        epoch_accuracies = torch.stack(batch_accuracies).mean()
        batch_losses = [x['validation_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        return {'validation_accuracy': epoch_accuracies.item(), 'validation_loss': epoch_loss.item()}

    def fit(model, criterion, optimizer, scheduler, num_epochs=10):
      since = time.time()
      best_acc = 0
      for epoch in range(num_epochs):
          print('Epoch {}/{}'.format(epoch, num_epochs - 1))
          print('-' * 15)
          
          for phase in ['train', 'valid']:
              if phase == 'train':
                  scheduler.step()
                  model.train()
              else:
                  model.eval()
                  
              running_loss = 0
              running_corrects = 0
              
              for inputs, labels in dataloaders[phase]:
                  inputs = inputs.to(device) 
                  labels = labels.to(device)
                  optimizer.zero_grad()
                  
                  with torch.set_grad_enabled(phase == 'train'):
                      # get outputs and predictions
                      outputs = model(inputs)
                      _, preds = torch.max(outputs, 1)
                      
                      loss = criterion(outputs, labels)
                      
                      if phase == 'train':
                          loss.backward()
                          optimizer.step()
                          
                  running_loss += loss.item() * inputs.size(0)
                  running_corrects += torch.sum(preds == labels.data)
                  
              epoch_loss = running_loss / dataset_sizes[phase]
              epoch_acc = running_corrects.double() / dataset_sizes[phase]
              
              if phase == 'valid' and epoch_acc > best_acc:
                  best_acc = epoch_acc
                  best_model_wts = copy.deepcopy(model.state_dict())
      time_elapsed = time.time() - since
      print('Training took {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
      print('Best validation Acc: {:4f}'.format(best_acc))
      
      model.load_state_dict(best_model_wts)
      return model

In [None]:
def plot_results(x, label='Results', text=''):
    plt.plot(x, '-x')
    plt.xlabel('Epoch')
    plt.ylabel(f'{label}')
    plt.title(f'{label} vs. No. of epochs {text}');

In [None]:
dataset = Fruits360Dataset("./fruits-360/Training", transform=ToTensor())
testset = Fruits360Dataset("./fruits-360/Test", transform=ToTensor())

In [None]:
train_dataset, validation_dataset = random_split(dataset, [(len(dataset) - validation_size), int(len(dataset) * 0.1)])
len(train_dataset), len(validation_dataset)

In [None]:
model = fit(model, criterion, optimizer, exp_scheduler, num_epochs=10)

