In [None]:
'''Importing Modules'''


# vanilla:
import os
import json
import time
from time import strptime
import datetime
from datetime import timedelta
import shutil
from collections import OrderedDict
import random
import sys
# external:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms, models
import PIL.Image as Image
import git
import splitfolders
from dotenv import load_dotenv
# custom:
from routines import *
from displays import *
import myTransforms
import consts

println([('done.', 'g')])

In [None]:
'''Parsing the .env File'''


# Loading sensitive info from the dotenv file.
# It is needed in order to clone the data repo:
if not load_dotenv(consts.dotenv_path):
    e_msg = 'cannot find the required .env file'
    raise SystemExit(e_msg)
gh_token = os.getenv('GH_TOKEN')
gh_username = os.getenv('GH_USERNAME')
repo_name = os.getenv('REMOTE_REPO_NAME')
repo_url = f'https://{gh_token}@github.com/{gh_username}/{repo_name}.git'
repo_dir_path = f'./{repo_name}'
dataset_dir_path = repo_dir_path + '/dataset'
classes_file_path = repo_dir_path + '/classes.json'

println([('done.', 'g')])

In [None]:
'''Cloning the Remote Data-Repository'''

url_issue = False
# Checks if a leftover repo exists, overwrite it if so:
if os.path.exists(repo_dir_path):
    git.rmtree(repo_dir_path)
# Clones the repo, and raises an exception if the remote URL is corrupted:
try:
    git.Repo.clone_from(repo_url, repo_name)
except Exception as e:
    url_issue = True
    pass
if url_issue:
    e_msg = 'there is an issue with the remote repo URL'
    raise SystemExit(e_msg)


println([('done.', 'g')])

In [None]:
'''Parsing the JSON File from the Data Repository'''


json_not_found = False
try:
    with open(classes_file_path, 'r') as f:
        json_file = json.load(f)
        classes = OrderedDict(json_file[0])
        images_per_class = json_file[1]['images_per_class']
        # Creates a list of all subdir names (strings) within dataset dir:
        dir_names = [dataset_dir_path + '\\%.2d' % i for i in range(1, len(classes) + 1)]
    # Displays the JSON file metadata:
    println(['total classes', 'images per class'], header=True)
    println([len(classes), images_per_class])
except FileNotFoundError as e:
    json_not_found=True
if json_not_found:
    e_msg=f'cannot locate the \'classes.json\' file in "{repo_name}".'\
        + f'\nre-run the \'Cloning the Remote Data-Repositoryg\' cell and try again.'
    raise SystemExit(e_msg)

println([('done.', 'g')])

In [None]:
'''Validating the Dataset Directory'''


println([('performing a valdiation of the cloned data repo according its JSON file,\n\
before any further training can take place...', 'y')])

files_per_class = []
bad_dirs = []
json_ne_dirs = False

# Validates the number of classes defined in the JSON equals to number of classes subdirs:
if len(os.listdir(dataset_dir_path)) != len(classes):
    json_ne_dirs=True

# Validates that the number of images in each class subdir equals to the one defined in the JSON:
for dir_tuple in os.walk(dataset_dir_path):
    if dir_tuple[0] in dir_names: # skips junk directories
        images_in_dir = len(dir_tuple[2])
        files_per_class.append(images_in_dir)
        if images_in_dir != images_per_class:
            bad_dirs.append(dir_tuple[0])

# Raise exceptions if needed:
if json_ne_dirs:
    e_msg=f'number of classes according to the JSON file ({len(classes)})'\
        + f' does not correlate with total dirs ({len(os.listdir(dataset_dir_path))})'\
        + f' in \"{dataset_dir_path}\".'\
        + f'\nre-run \'Data Repository Cloning\' cell then re-run this cell.'
    raise SystemExit(e_msg)
elif bad_dirs != []:
    e_msg=f'image count in the following directories is incorrect: {bad_dirs}'
    raise SystemExit(e_msg)

# Displaying...
# If the number of files found in a class subdir does not strictly equal
#  to the defined number (from the JSON file), the number will be highlighted
#  with red color; elsewise, in green.
println(['id', 'parsed class', 'images found'], header=True)
for i, (ID, Class) in enumerate(classes.items()):
    println([ID,
            Class.upper() if Class in ['uk','usa'] else Class.capitalize(),
            (files_per_class[i], ('g' if files_per_class[i] == images_per_class else 'r'))])
println(['','','total images'], header=True)
println(['','',(sum(files_per_class), ('g' if (sum(files_per_class) == (len(classes) * images_per_class)) else 'r'))])
println([('done.', 'g')])

In [None]:
'''Splitting the Dataset'''


println([('creating a new \'sets\' dir, with three subdirs of images: \
\'train\', \'valid\', \'test\'...', 'y')])

# Deleting a leftover 'sets' directory if such exists:

if os.path.exists(consts.sets_path):
    shutil.rmtree(consts.sets_path)

# Randomly splitting the dataset into 'test', 'valid', 'test' image directories:
splitfolder_issue = False
try:
    splitfolders.ratio(
        dataset_dir_path,
        output=consts.sets_path,
        seed=1337,
        ratio=(.8, .1, .1),
        group_prefix=None,
        move=False)
except:
    splitfolder_issue = True
if splitfolder_issue:
    e_msg = 'exception raised within a splitfolders.ratio() call'
    raise SystemExit(e_msg)

println([('done.', 'g')])

In [None]:
'''Creating DataLoaders'''


batch_size = 32
print(f'batch size is {batch_size}')

# Instanciating each set:
train_data = datasets.ImageFolder(consts.trainset_path, transform=myTransforms.train_transforms)
valid_data = datasets.ImageFolder(consts.validset_path, transform=myTransforms.valid_transforms)
test_data = datasets.ImageFolder(consts.testset_path, transform=myTransforms.test_transforms)

