In [None]:
import os
import re
import h5py
import shutil
import skimage
import numpy as np

from glob import glob
from tqdm import tqdm
tqdm.pandas(desc="progress-bar")

import matplotlib.pyplot as plt

### Download dataset and unpack it

In [None]:
_URL = 'https://s3.amazonaws.com/nist-srd/SD18/sd18.zip'

path_to_zip = tf.keras.utils.get_file('sd18.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'sd18/')

### Orginize the images

In [None]:
if not os.path.isdir('data'):
    os.mkdir('data')

filenames = glob(PATH + 'single/f1_p1/*/*.png')
for filename in tqdm(filenames):
    indx = filename.split('/')[-1].split('_')[0]
    # remove leading zeros from index
    indx = re.sub(r'(?<!\d)0+', '', indx)
    side = filename.split('/')[-1].split('_')[2].split('.')[0].lower()
    new_file = 'data/mugshot_{}.{}.png'.format(side, indx)
    shutil.copyfile(filename, new_file)

In [None]:
# Convert Grayscale to RGB 
filenames = glob('data/*.png')
for filename in tqdm(filenames):
    im = skimage.io.imread(filename)
    im = skimage.color.gray2rgb(im)
    im = skimage.transform.resize(im, (256, 256), anti_aliasing=True)
    im = skimage.util.img_as_ubyte(im)
    skimage.io.imsave(filename, im)

In [None]:
# Flip L to R
filenames = glob('data//mugshot_l.*.png')
for filename in tqdm(filenames):
    im = skimage.io.imread(filename)
    im = np.fliplr(im)
    skimage.io.imsave(filename, im)
    # rename file
    new_filename = filename.replace('_l', '_r')
    os.rename(filename, new_filename)

### Save the image dataset into a HDF5

In [None]:
hdf5_path = 'data/dataset.hdf5'
frnt_path = 'data/mugshot_f.*.png'
side_path = 'data/mugshot_r.*.png'

frnt = glob(frnt_path)
side = glob(side_path)

train_inpt = frnt[0:int(0.8*len(frnt))]
train_real = side[0:int(0.8*len(side))]
test_inpt = frnt[int(0.8*len(frnt)):]
test_real = side[int(0.8*len(side)):]

In [None]:
# Define an array for each of train and test set with the shape
# (number of data, image_height, image_width, image_depth)
train_shape = (len(train_inpt), 256, 256, 3)
test_shape = (len(test_inpt), 256, 256, 3)
    
# open a hdf5 file and create earrays
hdf5_file = h5py.File(hdf5_path, mode='w')

hdf5_file.create_dataset("train_inpt", train_shape, np.int8)
hdf5_file.create_dataset("train_real", train_shape, np.int8)
hdf5_file.create_dataset("test_inpt", test_shape, np.int8)
hdf5_file.create_dataset("test_real", test_shape, np.int8)

In [None]:
for i in range(len(train_inpt)):
    img = skimage.io.imread(train_inpt[i])
    hdf5_file["train_inpt"][i, ...] = img[None]
    img = skimage.io.imread(train_real[i])
    hdf5_file["train_real"][i, ...] = img[None]

for i in range(len(test_inpt)):
    img = skimage.io.imread(test_inpt[i])
    hdf5_file["test_inpt"][i, ...] = img[None]
    img = skimage.io.imread(test_real[i])
    hdf5_file["test_real"][i, ...] = img[None]
    
hdf5_file.close()

### Check if the data is saved properly in the HDF5 file

In [None]:
# open the hdf5 file
hdf5_path = 'data/dataset.hdf5'
hdf5_file = h5py.File(hdf5_path, 'r')

# Get total number of samples
num_data = hdf5_file['train_inpt'].shape[0]
print(num_data)

### Cleaning up behind me

In [None]:
# Removing all png images
files = glob('data/*.png')
for file in files:
    os.remove(file)
    
# Remove downloaded data
shutil.rmtree(PATH)