Skip to content

Commit

Permalink
Extract a base Model class from current Model class, rename the curre…
Browse files Browse the repository at this point in the history
…nt to CNNModel, and subclasses Model.

PiperOrigin-RevId: 207753123
  • Loading branch information
protoget authored and tensorflower-gardener committed Aug 7, 2018
1 parent 6eeaa75 commit 634b921
Show file tree
Hide file tree
Showing 18 changed files with 76 additions and 53 deletions.
18 changes: 12 additions & 6 deletions scripts/tf_cnn_benchmarks/benchmark_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2773,8 +2773,8 @@ def add_sync_queues_and_barrier(self, name_prefix, enqueue_after_list):
return tf.group(*queue_ops)


class BenchmarkSeq2Seq(BenchmarkCNN):
"""Class for benchmarking a seq2seq network."""
class BenchmarkNMT(BenchmarkCNN):
"""Class for benchmarking a NMT network."""

def __init__(self, params, dataset=None, model=None):
# pylint:disable=super-init-not-called
Expand All @@ -2788,16 +2788,22 @@ def _build_graph(self):
"""
pass

def _build_model_single_session_with_dataset_prefetching(self):
pass
def _build_model(self):
"""Build the TensorFlow graph."""

# Not implemented since it's FLAGS.dataset_use_prefetch is default True.
raise NotImplementedError

def _build_model_single_session(self):
pass
"""Build the TensorFlow graph for multiple replicas in a single_session."""

# Not implemented since it's FLAGS.dataset_use_prefetch is default True.
raise NotImplementedError

def _build_model_with_dataset_prefetching(self):
pass

def _build_model(self):
def _build_model_single_session_with_dataset_prefetching(self):
pass


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def _test_variable_update(self,
actual_losses.append([x.loss for x in outputs])

inputs = test_util.get_fake_var_update_inputs()
expected_losses = test_util.TestModel().manually_compute_losses(
expected_losses = test_util.TestCNNModel().manually_compute_losses(
inputs, num_workers, params)
if params.variable_update == 'distributed_all_reduce':
# In distributed all reduce, each step, the controller outputs the average
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run_with_real_model(params):

def run_with_test_model(params):
"""Runs tf_cnn_benchmarks with a test model."""
model = test_util.TestModel()
model = test_util.TestCNNModel()
inputs = test_util.get_fake_var_update_inputs()
with test_util.monkey_patch(benchmark_cnn,
LOSS_AND_ACCURACY_DIGITS_TO_SHOW=15):
Expand Down
8 changes: 4 additions & 4 deletions scripts/tf_cnn_benchmarks/benchmark_cnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def setUp(self):
def _get_benchmark_cnn_losses(self, inputs, params):
"""Returns the losses of BenchmarkCNN on the given inputs and params."""
logs = []
model = test_util.TestModel()
model = test_util.TestCNNModel()
with test_util.monkey_patch(benchmark_cnn,
log_fn=test_util.print_and_add_to_list(logs),
LOSS_AND_ACCURACY_DIGITS_TO_SHOW=15):
Expand All @@ -1022,16 +1022,16 @@ def _get_benchmark_cnn_losses(self, inputs, params):
def _test_variable_update(self, params):
"""Tests variables are updated correctly when the given params are used.
A BenchmarkCNN is created with a TestModel, and is run with some scalar
A BenchmarkCNN is created with a TestCNNModel, and is run with some scalar
images. The losses are then compared with the losses obtained with
TestModel().manually_compute_losses()
TestCNNModel().manually_compute_losses()
Args:
params: a Params tuple used to create BenchmarkCNN.
"""
inputs = test_util.get_fake_var_update_inputs()
actual_losses = self._get_benchmark_cnn_losses(inputs, params)
expected_losses, = test_util.TestModel().manually_compute_losses(
expected_losses, = test_util.TestCNNModel().manually_compute_losses(
inputs, 1, params)
rtol = 3e-2 if params.use_fp16 else 1e-5
self.assertAllClose(actual_losses[:len(expected_losses)], expected_losses,
Expand Down
4 changes: 2 additions & 2 deletions scripts/tf_cnn_benchmarks/models/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from models import model


class AlexnetModel(model.Model):
class AlexnetModel(model.CNNModel):
"""Alexnet cnn model."""

def __init__(self):
Expand All @@ -48,7 +48,7 @@ def add_inference(self, cnn):
cnn.dropout()


class AlexnetCifar10Model(model.Model):
class AlexnetCifar10Model(model.CNNModel):
"""Alexnet cnn model for cifar datasets.
The model architecture follows the one defined in the tensorflow tutorial
Expand Down
2 changes: 1 addition & 1 deletion scripts/tf_cnn_benchmarks/models/densenet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from models import model as model_lib


