# Example of Downloading and Formatting Data for MNIST
Used to get the data for the convolutional neural network in the other notebook. Might be useful for testing how different model architectures function with a larger dataset. There's a LOT here we don't need but some functions here might be useful for our images.

A lot of this is borrowed from Google.

In [14]:
# Import bunch of useful packages
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import tarfile
import imagehash  # needs to be installed

from IPython.display import display, Image
from scipy import ndimage
from sklearn.linear_model import LogisticRegression
from six.moves.urllib.request import urlretrieve
from six.moves import cPickle as pickle
from PIL import Image

### Useful functions

Some functions to download and extract datasets from a URL:

In [15]:
def download_progress_hook(count, blockSize, totalSize):

    global last_percent_reported
    percent = int(count * blockSize * 100 / totalSize)

    if last_percent_reported != percent:
        if percent % 5 == 0:
            sys.stdout.write('%s%%' % percent)
            sys.stdout.flush()
        else:
            sys.stdout.write('.')
            sys.stdout.flush()

        last_percent_reported = percent

In [16]:
def maybe_download(filename, expected_bytes, force=False):
    dest_filename = os.path.join(data_root, filename)
    if force or not os.path.exists(dest_filename):
        print('Attempting to download', filename)
        filename, _ = urlretrieve(url + filename, dest_filename, reporthook=download_progress_hook)
        print('\nDownload complete!')
    statinfo = os.stat(dest_filename)
    if statinfo.st_size == expected_bytes:
        print('Found and verified', dest_filename)
    else:
        print('Failed to verify' + dest_filename + '. Can you get to it with a browser?')
    return dest_filename

In [17]:
def maybe_extract(filename, force=False):
    root = os.path.splitext(os.path.splitext(filename)[0])[0]
    if os.path.isdir(root) and not force:
        print('%s already present - skipping extraction of %s' % (root, filename))
    else:
        print('Extracting data for %s - this may take a while.' % root)
        tar = tarfile.open(filename)
        sys.stdout.flush()
        tar.extractall(data_root)
        tar.close()
    data_folders = [
        os.path.join(root, d) for d in sorted(os.listdir(root)) if os.path.isdir(os.path.join(root, d))
    ]
    if len(data_folders) != num_classes:
        raise Exception("Expected %s folders, found %s instead" % (num_classes, len(data_folders)))
    return data_folders

Functions to check data:

In [18]:
def display_image_samples(data_folders):
    for data_folder in data_folders:
        sample_image_files = [os.path.join(data_folder, os.listdir(data_folder)[i]) for i in range(1)]
        for image_file in sample_image_files:
            if os.path.isfile(image_file):
                image = plt.imread(image_file)
                plt.title(data_folder)
                plt.imshow(image)
                plt.show()

In [19]:
def load_letter(folder, min_num_images):
    """Loads data for a single letter"""
    image_size = 28
    pixel_depth = 255.0
    image_files = [os.path.join(folder, file) for file in os.listdir(folder)]
    dataset = np.ndarray(shape=(len(image_files), image_size, image_size),
                         dtype=np.float32)
    num_images = 0
    for image_file in image_files:
        try:
            image_data = (ndimage.imread(image_file).astype(float) - pixel_depth / 2) / pixel_depth
            if image_data.shape != (image_size, image_size):
                raise Exception("Unexpected image shape: %s" % str(image_data.shape))
            dataset[num_images, :, :] = image_data
            num_images += 1
        except IOError as e:
            print("Could not read image file: ", e, ". It\'s ok - skipping")

    dataset = dataset[0:num_images, :, :]
    if num_images <= min_num_images:
        raise Exception("Significantly fewer images than expected: %s < %s - please check."
                        % (num_images, min_num_images))
    print("Full dataset tensor: ", dataset.shape)
    print("Mean:\t%s" % np.mean(dataset))
    print("Stddev: %s" % np.std(dataset))
    return dataset

Wrangling, storing, and merging data:

In [20]:
def maybe_pickle(data_folders, min_num_images_per_class, force=False):
    dataset_names = []
    for folder in data_folders:
        set_filename = folder + '.pickle'
        dataset_names.append(set_filename)
        if os.path.exists(set_filename) and not force:
            # You may override by setting force=True.
            print('%s already present - Skipping pickling.' % set_filename)
        else:
            print('Pickling %s.' % set_filename)
            dataset = load_letter(folder, min_num_images_per_class)
            try:
                with open(set_filename, 'wb') as f:
                    pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
            except Exception as e:
                print('Unable to save data to', set_filename, ':', e)

    return dataset_names

In [21]:
def make_arrays(row_num, img_size):
    if row_num:
        dataset = np.ndarray((row_num, img_size, img_size), dtype=np.float32)
        labels = np.ndarray(row_num, dtype=np.int32)
    else:
        dataset, labels = None, None
    return dataset, labels