# Creating a DataLoader for each set:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

println([('done.', 'g')])

In [None]:
'''Instaciating a Model and a Classifier'''


pretrained = True
weights=('DEFAULT' if pretrained else None)
model = models.vgg16(weights=weights)
model_name = 'VGG16'
for param in model.parameters():
    # Freeze the MODEL parameters so we don't backprop through them! Only through the classifier.
    param.requires_grad = False
dropout_probability = .5
in_features = 25088
out_features = 1024
od = OrderedDict([('fc1', nn.Linear(in_features, out_features)),
                ('drop', nn.Dropout(p=dropout_probability)),
                ('relu', nn.ReLU()),
                ('fc2', nn.Linear(out_features, len(classes))),
                ('output', nn.LogSoftmax(dim=1))])
classifier = nn.Sequential(od)
model.classifier = classifier

# Displaying:
println(['Model'], header=True)
println([f'{model_name}, ' + ('Pretrained ' if pretrained else 'Not Pretrained')])
println(['Classifier'], header=True)
print([f'{layer}' for layer in od.keys()])

println([('done.', 'g')])

In [None]:
'''Loading a Model Checkpoint'''


println([('looking for \'.pth\' files in the default checkpoints folder; \
will load the latest one, but if none were found it is still OK...', 'y')])

if os.path.exists(consts.checkpoints_path):
    latest_checkpoint = latestCheckpoint()
    if latest_checkpoint == None:
        print('no \'.pth\' files were found.')
        print(f'no checkpoint loaded.')
    else:
        model = loadCheckpoint(latest_checkpoint, weights)
        print(f'checkpoint loaded from: \"{latest_checkpoint}\"')
else:
    print(f'no checkpoints directory found, created a new one.')
    print(f'no checkpoint loaded.')
    os.mkdir(consts.checkpoints_path)

println([('done.', 'g')])

In [None]:
'''Defining the Training Hyperparameters'''


# Hyperparameters:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10
learning_rate = .001
criterion = nn.NLLLoss()
# Only train the CLASSIFIER parameters, FEATURE parameters are frozen!
optimizer = optim.Adam(model.classifier.parameters(), lr = learning_rate)
# Casting the model instance to the available hardware:
model.to(device)
# Hyperparameters names for displaying:
device_name = ('GPU' if device == torch.device('cuda') else 'CPU')
criterion_name = 'Negative Log Loss'
optimizer_name = 'Adam'
# Displaying:
println(['model', 'pretrained', 'device', 'epochs'], header=True)
println([model_name, ('yes' if pretrained else 'no'), 
        (device_name, ('g' if device_name == 'GPU' else 'r')), epochs])
println(['learning rate', 'loss function', 'optimizer'], header=True)
println([learning_rate, criterion_name, optimizer_name])

println([('done.', 'g')])

In [None]:
'''Model Training, Validation, and Testing'''


println([(f'training on {device_name} started, might take a few minutes to complete...', 'y')])

# Training and validating part, displaying too:
train_metadata = []
start_training_time = time.time()
println(['epoch', 'time', 'train loss', 'valid loss', 'accuracy'], header=True)

for idx in range(epochs):
    # Keep the model object up-to-date (because we send it to another function):
    hyperparams = (model, optimizer, device, criterion)
    # Epoch metadata:
    start_time = time.time()
    end_time = None
    train_loss = 0
    valid_loss = 0
    accuracy = 0
    # Switching model mode to TRAINING.
    # Training the model using the entire train image set:
    model.train()
    for inputs, labels in train_loader:     
        train_loss += train(hyperparams, inputs, labels)
    # Switching model mode to EVALUTAION.
    # Validating the model using the entire valid image set:
    model.eval()
    with torch.no_grad():
        for inputs, labels in valid_loader:
            loss, acc = test(hyperparams, inputs, labels)
            valid_loss += loss
            accuracy += acc
    end_time = time.time()
    aggregated_metadata = (idx, start_time, end_time, train_loss, valid_loss, accuracy, (train_loader, valid_loader))
    # Collect this epoch's metadata and add to the list list:
    collect(train_metadata, aggregated_metadata)
    # Display this epoch's metadata:
    displayTrain(train_metadata, idx)

end_training_time = time.time()
total_training_time = (end_training_time - start_training_time )

# Displaying the collected training metadata:
println([(f'training finished, results:', 'y')])
displayTrain(train_metadata)


# Testing part:
println([(f'testing the trained model:', 'y')])
test_loss = 0
accuracy = 0
# Testing loop:
model.eval()
hyperparams = (model, optimizer, device, criterion)
for inputs, labels in test_loader:
    loss, acc = test(hyperparams, inputs, labels)
    test_loss += loss
    accuracy += acc
displayTest(test_loss, accuracy, test_loader)

println([('done.', 'g')])

In [None]:
'''Saving a Model Checkpoint'''


timestamp = datetime.datetime.now().strftime(consts.checkpoint_timestamp_format)
checkpoint_name = f'{timestamp}.pth'
model.class_to_idx = train_data.class_to_idx
checkpoint = {'network': 'vgg16',
              'input_size': in_features,
              'output_size': len(classes), 
              'batch_size': batch_size,
              'classifier' : classifier,
              'epochs': epochs,
              'optimizer': optimizer.state_dict(),
              'state_dict': model.state_dict(),
              'class_to_idx': model.class_to_idx}
checkpoint_path = consts.checkpoints_path + '/' + checkpoint_name
torch.save(checkpoint, checkpoint_path)
print(f'checkpoint saved to \"{checkpoint_path}\"')

println([('done.', 'g')])
