In [1]:
'''Importing Modules'''

'''
Prerequisites:
NumPy             https://numpy.org/doc/stable/
Matplotlib        https://matplotlib.org/stable/index.html
PyTorch           https://pytorch.org/docs/stable/index.html
Torchvision       https://pytorch.org/docs/stable/index.html
PIL               https://pillow.readthedocs.io/en/stable/
GitPython         https://gitpython.readthedocs.io/en/stable/
split-folders:    https://pypi.org/project/split-folders/
python-dotenv:    https://pypi.org/project/python-dotenv/
'''

# vanilla:
import os
import json
import time
from time import strptime
import datetime
from datetime import timedelta
import shutil
from collections import OrderedDict
import random
import sys
# external:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms, models
import PIL.Image as Image
import git
import splitfolders
from dotenv import load_dotenv
# custom:
from routines import *
from displays import *
import myTransforms
import consts

println([('done.', 'g')])

[32mdone.               [0m


In [2]:
'''Parsing the .env File'''


# Loading sensitive info from the dotenv file.
#  These are needed in order to clone the data repo:
if not load_dotenv(consts.dotenv_path):
    e_msg = 'cannot find the required .env file'
    raise SystemExit(e_msg)
gh_token = os.getenv('GH_TOKEN')
gh_username = os.getenv('GH_USERNAME')
repo_name = os.getenv('REMOTE_REPO_NAME')
repo_url = f'https://{gh_token}@github.com/{gh_username}/{repo_name}.git'
dataset_path = repo_name + '\\dataset'
classes_path = repo_name + '\\classes.json'
url_issue = False

println([('done.', 'g')])

[32mdone.               [0m


In [3]:
'''Cloning the Remote Data-Repository'''


# Checks if a leftover repo exists, overwrite it if so:
if os.path.exists(repo_name):
    git.rmtree(repo_name)
# Clones the repo, and raises an exception if the remote URL is corrupted:
try:
    git.Repo.clone_from(repo_url, repo_name)
except Exception as e:
    url_issue = True
    pass
if url_issue:
    e_msg = 'there is an issue with the remote repo URL'
    raise SystemExit(e_msg)


println([('done.', 'g')])

[32mdone.               [0m


In [4]:
'''Parsing the JSON File from the Data Repository'''


json_not_found = False
try:
    with open(classes_path, 'r') as f:
        json_file = json.load(f)
        classes = dict(json_file[0])
    # Displays the JSON file metadata:
    println(['total classes'], header=True)
    println([len(classes)])
except FileNotFoundError as e:
    json_not_found=True
if json_not_found:
    e_msg=f'cannot locate the \'classes.json\' file in "{repo_name}".'\
        + f'\nre-run the \'Cloning the Remote Data-Repositoryg\' cell and try again.'
    raise SystemExit(e_msg)

println([('done.', 'g')])

[1mtotal classes       [0m
[0m[0m15                  [0m
[32mdone.               [0m


In [24]:
'''Parsing the Input Directory Supplied by the User'''


println([(f'looking for image files in \'{consts.input_path}\'...', 'y')])

files = []
bad_files = []
no_input_dir = False
no_files = False
got_greyscale = False

try:
    for file in os.listdir(consts.input_path):
        if file.endswith(consts.valid_filetypes):
            files.append(file)
        else:
            bad_files.append(file)
except FileNotFoundError as e:
    no_input_dir = True
if no_input_dir:
    e_msg = 'cannot find the input directory (it must be named \'input\' and be in root)'
    raise SystemExit(e_msg)
if bad_files != []:
    e_msg = f'invalid files were found in input the directory, please remove them:\n{bad_files}'
    raise SystemExit(e_msg)
if files == []:
    e_msg = 'cannot find any image files in the input directory'
    raise SystemExit(e_msg)

# Displaying:
println(['#', 'filename', 'height', 'width', 'color'], header=True)    
for idx, file in enumerate(files):
    img = Image.open(consts.input_path + file)
    tensor = transforms.Compose([transforms.PILToTensor()])(img)
    if len(tensor) == 3:
        color = ('rgb', 'g')
    else:
        color = ('greyscale', 'r')
        got_greyscale = True
    color = ('rgb', 'g') if len(tensor)==3 else ('greyscale', 'r')
    height = len(tensor[0])
    width = len(tensor[0][0])
    println([idx + 1, file, height, width, color])
if got_greyscale:
    e_msg = 'cannot parse greyscale images, revise your input'
    raise SystemExit(e_msg)

println([('done.', 'g')])

[33mlooking for image files in './input/'...[0m
[1m#                   [1mfilename            [1mheight              [1mwidth               [1mcolor               [0m
[0m[0m1                   [0m[0m1.jpg               [0m[0m900                 [0m[0m1200                [32mrgb                 [0m
[0m[0m2                   [0m[0m2.png               [0m[0m650                 [0m[0m1169                [32mrgb                 [0m
[0m[0m3                   [0m[0m3.png               [0m[0m427                 [0m[0m640                 [31mgreyscale           [0m


SystemExit: cannot parse greyscale images

In [6]:
'''Instanciating Our Model'''


println([(f'looking for \'.pth\' files in \'{consts.checkpoints_path}\' to instanctiate a new model...', 'y')])

checkpoints = []
no_checkpoints_dir = False

try:
    for file in os.listdir(consts.checkpoints_path):
        if file.endswith('.pth'):
            checkpoints.append(file)
except FileNotFoundError as e:
    no_checkpoints_dir = True

if no_checkpoints_dir:
    e_msg = f'cannot find the checkpoints dir: \'{consts.checkpoints_path}\''
    raise SystemExit(e_msg)
if checkpoints == []:
    e_msg = f'cannot find any \'.pth\' files in \'{consts.checkpoints_path}\''
    raise SystemExit(e_msg)

latest_checkpoint = latestCheckpoint()
pretrained = True
weights=('DEFAULT' if pretrained else None)
model = loadCheckpoint(latest_checkpoint, weights=weights)
model.to('cpu')
model.eval()


println([('done.', 'g')])

[33mlooking for '.pth' files in './checkpoints' to instanctiate a new model...[0m
[32mdone.               [0m


In [7]:
'''Creating the Images List'''


images = [] # list of tuples: (filename, Tensor)

for file in files:
    pil_image = Image.open(consts.input_path + '//' + file)
    pil_image = myTransforms.pilimg_transforms(pil_image)
    images.append((file, pil_image))

In [22]:
'''Prediciting the Supplied Images'''


predictions = [] # list of tuples: (filename, prediction)

for filename, img in images:
    inputs = Variable(img)
    inputs = inputs.unsqueeze(dim = 0)
    log_probabilities = model.forward(inputs)
    probabilities = torch.exp(log_probabilities)
    top_probabilities, top_classes = probabilities.topk(15, dim=1)
    pred = top_classes[0][0]
    pred = str(int(pred) + 1)
    predictions.append((filename, classes[pred]))

println(['#', 'filename', 'prediction'], header=True)
for idx, pred in enumerate(predictions):
    println([idx + 1,
            pred[0],
            pred[1].upper() if pred[1] in ['uk','usa'] else pred[1].capitalize()])

println([('done.', 'g')])

[1m#                   [1mfilename            [1mprediction          [0m
[0m[0m1                   [0m[0m1.jpg               [0m[0mUSA                 [0m
[0m[0m2                   [0m[0m2.png               [0m[0mIsrael              [0m
[32mdone.               [0m


In [19]:
'''Generating the Results Directory'''


# If an output dir already exists, overwrite it:
if os.path.exists(consts.output_path):
    shutil.rmtree(consts.output_path)
os.mkdir(consts.output_path)

for pred in predictions:
    # Creates a new class subdir to dump the image in, if one does not exist yet:
    filename = pred[0]
    classname = pred[1]

    class_path = consts.output_path + classname + '/'

    if not os.path.exists(class_path):
        os.mkdir(class_path)
    shutil.copyfile(consts.input_path + filename, class_path + filename)

println([('done.', 'g')])