Skip to content

Commit

Permalink
BREAKING CHANGE: Replace public GraphDef eval support with SavedModel…
Browse files Browse the repository at this point in the history
… support.

PiperOrigin-RevId: 276197034
Change-Id: I621f25350c4dd6478a1470b098a66060f17e840d
  • Loading branch information
joel-shor authored and Copybara-Service committed Oct 23, 2019
1 parent df666b9 commit 5fbbba1
Show file tree
Hide file tree
Showing 17 changed files with 888 additions and 1,149 deletions.
6 changes: 0 additions & 6 deletions tensorflow_gan/examples/mnist/conditional_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@
flags.DEFINE_integer('noise_dims', 64,
'Dimensions of the generator noise vector')

flags.DEFINE_string(
'classifier_filename', None,
'Location of the pretrained classifier. If `None`, use '
'default.')

flags.DEFINE_integer(
'max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
Expand All @@ -55,7 +50,6 @@ def main(_):
hparams = conditional_eval_lib.HParams(FLAGS.checkpoint_dir, FLAGS.eval_dir,
FLAGS.num_images_per_class,
FLAGS.noise_dims,
FLAGS.classifier_filename,
FLAGS.max_number_of_evaluations,
FLAGS.write_to_disk)
conditional_eval_lib.evaluate(hparams, run_eval_loop=True)
Expand Down
8 changes: 3 additions & 5 deletions tensorflow_gan/examples/mnist/conditional_eval_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

HParams = collections.namedtuple('HParams', [
'checkpoint_dir', 'eval_dir', 'num_images_per_class', 'noise_dims',
'classifier_filename', 'max_number_of_evaluations', 'write_to_disk'
'max_number_of_evaluations', 'write_to_disk'
])


Expand Down Expand Up @@ -61,12 +61,10 @@ def evaluate(hparams, run_eval_loop=True):

# Calculate evaluation metrics.
tf.compat.v1.summary.scalar(
'MNIST_Classifier_score',
util.mnist_score(images, hparams.classifier_filename))
'MNIST_Classifier_score', util.mnist_score(images))
tf.compat.v1.summary.scalar(
'MNIST_Cross_entropy',
util.mnist_cross_entropy(images, one_hot_labels,
hparams.classifier_filename))
util.mnist_cross_entropy(images, one_hot_labels))

# Write images to disk.
image_write_ops = None
Expand Down
1 change: 0 additions & 1 deletion tensorflow_gan/examples/mnist/conditional_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def test_build_graph(self):
eval_dir='/tmp/mnist/',
num_images_per_class=10,
noise_dims=64,
classifier_filename=None,
max_number_of_evaluations=None,
write_to_disk=True)
conditional_eval_lib.evaluate(hparams, run_eval_loop=False)
Expand Down
6 changes: 0 additions & 6 deletions tensorflow_gan/examples/mnist/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@
flags.DEFINE_integer('noise_dims', 64,
'Dimensions of the generator noise vector')

flags.DEFINE_string(
'classifier_filename', None,
'Location of the pretrained classifier. If `None`, use '
'default.')

