Permalink
Branch: master
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
185 lines (161 sloc) 6.43 KB
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data utils for CIFAR-10 and CIFAR-100."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import cPickle
import os
import augmentation_transforms
import numpy as np
import policies as found_policies
import tensorflow as tf
# pylint:disable=logging-format-interpolation
class DataSet(object):
"""Dataset object that produces augmented training and eval data."""
def __init__(self, hparams):
self.hparams = hparams
self.epochs = 0
self.curr_train_index = 0
all_labels = []
self.good_policies = found_policies.good_policies()
# Determine how many databatched to load
num_data_batches_to_load = 5
total_batches_to_load = num_data_batches_to_load
train_batches_to_load = total_batches_to_load
assert hparams.train_size + hparams.validation_size <= 50000
if hparams.eval_test:
total_batches_to_load += 1
# Determine how many images we have loaded
total_dataset_size = 10000 * num_data_batches_to_load
train_dataset_size = total_dataset_size
if hparams.eval_test:
total_dataset_size += 10000
if hparams.dataset == 'cifar10':
all_data = np.empty((total_batches_to_load, 10000, 3072), dtype=np.uint8)
elif hparams.dataset == 'cifar100':
assert num_data_batches_to_load == 5
all_data = np.empty((1, 50000, 3072), dtype=np.uint8)
if hparams.eval_test:
test_data = np.empty((1, 10000, 3072), dtype=np.uint8)
if hparams.dataset == 'cifar10':
tf.logging.info('Cifar10')
datafiles = [
'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4',
'data_batch_5']
datafiles = datafiles[:train_batches_to_load]
if hparams.eval_test:
datafiles.append('test_batch')
num_classes = 10
elif hparams.dataset == 'cifar100':
datafiles = ['train']
if hparams.eval_test:
datafiles.append('test')
num_classes = 100
else:
raise NotImplementedError('Unimplemented dataset: ', hparams.dataset)
if hparams.dataset != 'test':
for file_num, f in enumerate(datafiles):
d = unpickle(os.path.join(hparams.data_path, f))
if f == 'test':
test_data[0] = copy.deepcopy(d['data'])
all_data = np.concatenate([all_data, test_data], axis=1)
else:
all_data[file_num] = copy.deepcopy(d['data'])
if hparams.dataset == 'cifar10':
labels = np.array(d['labels'])
else:
labels = np.array(d['fine_labels'])
nsamples = len(labels)
for idx in range(nsamples):
all_labels.append(labels[idx])
all_data = all_data.reshape(total_dataset_size, 3072)
all_data = all_data.reshape(-1, 3, 32, 32)
all_data = all_data.transpose(0, 2, 3, 1).copy()
all_data = all_data / 255.0
mean = augmentation_transforms.MEANS
std = augmentation_transforms.STDS
tf.logging.info('mean:{} std: {}'.format(mean, std))
all_data = (all_data - mean) / std
all_labels = np.eye(num_classes)[np.array(all_labels, dtype=np.int32)]
assert len(all_data) == len(all_labels)
tf.logging.info(
'In CIFAR10 loader, number of images: {}'.format(len(all_data)))
# Break off test data
if hparams.eval_test:
self.test_images = all_data[train_dataset_size:]
self.test_labels = all_labels[train_dataset_size:]
# Shuffle the rest of the data
all_data = all_data[:train_dataset_size]
all_labels = all_labels[:train_dataset_size]
np.random.seed(0)
perm = np.arange(len(all_data))
np.random.shuffle(perm)
all_data = all_data[perm]
all_labels = all_labels[perm]
# Break into train and val
train_size, val_size = hparams.train_size, hparams.validation_size
assert 50000 >= train_size + val_size
self.train_images = all_data[:train_size]
self.train_labels = all_labels[:train_size]
self.val_images = all_data[train_size:train_size + val_size]
self.val_labels = all_labels[train_size:train_size + val_size]
self.num_train = self.train_images.shape[0]
def next_batch(self):
"""Return the next minibatch of augmented data."""
next_train_index = self.curr_train_index + self.hparams.batch_size
if next_train_index > self.num_train:
# Increase epoch number
epoch = self.epochs + 1
self.reset()
self.epochs = epoch
batched_data = (
self.train_images[self.curr_train_index:
self.curr_train_index + self.hparams.batch_size],
self.train_labels[self.curr_train_index:
self.curr_train_index + self.hparams.batch_size])
final_imgs = []
images, labels = batched_data
for data in images:
epoch_policy = self.good_policies[np.random.choice(
len(self.good_policies))]
final_img = augmentation_transforms.apply_policy(
epoch_policy, data)
final_img = augmentation_transforms.random_flip(
augmentation_transforms.zero_pad_and_crop(final_img, 4))
# Apply cutout
final_img = augmentation_transforms.cutout_numpy(final_img)
final_imgs.append(final_img)
batched_data = (np.array(final_imgs, np.float32), labels)
self.curr_train_index += self.hparams.batch_size
return batched_data
def reset(self):
"""Reset training data and index into the training data."""
self.epochs = 0
# Shuffle the training data
perm = np.arange(self.num_train)
np.random.shuffle(perm)
assert self.num_train == self.train_images.shape[
0], 'Error incorrect shuffling mask'
self.train_images = self.train_images[perm]
self.train_labels = self.train_labels[perm]
self.curr_train_index = 0
def unpickle(f):
tf.logging.info('loading file: {}'.format(f))
fo = tf.gfile.Open(f, 'r')
d = cPickle.load(fo)
fo.close()
return d