In [2]:
import tensorflow as tf
import numpy as np
import pandas as pd
import requests 
import tarfile
from io import StringIO
import os
import pickle
from PIL import Image
import matplotlib.pyplot as plt
from functools import partial
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras import layers
from tensorflow.python.keras import backend as K

%matplotlib inline

## Downloading the CIFAR-10 dataset

In [None]:
def download_cifar10(url):
    f = os.path.join('cifar10', "cifar-10-python.tar.gz")
    try:
        files = [
            os.path.join('cifar10', 'cifar-10-batches-py', 'meta_data'),
            os.path.join('cifar10', 'cifar-10-batches-py', 'data_batch_1'),
            os.path.join('cifar10', 'cifar-10-batches-py', 'data_batch_2'),
            os.path.join('cifar10', 'cifar-10-batches-py', 'data_batch_3'),
            os.path.join('cifar10', 'cifar-10-batches-py', 'data_batch_4'),
            os.path.join('cifar10', 'cifar-10-batches-py', 'data_batch_5'),
            os.path.join('cifar10', 'cifar-10-batches-py', 'test_batch')
        ]
        print('Checking if the following files exists\n{}\n'.format(files))
        
        assert_msg = 'Some of the files were missing'
        assert all([os.path.exists(path) for path in files]), assert_msg
    except:
        try:
            print("Extracting {}".format(f))
            filename = os.path.join('cifar10', f)
            tar = tarfile.open(f, "r:gz")
            tar.extractall('cifar10')
            tar.close()
        except FileNotFoundError:
            print("Unable to find the file {}".format(f))
            print('Downloading CIFAR-10 from {}'.format(url))
            res = requests.get(url, stream=True)
            total_length = int(res.headers.get('content-length'))
            print('Detected data size: {}KB'.format(total_length))
            print('Making a directory cifar-10 to store data')
            if not os.path.exists('cifar10'):
                os.mkdir('cifar10')
            with open(f, 'wb') as f:
                print('Downloading data')
                for data in res.iter_content(chunk_size=1024*1024):
                    print('.',end='')
                    f.write(data)
            
            print("Extracting {}".format(f))
            filename = os.path.join('cifar10', f)
            tar = tarfile.open(f, "r:gz")
            tar.extractall('cifar10')
            tar.close()
            
        except Exception as ex:
            print(ex)
        
    print('\n\tDone')
    
download_cifar10('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')

Checking if the following files exists
['cifar10\\cifar-10-batches-py\\meta_data', 'cifar10\\cifar-10-batches-py\\data_batch_1', 'cifar10\\cifar-10-batches-py\\data_batch_2', 'cifar10\\cifar-10-batches-py\\data_batch_3', 'cifar10\\cifar-10-batches-py\\data_batch_4', 'cifar10\\cifar-10-batches-py\\data_batch_5', 'cifar10\\cifar-10-batches-py\\test_batch']

Extracting cifar10\cifar-10-python.tar.gz
Unable to find the file cifar10\cifar-10-python.tar.gz
Downloading CIFAR-10 from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Detected data size: 170498071KB
Making a directory cifar-10 to store data
Downloading data
...........................

## Few helper functions for processing images 
* `vec2image` - Takes in a vector of 3072 elements, reshap it to 32x32x3, subtract the mean and optionally flip the image
* `get_image` - Given a single class gets an image (for visual inspection)

In [None]:
def vec2image(image_vec, normalize=False, global_mean=None, flip=False):
    """ Creating a 2D image from the 1D vector in the data """
    assert image_vec.size == 3072, "This (shape:{}) is not a CIFAR-10 Image".format(image_vec.shape)
    img_mat = image_vec.reshape(32,32,3, order='F')
    img_mat = np.rot90(img_mat,3)
    if normalize and global_mean:
        img_mat = img_mat.astype(np.float32) - global_mean
    if flip:
        img_mat = np.flip(img_mat, axis=1)
    return img_mat

def get_image(file, class_label):
    """ Given a class get an image """
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    
    for image_vec, label in zip(dict[b"data"], dict[b"labels"]):
        if label==class_label:
            img_mat = vec2image(image_vec)
            return img_mat
        else:
            continue
    return None

def get_label_to_name_map(file):
    """ Get a list of label names in the data """
    with open(file, 'rb') as f:
        label_dict = pickle.load(f, encoding='bytes')
        return [str(v,'utf-8') for v in label_dict[b"label_names"]]
    
    return None

# file batches.meta for label to string map
label_map = get_label_to_name_map(os.path.join('cifar10', 'cifar-10-batches-py', 'batches.meta'))

plt.subplots(2,5)
for cls in range(10):
    img = get_image(os.path.join('cifar10', 'cifar-10-batches-py', 'data_batch_1'), cls)
    plt.subplot(2,5,cls+1)
    plt.imshow(img)
    plt.title(label_map[cls])
    plt.axis('off')

## Defining the Keras graph

In [None]:
tf.reset_default_graph()
K.clear_session()

# TODO: Define the model

## Running the model
* Using `train_on_batch` function

In [None]:
data_dir = os.path.join('cifar10','cifar-10-batches-py')
n_classes = 10
batch_size = 32

train_file_list = [
    os.path.join(data_dir, f) \
    for f in ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
]
test_file_list = [os.path.join(data_dir, 'test_batch')]

for ep in range(20):
    losses = []
    accuracy = []
    for f in train_file_list:

        with open(f, 'rb') as fo:
            # labels, data
            dict = pickle.load(fo, encoding='bytes')
            part_vec2image = partial(
                vec2image, normalize=True, global_mean=np.mean(dict[b"data"])
            )


            " Going through each batch in the data file"
            for di in range(len(dict[b"labels"])//batch_size):
                # Defining random indices
                rand_idx = np.random.randint(
                    0, len(dict[b"labels"]), batch_size
                )
            
                # Creating onehot labels from class labels
                batch_one_hot = np.zeros(shape=(batch_size, n_classes))
                batch_one_hot[
                    np.arange(batch_size), np.array(dict[b"labels"])[rand_idx]
                ] = 1.0

                # Creating a batch of images
                batch_images = np.apply_along_axis(
                    part_vec2image, axis=1, arr=dict[b"data"][rand_idx,:]
                )
                
                # TODO: Save a set of normalized images

                # Training the CNN on batch of images
                # TODO: Train the model on a single batch

                # TODO: Evaluate the model

                losses.append(loss)
                accuracy.append(acc)
        print('Loss for epoch: {}'.format(np.mean(losses)))
        print('Train accuracy for epoch: {}'.format(np.mean(accuracy)))

## Running the model
* Using `fit` function with validation