flags.DEFINE_integer(
'max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
Expand All @@ -60,7 +55,6 @@ def main(_):
hparams = eval_lib.HParams(FLAGS.checkpoint_dir, FLAGS.eval_dir,
FLAGS.dataset_dir, FLAGS.num_images_generated,
FLAGS.eval_real_images, FLAGS.noise_dims,
FLAGS.classifier_filename,
FLAGS.max_number_of_evaluations,
FLAGS.write_to_disk)
eval_lib.evaluate(hparams, run_eval_loop=True)
Expand Down
13 changes: 5 additions & 8 deletions tensorflow_gan/examples/mnist/eval_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

HParams = collections.namedtuple('HParams', [
'checkpoint_dir', 'eval_dir', 'dataset_dir', 'num_images_generated',
'eval_real_images', 'noise_dims', 'classifier_filename',
'max_number_of_evaluations', 'write_to_disk'
'eval_real_images', 'noise_dims', 'max_number_of_evaluations',
'write_to_disk'
])


Expand All @@ -52,8 +52,7 @@ def evaluate(hparams, run_eval_loop=True):
image_write_ops = None
if hparams.eval_real_images:
tf.compat.v1.summary.scalar(
'MNIST_Classifier_score',
util.mnist_score(real_images, hparams.classifier_filename))
'MNIST_Classifier_score', util.mnist_score(real_images))
else:
# In order for variables to load, use the same variable scope as in the
# train job.
Expand All @@ -63,11 +62,9 @@ def evaluate(hparams, run_eval_loop=True):
is_training=False)
tf.compat.v1.summary.scalar(
'MNIST_Frechet_distance',
util.mnist_frechet_distance(real_images, images,
hparams.classifier_filename))
util.mnist_frechet_distance(real_images, images))
tf.compat.v1.summary.scalar(
'MNIST_Classifier_score',
util.mnist_score(images, hparams.classifier_filename))
'MNIST_Classifier_score', util.mnist_score(images))
if hparams.num_images_generated >= 100 and hparams.write_to_disk:
reshaped_images = tfgan.eval.image_reshaper(
images[:100, ...], num_cols=10)
Expand Down
1 change: 0 additions & 1 deletion tensorflow_gan/examples/mnist/eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_build_graph(self, eval_real_images, mock_util, mock_provide_data):
num_images_generated=1000,
eval_real_images=eval_real_images,
noise_dims=64,
classifier_filename=None,
max_number_of_evaluations=None,
write_to_disk=True)

Expand Down
7 changes: 1 addition & 6 deletions tensorflow_gan/examples/mnist/infogan_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@
flags.DEFINE_integer('continuous_noise_dims', 2,
'The number of dimensions of the continuous noise.')

flags.DEFINE_string(
'classifier_filename', None,
'Location of the pretrained classifier. If `None`, use '
'default.')

flags.DEFINE_integer(
'max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
Expand All @@ -67,7 +62,7 @@ def main(_):
hparams = infogan_eval_lib.HParams(
FLAGS.checkpoint_dir, FLAGS.eval_dir, FLAGS.noise_samples,
FLAGS.unstructured_noise_dims, FLAGS.continuous_noise_dims,
FLAGS.classifier_filename, FLAGS.max_number_of_evaluations,
FLAGS.max_number_of_evaluations,
FLAGS.write_to_disk)
infogan_eval_lib.evaluate(hparams, run_eval_loop=True)

Expand Down
5 changes: 2 additions & 3 deletions tensorflow_gan/examples/mnist/infogan_eval_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

HParams = collections.namedtuple('HParams', [
'checkpoint_dir', 'eval_dir', 'noise_samples', 'unstructured_noise_dims',
'continuous_noise_dims', 'classifier_filename', 'max_number_of_evaluations',
'continuous_noise_dims', 'max_number_of_evaluations',
'write_to_disk'
])

Expand Down Expand Up @@ -96,8 +96,7 @@ def generator_fn(inputs):
all_images = tf.concat(
[categorical_images, continuous1_images, continuous2_images], 0)
tf.compat.v1.summary.scalar(
'MNIST_Classifier_score',
util.mnist_score(all_images, hparams.classifier_filename))
'MNIST_Classifier_score', util.mnist_score(all_images))

# Write images to disk.
image_write_ops = []
Expand Down
1 change: 0 additions & 1 deletion tensorflow_gan/examples/mnist/infogan_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_build_graph(self):
noise_samples=6,
unstructured_noise_dims=62,
continuous_noise_dims=2,
classifier_filename=None,
max_number_of_evaluations=None,
write_to_disk=True)
infogan_eval_lib.evaluate(hparams, run_eval_loop=False)
Expand Down
68 changes: 8 additions & 60 deletions tensorflow_gan/examples/mnist/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from __future__ import division
from __future__ import print_function

import os
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin

import tensorflow as tf
import tensorflow_gan as tfgan
import tensorflow_hub as tfhub
import tensorflow_probability as tfp

ds = tfp.distributions
Expand All @@ -43,56 +43,30 @@
'get_infogan_noise',
]

