Skip to content

Commit

Permalink
Class names for cifar/fashion mnist (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Aug 16, 2018
1 parent ea173d0 commit f227f45
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 28 deletions.
3 changes: 1 addition & 2 deletions examples/FasterRCNN/basemodel.py
Expand Up @@ -4,7 +4,6 @@
from contextlib import contextmanager, ExitStack
import numpy as np
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable

from tensorpack.tfutils import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
Expand Down Expand Up @@ -49,7 +48,7 @@ def freeze_affine_getter(getter, *args, **kwargs):
if name.endswith('/gamma') or name.endswith('/beta'):
kwargs['trainable'] = False
ret = getter(*args, **kwargs)
add_model_variable(ret)
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, ret)
else:
ret = getter(*args, **kwargs)
return ret
Expand Down
41 changes: 29 additions & 12 deletions tensorpack/dataflow/dataset/cifar.py
Expand Up @@ -66,14 +66,22 @@ def read_cifar(filenames, cifar_classnum):
def get_filenames(dir, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10:
filenames = [os.path.join(
train_files = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)]
filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch'))
test_files = [os.path.join(
dir, 'cifar-10-batches-py', 'test_batch')]
meta_file = os.path.join(dir, 'cifar-10-batches-py', 'batches.meta')
elif cifar_classnum == 100:
filenames = [os.path.join(dir, 'cifar-100-python', 'train'),
os.path.join(dir, 'cifar-100-python', 'test')]
return filenames
train_files = [os.path.join(dir, 'cifar-100-python', 'train')]
test_files = [os.path.join(dir, 'cifar-100-python', 'test')]
meta_file = os.path.join(dir, 'cifar-100-python', 'meta')
return train_files, test_files, meta_file


def _parse_meta(filename, cifar_classnum):
with open(filename, 'rb') as f:
obj = pickle.load(f)
return obj['label_names' if cifar_classnum == 10 else 'fine_label_names']


class CifarBase(RNGDataFlow):
Expand All @@ -84,14 +92,15 @@ def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10):
if dir is None:
dir = get_dataset_path('cifar{}_data'.format(cifar_classnum))
maybe_download_and_extract(dir, self.cifar_classnum)
fnames = get_filenames(dir, cifar_classnum)
train_files, test_files, meta_file = get_filenames(dir, cifar_classnum)
if train_or_test == 'train':
self.fs = fnames[:-1]
self.fs = train_files
else:
self.fs = [fnames[-1]]
self.fs = test_files
for f in self.fs:
if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f)
self._label_names = _parse_meta(meta_file, cifar_classnum)
self.train_or_test = train_or_test
self.data = read_cifar(self.fs, cifar_classnum)
self.dir = dir
Expand All @@ -110,14 +119,22 @@ def get_data(self):

def get_per_pixel_mean(self):
"""
return a mean image of all (train and test) images of size 32x32x3
Returns:
a mean image of all (train and test) images of size 32x32x3
"""
fnames = get_filenames(self.dir, self.cifar_classnum)
all_imgs = [x[0] for x in read_cifar(fnames, self.cifar_classnum)]
train_files, test_files, _ = get_filenames(self.dir, self.cifar_classnum)
all_imgs = [x[0] for x in read_cifar(train_files + test_files, self.cifar_classnum)]
arr = np.array(all_imgs, dtype='float32')
mean = np.mean(arr, axis=0)
return mean

def get_label_names(self):
"""
Returns:
[str]: name of each class.
"""
return self._label_names

def get_per_channel_mean(self):
"""
return three values as mean of each channel
Expand Down
27 changes: 20 additions & 7 deletions tensorpack/dataflow/dataset/mnist.py
Expand Up @@ -67,8 +67,8 @@ class Mnist(RNGDataFlow):
image is 28x28 in the range [0,1], label is an int.
"""

DIR_NAME = 'mnist_data'
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
_DIR_NAME = 'mnist_data'
_SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'

def __init__(self, train_or_test, shuffle=True, dir=None):
"""
Expand All @@ -77,15 +77,15 @@ def __init__(self, train_or_test, shuffle=True, dir=None):
shuffle (bool): shuffle the dataset
"""
if dir is None:
dir = get_dataset_path(self.DIR_NAME)
dir = get_dataset_path(self._DIR_NAME)
assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test
self.shuffle = shuffle

def get_images_and_labels(image_file, label_file):
f = maybe_download(self.SOURCE_URL + image_file, dir)
f = maybe_download(self._SOURCE_URL + image_file, dir)
images = extract_images(f)
f = maybe_download(self.SOURCE_URL + label_file, dir)
f = maybe_download(self._SOURCE_URL + label_file, dir)
labels = extract_labels(f)
assert images.shape[0] == labels.shape[0]
return images, labels
Expand Down Expand Up @@ -113,8 +113,21 @@ def get_data(self):


class FashionMnist(Mnist):
DIR_NAME = 'fashion_mnist_data'
SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
"""
Same API as :class:`Mnist`, but more fashion.
"""

_DIR_NAME = 'fashion_mnist_data'
_SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'

def get_label_names(self):
"""
Returns:
[str]: the name of each class
"""
# copied from https://github.com/zalandoresearch/fashion-mnist
return ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion tensorpack/dataflow/dataset/svhn.py
Expand Up @@ -64,7 +64,8 @@ def get_data(self):
@staticmethod
def get_per_pixel_mean():
"""
return 32x32x3 image
Returns:
a 32x32x3 image
"""
a = SVHNDigit('train')
b = SVHNDigit('test')
Expand Down
5 changes: 2 additions & 3 deletions tensorpack/models/batch_norm.py
Expand Up @@ -3,7 +3,6 @@


import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
import re
import six
Expand Down Expand Up @@ -191,7 +190,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if ctx.is_main_training_tower:
for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable):
add_model_variable(v)
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if not ctx.is_main_training_tower or internal_update:
restore_collection(coll_bk)

Expand Down Expand Up @@ -354,7 +353,7 @@ def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
if ctx.is_main_training_tower:
for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable):
add_model_variable(v)
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
else:
# only run UPDATE_OPS in the first tower
restore_collection(coll_bk)
Expand Down
13 changes: 10 additions & 3 deletions tensorpack/tfutils/varreplace.py
Expand Up @@ -3,7 +3,6 @@
# Credit: Qinyao He

import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from contextlib import contextmanager

from .common import get_tf_version_tuple
Expand All @@ -13,6 +12,13 @@

@contextmanager
def custom_getter_scope(custom_getter):
"""
Args:
custom_getter: the same as in :func:`tf.get_variable`
Returns:
The current variable scope with a custom_getter.
"""
scope = tf.get_variable_scope()
if get_tf_version_tuple() >= (1, 5):
with tf.variable_scope(
Expand All @@ -35,7 +41,8 @@ def remap_variables(fn):
fn (tf.Variable -> tf.Tensor)
Returns:
a context where all the variables will be mapped by fn.
The current variable scope with a custom_getter that maps
all the variables by fn.
Example:
.. code-block:: python
Expand Down Expand Up @@ -83,7 +90,7 @@ def custom_getter(getter, *args, **kwargs):
kwargs['trainable'] = False
v = getter(*args, **kwargs)
if skip_collection:
add_model_variable(v)
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if trainable and stop_gradient:
v = tf.stop_gradient(v, name='freezed_' + name)
return v
Expand Down

0 comments on commit f227f45

Please sign in to comment.