class DensenetCifar10Model(model_lib.Model):
class DensenetCifar10Model(model_lib.CNNModel):
"""Densenet cnn network configuration."""

def __init__(self, model, layer_counts, growth_rate):
Expand Down
3 changes: 2 additions & 1 deletion scripts/tf_cnn_benchmarks/models/googlenet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from models import model


class GooglenetModel(model.Model):
class GooglenetModel(model.CNNModel):
"""GoogLeNet."""

def __init__(self):
super(GooglenetModel, self).__init__('googlenet', 224, 32, 0.005)
Expand Down
6 changes: 4 additions & 2 deletions scripts/tf_cnn_benchmarks/models/inception_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
from models import model


class Inceptionv3Model(model.Model):
class Inceptionv3Model(model.CNNModel):
"""InceptionV3."""

def __init__(self, auxiliary=False):
self._auxiliary = auxiliary
Expand Down Expand Up @@ -157,7 +158,8 @@ def inception_v4_rb(cnn):
cnn.inception_module('incept_v4_rb', cols)


class Inceptionv4Model(model.Model):
class Inceptionv4Model(model.CNNModel):
"""Inceptionv4."""

def __init__(self):
super(Inceptionv4Model, self).__init__('inception4', 299, 32, 0.005)
Expand Down
3 changes: 2 additions & 1 deletion scripts/tf_cnn_benchmarks/models/lenet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from models import model


class Lenet5Model(model.Model):
class Lenet5Model(model.CNNModel):
"""Lenet5."""

def __init__(self):
super(Lenet5Model, self).__init__('lenet5', 28, 32, 0.005)
Expand Down
3 changes: 1 addition & 2 deletions scripts/tf_cnn_benchmarks/models/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def training_scope(**kwargs):
return lib.training_scope(**kwargs)


class MobilenetModel(model.Model):
class MobilenetModel(model.CNNModel):
"""Mobilenet model configuration."""

def __init__(self):
Expand All @@ -196,4 +196,3 @@ def add_inference(self, cnn):
with tf.contrib.slim.arg_scope(training_scope(is_training=cnn.phase_train)):
cnn.top_layer, _ = mobilenet(cnn.top_layer, is_training=cnn.phase_train)
cnn.top_size = cnn.top_layer.shape[-1].value

47 changes: 30 additions & 17 deletions scripts/tf_cnn_benchmarks/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,20 @@


class Model(object):
"""Base model configuration for CNN benchmarks."""
"""Base model config for DNN benchmarks."""

def __init__(self,
model,
image_size,
batch_size,
learning_rate,
layer_counts=None,
fp16_loss_scale=128):
self.model = model
self.image_size = image_size
def __init__(self, model_name, batch_size, learning_rate, fp16_loss_scale):
self.model = model_name
self.batch_size = batch_size
self.default_batch_size = batch_size
self.learning_rate = learning_rate
self.layer_counts = layer_counts
# TODO(reedwm) Set custom loss scales for each model instead of using the
# default of 128.
self.fp16_loss_scale = fp16_loss_scale

def get_model(self):
return self.model

def get_image_size(self):
return self.image_size

def get_batch_size(self):
return self.batch_size

Expand All @@ -53,9 +42,6 @@ def set_batch_size(self, batch_size):
def get_default_batch_size(self):
return self.default_batch_size

def get_layer_counts(self):
return self.layer_counts

def get_fp16_loss_scale(self):
return self.fp16_loss_scale

Expand All @@ -67,6 +53,33 @@ def get_learning_rate(self, global_step, batch_size):
def add_inference(self, unused_cnn):
raise ValueError('Must be implemented in derived classes')

def build_network(self, inputs, **kwargs):
del inputs
del kwargs
raise ValueError('Must be implemented in derived classes')


class CNNModel(Model):
"""Base model configuration for CNN benchmarks."""

def __init__(self,
model,
image_size,
batch_size,
learning_rate,
layer_counts=None,
fp16_loss_scale=128):
super(CNNModel, self).__init__(model, batch_size, learning_rate,
fp16_loss_scale)
self.image_size = image_size
self.layer_counts = layer_counts

def get_image_size(self):
return self.image_size

def get_layer_counts(self):
return self.layer_counts

