In [1]:
# vanilla modules:
import os
import json
import time
import datetime
import shutil
from collections import OrderedDict

# external modules:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable as var
import torchvision
from torchvision import datasets, transforms, models
import PIL.Image as img
import git

# Prerequisites:
# - NumPy
# - Matplotlib
# - PyTorch, Torchvision
# - PIL
# - GitPython
# - 

In [2]:
# Helpers:
class Style:
	BOLD = 		'\033[1m'
	BLACK = 	'\033[30m'
	RED =		'\033[31m'
	GREEN =		'\033[32m'
	YELLOW =	'\033[33m'
	BLUE =		'\033[34m'
	MAGENTA =	'\033[35m'
	CYAN =		'\033[36m'
	WHITE =		'\033[37m'
	END =		'\033[0m'

def color(input, color, bold=False):
	if   color in ['k', 'black']:
		return (Style.BLACK + str(input) + Style.END) if not bold else (Style.BOLD + Style.BLACK + str(input) + Style.END)
	elif color in ['r', 'red']:
		return Style.RED + str(input) + Style.END if not bold else (Style.BOLD + Style.RED + str(input) + Style.END)
	elif color in ['g', 'green']:
		return Style.GREEN + str(input) + Style.END if not bold else (Style.BOLD + Style.GREEN + str(input) + Style.END)
	elif color in ['y', 'yellow']:
		return Style.YELLOW + str(input) + Style.END if not bold else (Style.BOLD + Style.YELLOW + str(input) + Style.END)
	elif color in ['b', 'blue']:
		return Style.BLUE + str(input) + Style.END if not bold else (Style.BOLD + Style.BLUE + str(input) + Style.END)
	elif color in ['m', 'magenta']:
		return Style.MAGENTA + str(input) + Style.END if not bold else (Style.BOLD + Style.MAGENTA + str(input) + Style.END)
	elif color in ['c', 'cyan']:
		return Style.CYAN + str(input) + Style.END if not bold else (Style.BOLD + Style.CYAN + str(input) + Style.END)
	elif color in ['w', 'white']:
		return Style.WHITE + str(input) + Style.END if not bold else (Style.BOLD + Style.WHITE + str(input) + Style.END)
	else:
		raise SystemExit(f'invalid parameters')

def cmd(command):
    os.system(f'cmd /c \"{command}\"')

In [9]:
# Clone the data repository from GitHub:
path =  'FlagNet_data'
token = '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'
url =   f'https://{token}@github.com/yuval-ro/FlagNet_data.git'

try:
    git.Repo.clone_from(url, path)
except:
    git.rmtree(path)
    git.Repo.clone_from(url, path)

In [12]:
# Dataset sanity check and display:
IMAGES_PER_CLASS = 30
dataset_dir = path + '\\dataset'
classes_json = path + '\\classes.json'
images_per_class = []
bad_dirs = []
json_ne_dirs = False

# Locates and parses the "classes.json" file:
with open(classes_json, 'r') as f:
    classes = OrderedDict(json.load(f))
    dir_names = [dataset_dir + '\\%.2d' % i for i in range(1, len(classes) + 1)]


# Checks the number of classes defined in the json equal to number of classes subdirs:
if len(os.listdir(dataset_dir)) != len(classes):
    json_ne_dirs=True

# Checks the number of images of each class subdir:
for dir_tuple in os.walk(dataset_dir):
    if dir_tuple[0] in dir_names: # skips junk directories
        images_in_dir = len(dir_tuple[2])
        images_per_class.append(images_in_dir)
        if images_in_dir != IMAGES_PER_CLASS:
            bad_dirs.append(dir_tuple[0])

# Displays the metadata nicely:
def println(s1='', s2='', s3='', bold=False):
    if bold:
        print(Style.BOLD+'{0:<20}'.format(s1), '{0:<20}'.format(s2), '{0:<20}'.format(s3)+Style.END)
    else:
        print('{0:<20}'.format(s1), '{0:<20}'.format(s2), '{0:<20}'.format(s3))


println('id', 'class', 'images', bold=True)

for i, (ID, Class) in enumerate(classes.items()):
    println(ID,
            Class.upper() if Class in ['uk','usa'] else Class.capitalize(),
            images_per_class[i] if images_per_class[i] == IMAGES_PER_CLASS else color(images_per_class[i], 'r'))
println(s3='total', bold=True)
println(s3=(sum(images_per_class)) if sum(images_per_class) == (len(classes) * IMAGES_PER_CLASS) else color(sum(images_per_class), 'r'))

# Throws exceptions if needed:
if json_ne_dirs:
    raise SystemExit(f'number of classes according to the json file ({len(classes)}) does not correlate with total dirs ({len(os.listdir(dataset_dir))}) in \"{dataset_dir}\"')
elif bad_dirs != []:
    raise SystemExit(f'image count in the following directories is incorrect: {bad_dirs}')
else:
    println(s3=color('all okay!', 'g'), bold=True)

[1mid                   class                images              [0m
1                    Australia            30                  
2                    Brazil               30                  
3                    Canada               30                  
4                    China                30                  
5                    France               30                  
6                    Germany              30                  
7                    India                30                  
8                    Israel               30                  
9                    Italy                30                  
10                   Japan                30                  
11                   Russia               30                  
12                   Spain                30                  
13                   Sweden               30                  
14                   UK                   30                  
15                   USA                  30   