In [22]:
# Combines all datasets -- this remove similars flag makes this more complicated than it would otherwise be
def merge_datasets(pickle_files, train_size, valid_size=0, remove_similar=False):
    num_classes = len(pickle_files)
    valid_dataset, valid_labels = make_arrays(valid_size, image_size)
    train_dataset, train_labels = make_arrays(train_size, image_size)
    vsize_per_class = valid_size // num_classes
    tsize_per_class = train_size // num_classes

    start_v, start_t = 0, 0
    end_v, end_t = vsize_per_class, tsize_per_class
    end_l = vsize_per_class + tsize_per_class
    for label, pickle_file in enumerate(pickle_files):
        try:
            with open(pickle_file, 'rb') as f:
                letter_set = pickle.load(f)
                # let's shuffle the letters to have random validation and training set
                np.random.shuffle(letter_set)
                if remove_similar and label != 8:
                    letter_set, _ = remove_similars(letter_set)
                if valid_dataset is not None:
                    valid_letter_set = letter_set[:vsize_per_class, :, :]
                    valid_dataset[start_v:end_v, :, :] = valid_letter_set
                    valid_labels[start_v:end_v] = label
                    start_v += vsize_per_class
                    end_v += vsize_per_class

                train_letter_set = letter_set[vsize_per_class:end_l, :, :]
                train_dataset[start_t:end_t, :, :] = train_letter_set
                train_labels[start_t:end_t] = label
                start_t += tsize_per_class
                end_t += tsize_per_class
        except Exception as e:
            print('Unable to process data from', pickle_file, ':', e)
            raise

    return valid_dataset, valid_labels, train_dataset, train_labels

In [23]:
def randomise(dataset, labels):
    permutation = np.random.permutation(labels.shape[0])
    shuffled_dataset = dataset[permutation, :, :]
    shuffled_labels = labels[permutation]
    return shuffled_dataset, shuffled_labels

If any duplicates need to be removed -- I will include another script that I used to remove images that are too similar:

In [29]:
def remove_similars(dataset, labels=None):
    num_before = len(dataset)
    hashed_dataset = [hash_image(img) for img in dataset]
    _, index_u = np.unique(hashed_dataset, return_index=True)
    if labels:
        labels = labels[index_u]
    num_after = len(index_u)
    dataset = dataset[index_u, :, :]
    print("{} repeated samples removed of {}".format(num_before - num_after, num_before))
    return dataset, labels

def hash_image(image):
    image_int = (image * 255 + 255).astype(np.int32)
    hashed_image = imagehash.dhash(Image.fromarray(image_int).resize((16, 16), Image.ANTIALIAS))
    return str(hashed_image)

In [25]:
def display_image_sample_pretty(dataset, labels):
    alph = 'ABCDEFGHIJ'
    indexes = np.random.permutation(len(labels))[0:12]
    sample_images = dataset[indexes, :, :]
    sample_labels = labels[indexes]

    for i in range(len(indexes)):
        alph_label = alph[sample_labels[i]]
        plt.subplot(3, 4, i+1)
        plt.axis('off')
        plt.title(alph_label)
        plt.imshow(sample_images[i, :, :])

    plt.show()

### Running the functions


In [30]:
# Setting URL and folder for data
url = 'http://commondatastorage.googleapis.com/books1000/'
last_percent_reported = None
data_root = 'mnist/'  # make sure you create this folder manually before running this cell

num_classes = 10
image_size = 28

train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)
test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)

train_folders = maybe_extract(train_filename)
test_folders = maybe_extract(test_filename)

# pprint(train_folders)
# display_image_samples(train_folders)
# display_image_samples(test_folders)

train_datasets = maybe_pickle(train_folders, 45000)
test_datasets = maybe_pickle(test_folders, 1800)

train_size = 200000
valid_size = 10000
test_size = 10000

valid_dataset, valid_labels, train_dataset, train_labels = merge_datasets(
    train_datasets, train_size, valid_size, remove_similar=False)
_, _, test_dataset, test_labels = merge_datasets(test_datasets, test_size, remove_similar=True)

print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_dataset.shape, valid_labels.shape)
print('Testing:', test_dataset.shape, test_labels.shape)

train_dataset, train_labels = randomise(train_dataset, train_labels)
valid_dataset, valid_labels = randomise(valid_dataset, valid_labels)
test_dataset, test_labels = randomise(test_dataset, test_labels)

pickle_file = os.path.join(data_root, 'MNISTu.pickle')

try:
    with open(pickle_file, 'wb') as f:
        save = {
            'train_dataset': train_dataset,
            'train_labels': train_labels,
            'valid_dataset': valid_dataset,
            'valid_labels': valid_labels,
            'test_dataset': test_dataset,
            'test_labels': test_labels
        }
        pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
except Exception as e:
    print("Unable to save pickle file: ", pickle_file, ": ", e)
    raise

statinfo = os.stat(pickle_file)
print("Compressed size: ", statinfo.st_size)

Found and verified mnist/notMNIST_large.tar.gz
Found and verified mnist/notMNIST_small.tar.gz
mnist/notMNIST_large already present - skipping extraction of mnist/notMNIST_large.tar.gz
mnist/notMNIST_small already present - skipping extraction of mnist/notMNIST_small.tar.gz
mnist/notMNIST_large\A.pickle already present - Skipping pickling.
mnist/notMNIST_large\B.pickle already present - Skipping pickling.
mnist/notMNIST_large\C.pickle already present - Skipping pickling.
mnist/notMNIST_large\D.pickle already present - Skipping pickling.
mnist/notMNIST_large\E.pickle already present - Skipping pickling.
mnist/notMNIST_large\F.pickle already present - Skipping pickling.
mnist/notMNIST_large\G.pickle already present - Skipping pickling.
mnist/notMNIST_large\H.pickle already present - Skipping pickling.
mnist/notMNIST_large\I.pickle already present - Skipping pickling.
mnist/notMNIST_large\J.pickle already present - Skipping pickling.
mnist/notMNIST_small\A.pickle already present - Skipping

If you run this make sure you don't push the downloaded MNIST data to the repository -- it'll take really long to clone.