In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tarfile
import numpy as np
import six

from six.moves import cPickle

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import image_utils
from tensor2tensor.data_generators import mnist
from tensor2tensor.utils import metrics
from tensor2tensor.utils import registry

import tensorflow as tf

# URLs and filenames for CIFAR data.
_CIFAR10_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
_CIFAR10_PREFIX = "cifar-10-batches-py/"
_CIFAR10_TRAIN_FILES = [
    "data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4",
    "data_batch_5"
]
_CIFAR10_TEST_FILES = ["test_batch"]
_CIFAR10_IMAGE_SIZE = _CIFAR100_IMAGE_SIZE = 32

_CIFAR100_URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
_CIFAR100_PREFIX = "cifar-100-python/"
_CIFAR100_TRAIN_FILES = ["train"]
_CIFAR100_TEST_FILES = ["test"]




  return f(*args, **kwds)


In [2]:
def _get_cifar(directory, url):
  """Download and extract CIFAR to directory unless it is there."""
  filename = os.path.basename(url)
  path = generator_utils.maybe_download(directory, filename, url)
  tarfile.open(path, "r:gz").extractall(directory)


def cifar_generator(cifar_version, tmp_dir, training, how_many, start_from=0):
  """Image generator for CIFAR-10 and 100.
  Args:
    cifar_version: string; one of "cifar10" or "cifar100"
    tmp_dir: path to temporary storage directory.
    training: a Boolean; if true, we use the train set, otherwise the test set.
    how_many: how many images and labels to generate.
    start_from: from which image to start.
  Returns:
    An instance of image_generator that produces CIFAR-10 images and labels.
  """
  if cifar_version == "cifar10":
    url = _CIFAR10_URL
    train_files = _CIFAR10_TRAIN_FILES
    test_files = _CIFAR10_TEST_FILES
    prefix = _CIFAR10_PREFIX
    image_size = _CIFAR10_IMAGE_SIZE
    label_key = "labels"
  elif cifar_version == "cifar100" or cifar_version == "cifar20":
    url = _CIFAR100_URL
    train_files = _CIFAR100_TRAIN_FILES
    test_files = _CIFAR100_TEST_FILES
    prefix = _CIFAR100_PREFIX
    image_size = _CIFAR100_IMAGE_SIZE
    if cifar_version == "cifar100":
      label_key = "fine_labels"
    else:
      label_key = "coarse_labels"

  _get_cifar(tmp_dir, url)
  data_files = train_files if training else test_files
  all_images, all_labels = [], []
  for filename in data_files:
    path = os.path.join(tmp_dir, prefix, filename)
    with tf.gfile.Open(path, "rb") as f:
      if six.PY2:
        data = cPickle.load(f)
      else:
        data = cPickle.load(f, encoding="latin1")
    images = data["data"]
    num_images = images.shape[0]
    images = images.reshape((num_images, 3, image_size, image_size))
    all_images.extend([
        np.squeeze(images[j]).transpose((1, 2, 0)) for j in range(num_images)
    ])
    labels = data[label_key]
    all_labels.extend([labels[j] for j in range(num_images)])
  return image_utils.image_generator(
      all_images[start_from:start_from + how_many],
      all_labels[start_from:start_from + how_many])

In [3]:
cifar_generator("cifar100", "./datasets/cifar-100-parsed", True, 0, 2000)

INFO:tensorflow:Not downloading, file already found: ./datasets/cifar-100-parsed/cifar-100-python.tar.gz


<generator object image_generator at 0x12bf37678>