In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torch import optim
import torch.nn.functional as F
import time
import glob
from PIL import Image

from torchvision.datasets import ImageFolder 
from torch.utils.data import DataLoader
from torchvision.transforms import Resize, ToTensor, Normalize, Compose
from torchvision import models


In [2]:
#path to dataset
train_path='images/train'
val_path='images/val'

In [3]:
#resize and load images
vgg_train_dataset = ImageFolder(train_path, transform=Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
vgg_train_loader = DataLoader(vgg_train_dataset, batch_size=16, shuffle=True)

vgg_val_dataset = ImageFolder(val_path, transform=Compose([
    Resize((224,224)),
    ToTensor(),
    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
vgg_val_loader = DataLoader(vgg_val_dataset, batch_size=16)

In [4]:
#training function
device = torch.device("cuda" if torch.cuda.is_available() 
                                  else "cpu")

def train(model, datasets, loaders, num_epochs, loss_fn, optimizer, silent=False):
    train_dataset = datasets[0]
    val_dataset = datasets[1]
    train_loader = loaders[0]
    val_loader = loaders[1]
    
    train_accuracy = np.zeros(num_epochs)
    train_avg_loss = np.zeros(num_epochs)
    val_accuracy = np.zeros(num_epochs)
    val_avg_loss = np.zeros(num_epochs)
    
    for epoch in range(num_epochs):
        start_time = time.time()
        
        losses = []
        n_correct = 0
        for b_x, b_y in train_loader:
            b_x = b_x.to(device)
            b_y = b_y.to(device)

            # Compute predictions and losses
            pred = model(b_x)
            loss = loss_fn(pred, b_y)
            losses.append(loss.item())

            # Count number of correct predictions
            hard_preds = pred.argmax(dim=1)
            n_correct += torch.sum(pred.argmax(dim=1) == b_y).item()

            # Backpropagate
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()    

        # Compute accuracy and loss in the entire training set
        train_accuracy[epoch] = n_correct/len(train_dataset)
        train_avg_loss[epoch] = sum(losses)/len(losses)    

        # Compute accuracy and loss in the entire validation set
        #val_accuracy[epoch], val_avg_loss[epoch] = evaluate_model(val_loader, model, loss_fn)
        losses = []
        n_correct = 0
        with torch.no_grad():
            for b_x, b_y in val_loader:
                b_x = b_x.to(device)
                b_y = b_y.to(device)
                
                pred = model(b_x)
                loss = loss_fn(pred, b_y)
                losses.append(loss.item())

                hard_preds = pred.argmax(dim=1)
                n_correct += torch.sum(pred.argmax(dim=1) == b_y).item()
            val_accuracy[epoch] = n_correct/len(val_dataset)
            val_avg_loss[epoch] = sum(losses)/len(losses)

        if not silent:
            # Display metrics
            display_str = 'Epoch {} ({:.1f}s) '
            display_str += '\tLoss: {:.3f} '
            display_str += '\tLoss (val): {:.3f}'
            display_str += '\tAccuracy: {:.2f} '
            display_str += '\tAccuracy (val): {:.2f}'
            print(display_str.format(epoch, time.time() - start_time, train_avg_loss[epoch],
                                     val_avg_loss[epoch], train_accuracy[epoch], val_accuracy[epoch]))
            
    return train_accuracy, val_accuracy, train_avg_loss, val_avg_loss


In [5]:
#load vgg 16 model
vgg_model = models.vgg16(pretrained=True)

In [7]:
#Train model
num_features = vgg_model.classifier[0].in_features
new_top = nn.Sequential(nn.Linear(num_features, 8), nn.ReLU(), nn.Linear(8, 2), nn.LogSoftmax(dim=1))
vgg_model.classifier = new_top
for param in vgg_model.features.parameters():
    param.requires_grad = False
vgg_model.to(device)
vgg_train_acc, vgg_val_acc, vgg_train_loss, vgg_val_loss = train(
    vgg_model,
    [vgg_train_dataset, vgg_val_dataset],
    [vgg_train_loader, vgg_val_loader],
    10, nn.NLLLoss(), optim.Adam(vgg_model.parameters(), lr=0.0001)
)

Epoch 0 (461.9s) 	Loss: 0.180 	Loss (val): 0.095	Accuracy: 0.95 	Accuracy (val): 0.97
Epoch 1 (472.6s) 	Loss: 0.035 	Loss (val): 0.091	Accuracy: 1.00 	Accuracy (val): 0.97
Epoch 2 (487.8s) 	Loss: 0.016 	Loss (val): 0.081	Accuracy: 1.00 	Accuracy (val): 0.98
Epoch 3 (483.8s) 	Loss: 0.008 	Loss (val): 0.081	Accuracy: 1.00 	Accuracy (val): 0.97
Epoch 4 (487.7s) 	Loss: 0.005 	Loss (val): 0.081	Accuracy: 1.00 	Accuracy (val): 0.97
Epoch 5 (478.0s) 	Loss: 0.004 	Loss (val): 0.081	Accuracy: 1.00 	Accuracy (val): 0.97
Epoch 6 (476.6s) 	Loss: 0.003 	Loss (val): 0.081	Accuracy: 1.00 	Accuracy (val): 0.98
Epoch 7 (475.3s) 	Loss: 0.002 	Loss (val): 0.083	Accuracy: 1.00 	Accuracy (val): 0.97
Epoch 8 (476.3s) 	Loss: 0.002 	Loss (val): 0.083	Accuracy: 1.00 	Accuracy (val): 0.98
Epoch 9 (487.4s) 	Loss: 0.001 	Loss (val): 0.086	Accuracy: 1.00 	Accuracy (val): 0.97


In [8]:
#Save model
torch.save(vgg_model.state_dict(), 'my_fruit_vgg')