# The references to `MODEL_GRAPH_DEF` below are removed in open source by a
# copy bara transformation..
# Prepend `../`, since paths start from `third_party/tensorflow`.
MODEL_GRAPH_DEF = '../py/tensorflow_gan/examples/mnist/data/classify_mnist_graph_def.pb'
# The open source code finds the graph def by relative filepath.
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
MODEL_BY_FN = os.path.join(CUR_DIR, 'data', 'classify_mnist_graph_def.pb')
MNIST_MODULE = 'https://tfhub.dev/tensorflow/tfgan/eval/mnist/logits/1'

INPUT_TENSOR = 'inputs:0'
OUTPUT_TENSOR = 'logits:0'


def mnist_score(images,
graph_def_filename=None,
input_tensor=INPUT_TENSOR,
output_tensor=OUTPUT_TENSOR,
num_batches=1):
def mnist_score(images, num_batches=1):
"""Get MNIST classifier score.
Args:
images: A minibatch tensor of MNIST digits. Shape must be [batch, 28, 28,
1].
graph_def_filename: Location of a frozen GraphDef binary file on disk. If
`None`, uses a default graph.
input_tensor: GraphDef's input tensor name.
output_tensor: GraphDef's output tensor name.
num_batches: Number of batches to split `generated_images` in to in order to
efficiently run them through Inception.
Returns:
The classifier score, a floating-point scalar.
"""
images.shape.assert_is_compatible_with([None, 28, 28, 1])

graph_def = _graph_def_from_par_or_disk(graph_def_filename)
mnist_classifier_fn = lambda x: tfgan.eval.run_image_classifier( # pylint: disable=g-long-lambda
x, graph_def, input_tensor, output_tensor)

mnist_classifier_fn = tfhub.load(MNIST_MODULE)
score = tfgan.eval.classifier_score(images, mnist_classifier_fn, num_batches)
score.shape.assert_is_compatible_with([])

return score


def mnist_frechet_distance(real_images,
generated_images,
graph_def_filename=None,
input_tensor=INPUT_TENSOR,
output_tensor=OUTPUT_TENSOR,
num_batches=1):
def mnist_frechet_distance(real_images, generated_images, num_batches=1):
"""Frechet distance between real and generated images.
This technique is described in detail in https://arxiv.org/abs/1706.08500.
Expand All @@ -102,10 +76,6 @@ def mnist_frechet_distance(real_images,
real_images: Real images to use to compute Frechet Inception distance.
generated_images: Generated images to use to compute Frechet Inception
distance.
graph_def_filename: Location of a frozen GraphDef binary file on disk. If
`None`, uses a default graph.
input_tensor: GraphDef's input tensor name.
output_tensor: GraphDef's output tensor name.
num_batches: Number of batches to split images into in order to efficiently
run them through the classifier network.
Expand All @@ -114,42 +84,27 @@ def mnist_frechet_distance(real_images,
"""
real_images.shape.assert_is_compatible_with([None, 28, 28, 1])
generated_images.shape.assert_is_compatible_with([None, 28, 28, 1])

graph_def = _graph_def_from_par_or_disk(graph_def_filename)
mnist_classifier_fn = lambda x: tfgan.eval.run_image_classifier( # pylint: disable=g-long-lambda
x, graph_def, input_tensor, output_tensor)

mnist_classifier_fn = tfhub.load(MNIST_MODULE)
frechet_distance = tfgan.eval.frechet_classifier_distance(
real_images, generated_images, mnist_classifier_fn, num_batches)
frechet_distance.shape.assert_is_compatible_with([])

return frechet_distance


def mnist_cross_entropy(images,
one_hot_labels,
graph_def_filename=None,
input_tensor=INPUT_TENSOR,
output_tensor=OUTPUT_TENSOR):
def mnist_cross_entropy(images, one_hot_labels):
"""Returns the cross entropy loss of the classifier on images.
Args:
images: A minibatch tensor of MNIST digits. Shape must be [batch, 28, 28,
1].
one_hot_labels: The one hot label of the examples. Tensor size is [batch,
10].
graph_def_filename: Location of a frozen GraphDef binary file on disk. If
`None`, uses a default graph embedded in the par file.
input_tensor: GraphDef's input tensor name.
output_tensor: GraphDef's output tensor name.
Returns:
A scalar Tensor representing the cross entropy of the image minibatch.
"""
graph_def = _graph_def_from_par_or_disk(graph_def_filename)