def skip_final_affine_layer(self):
"""Returns if the caller of this class should skip the final affine layer.
Expand Down
6 changes: 3 additions & 3 deletions scripts/tf_cnn_benchmarks/models/nasnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def add_and_check_endpoint(endpoint_name, net):
return logits, end_points


class NasnetModel(model.Model):
class NasnetModel(model.CNNModel):
"""Nasnet model configuration."""

def __init__(self):
Expand All @@ -547,7 +547,7 @@ def add_inference(self, cnn):
cnn.top_size = cnn.top_layer.shape[-1].value


class NasnetLargeModel(model.Model):
class NasnetLargeModel(model.CNNModel):
"""Nasnet model configuration."""

def __init__(self):
Expand All @@ -562,7 +562,7 @@ def add_inference(self, cnn):
cnn.top_size = cnn.top_layer.shape[-1].value


class NasnetCifarModel(model.Model):
class NasnetCifarModel(model.CNNModel):
"""Nasnet cifar model configuration."""

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion scripts/tf_cnn_benchmarks/models/official_resnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from models import model as model_lib


class ImagenetResnetModel(model_lib.Model):
class ImagenetResnetModel(model_lib.CNNModel):
"""Official resnet models."""

def __init__(self, resnet_size, version=2):
Expand Down
3 changes: 2 additions & 1 deletion scripts/tf_cnn_benchmarks/models/overfeat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from models import model


class OverfeatModel(model.Model):
class OverfeatModel(model.CNNModel):
"""OverfeatModel."""

def __init__(self):
super(OverfeatModel, self).__init__('overfeat', 231, 32, 0.005)
Expand Down
4 changes: 2 additions & 2 deletions scripts/tf_cnn_benchmarks/models/resnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def residual_block(cnn, depth, stride, version):
cnn.top_size = depth


class ResnetModel(model_lib.Model):
class ResnetModel(model_lib.CNNModel):
"""Resnet cnn network configuration."""

def __init__(self, model, layer_counts):
Expand Down Expand Up @@ -323,7 +323,7 @@ def create_resnet152_v2_model():
return ResnetModel('resnet152_v2', (3, 8, 36, 3))


class ResnetCifar10Model(model_lib.Model):
class ResnetCifar10Model(model_lib.CNNModel):
"""Resnet cnn network configuration for Cifar 10 dataset.
V1 model architecture follows the one defined in the paper:
Expand Down
4 changes: 2 additions & 2 deletions scripts/tf_cnn_benchmarks/models/trivial_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from models import model


class TrivialModel(model.Model):
class TrivialModel(model.CNNModel):
"""Trivial model configuration."""

def __init__(self):
Expand All @@ -29,7 +29,7 @@ def add_inference(self, cnn):
cnn.affine(4096)


class TrivialCifar10Model(model.Model):
class TrivialCifar10Model(model.CNNModel):
"""Trivial cifar10 model configuration."""

def __init__(self):
Expand Down
6 changes: 3 additions & 3 deletions scripts/tf_cnn_benchmarks/models/vgg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _construct_vgg(cnn, num_conv_layers):
cnn.dropout()


class Vgg11Model(model.Model):
class Vgg11Model(model.CNNModel):

def __init__(self):
super(Vgg11Model, self).__init__('vgg11', 224, 64, 0.005)
Expand All @@ -62,7 +62,7 @@ def add_inference(self, cnn):
_construct_vgg(cnn, [1, 1, 2, 2, 2])


class Vgg16Model(model.Model):
class Vgg16Model(model.CNNModel):

def __init__(self):
super(Vgg16Model, self).__init__('vgg16', 224, 64, 0.005)
Expand All @@ -71,7 +71,7 @@ def add_inference(self, cnn):
_construct_vgg(cnn, [2, 2, 3, 3, 3])


class Vgg19Model(model.Model):
class Vgg19Model(model.CNNModel):

def __init__(self):
super(Vgg19Model, self).__init__('vgg19', 224, 64, 0.005)
Expand Down
6 changes: 3 additions & 3 deletions scripts/tf_cnn_benchmarks/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def manually_compute_losses(numpy_inputs, inputs_placeholder, loss, num_workers,
return losses


class TestModel(model.Model):
class TestCNNModel(model.CNNModel):
"""A simple model used for testing.
The input is a 1-channel 1x1 image, consisting of a single number. The model
Expand All @@ -444,8 +444,8 @@ class TestModel(model.Model):
"""

def __init__(self):
super(TestModel, self).__init__('test_model', image_size=1, batch_size=1,
learning_rate=1)
super(TestCNNModel, self).__init__(
'test_cnn_model', image_size=1, batch_size=1, learning_rate=1)

VAR_A_INITIAL_VALUE = 1.
VAR_B_INITIAL_VALUE = 2.
Expand Down

0 comments on commit 634b921

Please sign in to comment.