# Get Some Data
(*Objective: Download CIFAR10 into Google Drive, convert it to TFRecords. Time: 5 mins*) 

## The CIFAR10 dataset

We will use the [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset that contains 32x32 RGB images. It is used for proof of concept testing in several machine learning methods and also to evaluate of CNN image classifiers. 

<img src="https://miro.medium.com/max/824/1*SZnidBt7CQ4Xqcag6rd8Ew.png" width="700" border="1"/>

The current value of accuracy metric for this dataset has reached 99%. You can find the leaderboard in [paperswithcode.com](https://paperswithcode.com/sota/image-classification-on-cifar-10) a great site that hosts papers and links to their corresponding source code.

We start by mounting Google Drive in this runtime

In [0]:
from google.colab import drive
drive.mount('/content/gdrive/')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive/


## Downloading from URL
The `:DataSetCifar10Downloader` is a class that checks the existence of CIFAR10 files and if needed downloads the dataset from [Alex Krizhevsky's](https://qz.com/1307091/the-inside-story-of-how-ai-got-good-enough-to-dominate-silicon-valley/) page. It extracts the archive and copies them in the `data/cifar10` subfolder of your tutorial workspace. 

In [0]:
import os
import shutil
import sys
import zipfile
import tarfile
import pickle
from urllib.request import urlretrieve
import numpy as np

ORIGINAL_DATASET_FOLDER = "/content/gdrive/My Drive/Colab Notebooks/OOT2019/data/cifar10"
TFRECORDS_DATASET_FOLDER = "/content/gdrive/My Drive/Colab Notebooks/OOT2019/data/tfcifar10"


# =======================================================================================================================
class DataSetCifar10Downloader(object):
    DOWNLOAD_URL = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    
    # --------------------------------------------------------------------------------------------------------
    def __init__(self, p_sDataFolder):
        self.TempFolder = "/tmp"
        self.DataFolder = p_sDataFolder
    # --------------------------------------------------------------------------------------------------------            
    def _downloadProgressCallBack(self, count, block_size, total_size):
        pct_complete = float(count * block_size) / total_size
        msg = "\r- Download progress: {0:.1%}".format(pct_complete)
        sys.stdout.write(msg)
        sys.stdout.flush()        
    # --------------------------------------------------------------------------------------------------------
    def __ensureDataSetIsOnDisk(self):
        sSuffix = DataSetCifar10.DOWNLOAD_URL.split('/')[-1]
        sArchiveFileName = os.path.join(self.TempFolder, sSuffix)
        
        if not os.path.isfile(sArchiveFileName):
            sFilePath, _ = urlretrieve(url=DataSetCifar10.DOWNLOAD_URL, filename=sArchiveFileName, reporthook=self._downloadProgressCallBack)
            print()
            print("Download finished. Extracting files.")

            
        if sArchiveFileName.endswith(".zip"):
            zipfile.ZipFile(file=sArchiveFileName, mode="r").extractall(self.TempFolder)
        elif sArchiveFileName.endswith((".tar.gz", ".tgz")):
            tarfile.open(name=sArchiveFileName, mode="r:gz").extractall(self.TempFolder)
        print("Done.")

        shutil.move(os.path.join(self.TempFolder, "./cifar-10-batches-py"), self.DataFolder)

        os.remove(sArchiveFileName)
    # --------------------------------------------------------------------------------------------------------
    def Download(self):
        if not os.path.exists(self.DataFolder):
            self.__ensureDataSetIsOnDisk()
    # --------------------------------------------------------------------------------------------------------
# =======================================================================================================================


oDataSet = DataSetCifar10(ORIGINAL_DATASET_FOLDER)
oDataSet.Download()

## Converting to TFRecords format

We are going to convert the images into the TFRecord format and store the pixel mean and standard deviation for later use.

In [0]:
import os
import tensorflow as tf
import joblib

# --------------------------------------------------------------------------------------------------------
def convert_to_tfdataset(p_sSourceDataFolder, p_sDestDataFolder):
    # --------------------------------------------------------------------------------------------------------
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    # --------------------------------------------------------------------------------------------------------
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    # --------------------------------------------------------------------------------------------------------
    def save_to_records(p_sFileName, p_nImages, p_nLabels, p_sSubsetName):
        print("[>] Converting %s subset" % p_sSubsetName)
        writer = tf.python_io.TFRecordWriter(p_sFileName)
        for i in range(p_nImages.shape[0]):
            image_raw = p_nImages[i].tostring()
            example = tf.train.Example(features=tf.train.Features(feature={
                'height'    : _int64_feature(32),
                'width'     : _int64_feature(32),
                'depth'     : _int64_feature(3),
                'label'     : _int64_feature(int(p_nLabels[i])),
                'image_raw' : _bytes_feature(image_raw)
                }))
            writer.write(example.SerializeToString())
            if ((i+1) % 5000) == 0:
              print(" |__ Written %d tfrecords" % (i+1))
    # --------------------------------------------------------------------------------------------------------        
    if not os.path.exists(p_sDestDataFolder):
      os.makedirs(p_sDestDataFolder)
    
    # train set
    train_images = np.zeros((50000,3072), dtype=np.uint8)
    trian_labels = np.zeros((50000,), dtype=np.int32)
    for i in range(5):
        sFileName = os.path.join(p_sSourceDataFolder, 'data_batch_%d' % (i+1))
        with open(sFileName, 'rb') as oFile:
                data_batch = pickle.load(oFile, encoding='latin1')

        train_images[10000*i:10000*(i+1)] = data_batch['data']
        trian_labels[10000*i:10000*(i+1)] = np.asarray(data_batch['labels'], dtype=np.int32)
    train_images = np.reshape(train_images, [50000,3,32,32])
    train_images = np.transpose(train_images, axes=[0,2,3,1]) # NCHW -> NHWC
    save_to_records(os.path.join(p_sDestDataFolder, "train.tf"), train_images, trian_labels, "training")

    # mean and std
    print("[>] Calculating and storing pixel mean and std values")
    image_mean = np.mean(train_images.astype(np.float32), axis=(0,1,2))
    image_std = np.std(train_images.astype(np.float32), axis=(0,1,2))
    joblib.dump({'mean': image_mean, 'std': image_std}, os.path.join(p_sDestDataFolder, "meanstd.pkl"), compress=5)

    # test set
    sFileName = os.path.join(p_sSourceDataFolder, 'test_batch')
    with open(sFileName, 'rb') as oFile:
        data_batch = pickle.load(oFile, encoding='latin1')
    
    
    test_images = data_batch['data']
    test_images = np.reshape(test_images, [10000,3,32,32])
    test_images = np.transpose(test_images, axes=[0,2,3,1])
    test_labels = np.asarray(data_batch['labels'], dtype=np.int32)
    save_to_records(os.path.join(p_sDestDataFolder, "test.tf"), test_images, test_labels, "testing")
# --------------------------------------------------------------------------------------------------------

convert_to_tfdataset(ORIGINAL_DATASET_FOLDER, TFRECORDS_DATASET_FOLDER)