In [1]:
'''Importing Modules'''

# vanilla:
import os
import json
import time
import datetime
import shutil
from collections import OrderedDict
# external:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models
import git
import splitfolders
from dotenv import load_dotenv
# custom:
from routines import *
from displays import *
import myTransforms
import consts

print_msg('done.', 'g')

[32mdone.[0m


In [2]:
'''Parsing the .env File'''

# Loading sensitive info from the dotenv file.
# It is needed in order to clone the data repo:
if not load_dotenv(consts.dotenv_path):
    e_msg = 'cannot find the required .env file'
    raise SystemExit(e_msg)
gh_token = os.getenv('GH_TOKEN')
gh_username = os.getenv('GH_USERNAME')
repo_name = os.getenv('REMOTE_REPO_NAME')
repo_url = f'https://{gh_token}@github.com/{gh_username}/{repo_name}.git'
repo_dir_path = f'./{repo_name}'
dataset_dir_path = repo_dir_path + '/dataset'
classes_file_path = repo_dir_path + '/classes.json'

print_msg('done.', 'g')

[32mdone.[0m


In [3]:
'''Cloning the Remote Data-Repository'''

url_issue = False
# Checks if a leftover repo exists, overwrite it if so:
if os.path.exists(repo_dir_path):
    git.rmtree(repo_dir_path)
# Clones the repo, and raises an exception if the remote URL is corrupted:
try:
    git.Repo.clone_from(repo_url, repo_name)
except Exception as e:
    url_issue = True
    pass
if url_issue:
    e_msg = 'there is an issue with the remote repo URL'
    raise SystemExit(e_msg)

print_msg('done.', 'g')

[32mdone.[0m


In [4]:
'''Parsing the JSON File from the Data Repository'''

json_not_found = False
try:
    with open(classes_file_path, 'r') as f:
        json_file = json.load(f)
        classes = OrderedDict(json_file[0])
        images_per_class = json_file[1]['images_per_class']
        # Creates a list of all subdir names (strings) within dataset dir:
        dir_names = [dataset_dir_path + '\\%.2d' % i for i in range(1, len(classes) + 1)]
    # Displays the JSON file metadata:
    print_matrix([
        ('total classes', len(classes)),
        ('images per class', images_per_class)
        ], vector=True)
except FileNotFoundError as e:
    json_not_found=True
if json_not_found:
    e_msg=f'cannot locate the \'classes.json\' file in "{repo_name}".'\
        + f'\nre-run the \'Cloning the Remote Data-Repositoryg\' cell and try again.'
    raise SystemExit(e_msg)

print_msg('done.', 'g')

[1mtotal classes       [0m
[0m15                  [0m

[1mimages per class    [0m
[0m30                  [0m

[32mdone.[0m


In [5]:
'''Validating the Dataset Directory'''

msg = 'performing a valdiation of the cloned data repo according its JSON file \
before any further training can take place...'
print_msg(msg)

files_per_class = []
bad_dirs = []
json_ne_dirs = False

# Validates the number of classes defined in the JSON equals to number of classes subdirs:
if len(os.listdir(dataset_dir_path)) != len(classes):
    json_ne_dirs=True

# Validates that the number of images in each class subdir equals to the one defined in the JSON:
for dir_tuple in os.walk(dataset_dir_path):
    if dir_tuple[0] in dir_names: # skips junk directories
        images_in_dir = len(dir_tuple[2])
        files_per_class.append(images_in_dir)
        if images_in_dir != images_per_class:
            bad_dirs.append(dir_tuple[0])

# Raise exceptions if needed:
if json_ne_dirs:
    e_msg=f'number of classes according to the JSON file ({len(classes)})'\
        + f' does not correlate with total dirs ({len(os.listdir(dataset_dir_path))})'\
        + f' in \"{dataset_dir_path}\".'\
        + f'\nre-run \'Data Repository Cloning\' cell then re-run this cell.'
    raise SystemExit(e_msg)
elif bad_dirs != []:
    e_msg=f'image count in the following directories is incorrect: {bad_dirs}'
    raise SystemExit(e_msg)

# Displaying...
# If the number of files found in a class subdir does not strictly equal
#  to the defined number (from the JSON file), the number will be highlighted
#  with red color; elsewise, in green.

print_header(['id', 'parsed class', 'images found'])
for idx, (id, Class) in enumerate(classes.items()):
    print_line([id,
            Class.upper() if Class in ['uk','usa'] else Class.capitalize(),
            (files_per_class[idx], ('g' if files_per_class[idx] == images_per_class else 'r'))])
print_header(['','','total images'])
print_line(['','',(sum(files_per_class), ('g' if (sum(files_per_class) == (len(classes) * images_per_class)) else 'r'))])

print_msg('done.', 'g')

[33mperforming a valdiation of the cloned data repo according its JSON file before any further training can take place...[0m
[1mid                  [1mparsed class        [1mimages found        [0m
[0m1                   [0mAustralia           [32m30                  [0m
[0m2                   [0mBrazil              [32m30                  [0m
[0m3                   [0mCanada              [32m30                  [0m
[0m4                   [0mChina               [32m30                  [0m
[0m5                   [0mFrance              [32m30                  [0m
[0m6                   [0mGermany             [32m30                  [0m
[0m7                   [0mIndia               [32m30                  [0m
[0m8                   [0mIsrael              [32m30                  [0m
[0m9                   [0mItaly               [32m30                  [0m
[0m10                  [0mJapan               [32m30                  [0m
[0m11          

In [6]:
'''Splitting the Dataset'''

msg = 'creating a new \'sets\' dir, with three subdirs of images: train\', \'valid\', \'test\'...'
print_msg(msg)

# Deleting a leftover 'sets' directory if such exists:

if os.path.exists(consts.sets_path):
    shutil.rmtree(consts.sets_path)

# Randomly splitting the dataset into 'test', 'valid', 'test' image directories:
splitfolder_issue = False
try:
    splitfolders.ratio(
        dataset_dir_path,
        output=consts.sets_path,
        seed=1337,
        ratio=(.8, .1, .1),
        group_prefix=None,
        move=False)
except:
    splitfolder_issue = True
if splitfolder_issue:
    e_msg = 'exception raised within a splitfolders.ratio() call'
    raise SystemExit(e_msg)

print_msg('done.', 'g')

[33mcreating a new 'sets' dir, with three subdirs of images: train', 'valid', 'test'...[0m


Copying files: 450 files [00:01, 251.67 files/s]

[32mdone.[0m





In [7]:
'''Creating DataLoaders'''

batch_size = 32
print(f'batch size is {batch_size}')

# Instanciating each set:
train_data = datasets.ImageFolder(consts.trainset_path, transform=myTransforms.train_transforms)
valid_data = datasets.ImageFolder(consts.validset_path, transform=myTransforms.valid_transforms)
test_data = datasets.ImageFolder(consts.testset_path, transform=myTransforms.test_transforms)

# Creating a DataLoader for each set:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

print_msg('done.', 'g')

batch size is 32
[32mdone.[0m


In [8]:
'''Instaciating a Model and a Classifier'''

pretrained = True
weights=('DEFAULT' if pretrained else None)
model = models.vgg16(weights=weights)
model_name = 'VGG16'
for param in model.parameters():
    # Freeze the MODEL parameters so we don't backprop through them! Only through the classifier.
    param.requires_grad = False
dropout_probability = .5
in_features = 25088
out_features = 1024
od = OrderedDict([('fc1', nn.Linear(in_features, out_features)),
                ('drop', nn.Dropout(p=dropout_probability)),
                ('relu', nn.ReLU()),
                ('fc2', nn.Linear(out_features, len(classes))),
                ('output', nn.LogSoftmax(dim=1))])
classifier = nn.Sequential(od)
model.classifier = classifier

# Displaying:
print_matrix(
    [
        ('Model', f'{model_name}, ' + ('Pretrained ' if pretrained else 'Not Pretrained')),
        ('Classifier', str([f'{layer}' for layer in od.keys()])),
    ],
    vector=True)

print_msg('done.', 'g')

[1mModel               [0m
[0mVGG16, Pretrained   [0m

[1mClassifier          [0m
[0m['fc1', 'drop', 'relu', 'fc2', 'output'][0m

[32mdone.[0m


In [9]:
'''Loading a Model Checkpoint'''

msg = 'looking for \'.pth\' files in the default checkpoints folder;\n\
will load the latest one found, but if none were found it is still OK...'
print_msg(msg)

if os.path.exists(consts.checkpoints_path):
    latest_checkpoint = latestCheckpoint()
    if latest_checkpoint == None:
        print('no \'.pth\' files were found.')
        print(f'no checkpoint loaded.')
    else:
        model = loadCheckpoint(latest_checkpoint, weights)
        print(f'checkpoint loaded from: \"{latest_checkpoint}\"')
else:
    print(f'no checkpoints directory found, created a new one.')
    print(f'no checkpoint loaded.')
    os.mkdir(consts.checkpoints_path)

print_msg('done.', 'g')

[33mlooking for '.pth' files in the default checkpoints folder;
will load the latest one found, but if none were found it is still OK...[0m
no checkpoints directory found, created a new one.
no checkpoint loaded.
[32mdone.[0m


In [10]:
'''Hyperparameters'''

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10
learning_rate = .001
criterion = nn.NLLLoss()
# Only train the CLASSIFIER parameters, FEATURE parameters are frozen!
optimizer = optim.Adam(model.classifier.parameters(), lr = learning_rate)
# Casting the model instance to the available hardware:
model.to(device)
# Hyperparameters names for displaying:
device_name = ('GPU' if device == torch.device('cuda') else 'CPU')
criterion_name = 'Negative Log Loss'
optimizer_name = 'Adam'
# Displaying:
print_matrix([
        ('model', model_name),
        ('pretrained', 'yes' if pretrained else 'no'),
        ('device', device_name),
        ('epochs', epochs),
        ('learning rate', learning_rate),
        ('loss function', criterion_name),
        ('optimizer', optimizer_name)
], vector=True)

print_msg('done.', 'g')

[1mmodel               [0m
[0mVGG16               [0m

[1mpretrained          [0m
[0myes                 [0m

[1mdevice              [0m
[0mGPU                 [0m

[1mepochs              [0m
[0m10                  [0m

[1mlearning rate       [0m
[0m0.001               [0m

[1mloss function       [0m
[0mNegative Log Loss   [0m

[1moptimizer           [0m
[0mAdam                [0m

[32mdone.[0m


In [11]:
'''Model Training, Validation, and Testing'''

msg = f'training on {device_name} started, might take a few minutes to complete...'
print_msg(msg)

# Training and validating part, displaying too:
train_metadata = []
start_training_time = time.time()
print_header(['epoch', 'time', 'train loss', 'valid loss', 'accuracy'])

for idx in range(epochs):
    # Keep the model object up-to-date (because we send it to another function):
    hyperparams = (model, optimizer, device, criterion)
    # Epoch metadata:
    start_time = time.time()
    end_time = None
    train_loss = 0
    valid_loss = 0
    accuracy = 0
    # Switching model mode to TRAINING.
    # Training the model using the entire train image set:
    model.train()
    for inputs, labels in train_loader:     
        train_loss += train(hyperparams, inputs, labels)
    # Switching model mode to EVALUTAION.
    # Validating the model using the entire valid image set:
    model.eval()
    with torch.no_grad():
        for inputs, labels in valid_loader:
            loss, acc = test(hyperparams, inputs, labels)
            valid_loss += loss
            accuracy += acc
    end_time = time.time()
    aggregated_metadata = (idx, start_time, end_time, train_loss, valid_loss, accuracy, (train_loader, valid_loader))
    # Collect this epoch's metadata and add to the list list:
    collect(train_metadata, aggregated_metadata)
    # Display this epoch's metadata:
    print_epoch(train_metadata, idx)

end_training_time = time.time()
total_training_time = end_training_time - start_training_time

# Displaying the collected training metadata:
print_msg('training finished, results:')
print_train_summary(train_metadata)


# Testing part:
print_msg('testing the trained model:')
test_loss = 0
accuracy = 0
# Testing loop:
model.eval()
hyperparams = (model, optimizer, device, criterion)
for inputs, labels in test_loader:
    loss, acc = test(hyperparams, inputs, labels)
    test_loss += loss
    accuracy += acc
print_test_summary(test_loss, accuracy, test_loader)

print_msg('done.', 'g')

[33mtraining on GPU started, might take a few minutes to complete...[0m
[1mepoch               [1mtime                [1mtrain loss          [1mvalid loss          [1maccuracy            [0m
[0m1                   [0m00:09               [0m5.713               [0m3.783               [0m0.302               [0m
[0m2                   [0m00:05               [0m3.094               [0m1.495               [0m0.472               [0m
[0m3                   [0m00:05               [0m1.770               [0m1.113               [0m0.720               [0m
[0m4                   [0m00:05               [0m1.550               [0m0.836               [0m0.775               [0m
[0m5                   [0m00:05               [0m1.330               [0m0.883               [0m0.728               [0m
[0m6                   [0m00:05               [0m1.306               [0m0.722               [0m0.760               [0m
[0m7                   [0m00:04               [0

In [12]:
'''Saving a Model Checkpoint After Training'''

timestamp = datetime.datetime.now().strftime(consts.checkpoint_timestamp_format)
checkpoint_name = f'{timestamp}.pth'
model.class_to_idx = train_data.class_to_idx
checkpoint = {'network': 'vgg16',
              'input_size': in_features,
              'output_size': len(classes), 
              'batch_size': batch_size,
              'classifier' : classifier,
              'epochs': epochs,
              'optimizer': optimizer.state_dict(),
              'state_dict': model.state_dict(),
              'class_to_idx': model.class_to_idx}
checkpoint_path = consts.checkpoints_path + '/' + checkpoint_name
torch.save(checkpoint, checkpoint_path)
print(f'checkpoint saved to \"{checkpoint_path}\"')

print_msg('done.', 'g')

checkpoint saved to "./checkpoints/121156_260223.pth"
[32mdone.[0m