logits = tfgan.eval.run_image_classifier(images, graph_def, input_tensor,
output_tensor)
logits = tfhub.load(MNIST_MODULE)(images)
return tf.compat.v1.losses.softmax_cross_entropy(
one_hot_labels, logits, loss_collection=None)

Expand Down Expand Up @@ -327,10 +282,3 @@ def get_infogan_noise(batch_size, categorical_dim, structured_continuous_dim,
continuous_noise = continuous_dist.sample([batch_size])

return [unstructured_noise], [categorical_noise, continuous_noise]


def _graph_def_from_par_or_disk(filename):
if filename is None:
return tfgan.eval.get_graph_def_from_disk(MODEL_BY_FN)
else:
return tfgan.eval.get_graph_def_from_disk(filename)
34 changes: 6 additions & 28 deletions tensorflow_gan/examples/self_attention_estimator/eval_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,6 @@
import tensorflow_gan as tfgan # tf


# TODO(joelshor, marvinritter): Make a combined TPU/CPU/GPU graph the TF-GAN
# default, so this isn't necessary.
def default_graph_def_fn():
url = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05_v4.tar.gz'
graph_def = 'inceptionv1_for_inception_score_tpu.pb'
return tfgan.eval.get_graph_def_from_url_tarball(
url, graph_def, os.path.basename(url))


def get_activations(get_images_fn, num_batches, get_logits=False):
"""Get Inception activations.
Expand All @@ -51,27 +42,14 @@ def get_activations(get_images_fn, num_batches, get_logits=False):
Returns:
1 or 2 Tensors of Inception activations.
"""
def sample_fn(_):
images = get_images_fn()
inception_img_sz = tfgan.eval.INCEPTION_DEFAULT_IMAGE_SIZE
larger_images = tf.compat.v1.image.resize(
images, [inception_img_sz, inception_img_sz],
method=tf.image.ResizeMethod.BILINEAR)
return larger_images


# Image resizing happens inside the Inception SavedModel.
outputs = tfgan.eval.sample_and_run_inception(
sample_fn=lambda _: get_images_fn(),
sample_inputs=[1.0] * num_batches) # dummy inputs
if get_logits:
output_tensor = (tfgan.eval.INCEPTION_OUTPUT,
tfgan.eval.INCEPTION_FINAL_POOL)
return outputs['logits'], outputs['pool_3']
else:
output_tensor = tfgan.eval.INCEPTION_FINAL_POOL
output = tfgan.eval.sample_and_run_inception(
sample_fn,
sample_inputs=[1.0] * num_batches, # dummy inputs
output_tensor=output_tensor,
default_graph_def_fn=default_graph_def_fn)

return output
return outputs['pool_3']


def get_activations_from_dataset(image_ds, num_batches, get_logits=False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@
mock = tf.compat.v1.test.mock


def _mock_inception(*args, **kwargs):
def _mock_inception(*args, **kwargs): # pylint: disable=function-redefined
del args, kwargs
return tf.zeros([12, 2048])


return {
'logits': tf.zeros([12, 1008]),
'pool_3': tf.zeros([12, 2048]),
}


class EvalLibTest(tf.test.TestCase):
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_gan/python/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@
# Collapse eval into a single namespace.
from .classifier_metrics import *
from .eval_utils import *
from .inception_metrics import *
from .sliced_wasserstein import *
from .summaries import *

# Collect list of exposed symbols.
from .classifier_metrics import __all__ as classifier_metrics_symbols
from .eval_utils import __all__ as eval_utils_symbols
from .inception_metrics import __all__ as inception_metrics_symbols
from .sliced_wasserstein import __all__ as sliced_wasserstein_symbols
from .summaries import __all__ as summaries_symbols
__all__ = classifier_metrics_symbols
__all__ += eval_utils_symbols
__all__ += inception_metrics_symbols
__all__ += sliced_wasserstein_symbols
__all__ += summaries_symbols
Loading

0 comments on commit 5fbbba1

Please sign in to comment.