In [16]:
import tensorflow as tf
from six.moves import urllib
import sys
import numpy as np
import _pickle as cPickle
import tarfile
import os

In [2]:
DATA_URL = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
DATA_DIR = './data/cifar10_data/'
CIFAR10_DIR = 'cifar-10-batches-py/'

train_dir = 'cifar10_train'
test_dir = 'cifar10_test'

NUM_TRAIN_SAMPLES = 50000
IMAGE_SIZE = 32
NUM_CHANNELS = 3

In [9]:
def maybe_download_and_extract():
    """Download and extract the tarball from Alex's website."""
    dest_directory = DATA_DIR
    if not os.path.exists(dest_directory):
        os.makedirs(dest_directory)
    filename = DATA_URL.split('/')[-1]
    filepath = os.path.join(dest_directory, filename)
    if not os.path.exists(filepath):
        def _progress(count, block_size, total_size):
            sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
                float(count * block_size) / float(total_size) * 100.0))
            sys.stdout.flush()
        filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
        print()
        statinfo = os.stat(filepath)
        print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
    extracted_dir_path = os.path.join(dest_directory, CIFAR10_DIR)
    if not os.path.exists(extracted_dir_path):
        tarfile.open(filepath, 'r:gz').extractall(dest_directory)
        print('Successfully extracted')
    else:
        print('File present')

In [17]:
maybe_download_and_extract()

Successfully extracted


In [14]:
def load_batch(fpath, label_key='labels'):
    """Internal utility for parsing CIFAR data.
    # Arguments
        fpath: path the file to parse.
        label_key: key for label data in the retrieve
            dictionary.
    # Returns
        A tuple `(data, labels)`.
    """
    f = open(fpath, 'rb')
    d = cPickle.load(f, encoding='bytes')
    # decode utf8
    d_decoded = {}
    for k, v in d.items():
        d_decoded[k.decode('utf8')] = v
    f.close()
    data = d_decoded['data']
    labels = d_decoded[label_key]

    data = data.reshape(data.shape[0], NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
    return data, labels

In [20]:
x_train = None
y_train = None

for i in range(1, 6):
    fpath = os.path.join(DATA_DIR, CIFAR10_DIR, 'data_batch_' + str(i))
    data, labels = load_batch(fpath)
    if x_train is None:
        x_train = data
        y_train = labels
    else:
        x_train = np.concatenate((x_train, data), axis=0)
        y_train = np.concatenate((y_train, labels), axis=0)

fpath = os.path.join(DATA_DIR, CIFAR10_DIR, 'test_batch')
x_test, y_test = load_batch(fpath)

y_train = np.reshape(y_train, (len(y_train), 1))
y_test = np.reshape(y_test, (len(y_test), 1))

x_train = x_train.transpose(0, 2, 3, 1)
x_test = x_test.transpose(0, 2, 3, 1)

print("x_train:", str(x_train.shape))
print("y_train:", str(y_train.shape))

print("x_test:", str(x_train.shape))
print("y_test:", str(y_train.shape))

x_train: (50000, 32, 32, 3)
y_train: (50000, 1)
x_test: (50000, 32, 32, 3)
y_test: (50000, 1)


In [None]:
def _variable_with_weight_decay(name, shape, stddev, wd):
  """Helper to create an initialized Variable with weight decay.

  Note that the Variable is initialized with a truncated normal distribution.
  A weight decay is added only if one is specified.

  Args:
    name: name of the variable
    shape: list of ints
    stddev: standard deviation of a truncated Gaussian
    wd: add L2Loss weight decay multiplied by this float. If None, weight
        decay is not added for this Variable.

  Returns:
    Variable Tensor
  """
  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  var = _variable_on_cpu(
      name,
      shape,
      tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
  if wd is not None:
    weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
    tf.add_to_collection('losses', weight_decay)
  return var