Skip to content

Commit

Permalink
Merge internal changes into public repository (change 217345822)
Browse files Browse the repository at this point in the history
  • Loading branch information
shizhiw committed Oct 16, 2018
1 parent 5cb81b5 commit 82eb52b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 16 deletions.
2 changes: 1 addition & 1 deletion models/experimental/ncf/ncf_main.py
Expand Up @@ -493,7 +493,7 @@ def logits_fn(embedding, params):
mlp_vector = tf.keras.layers.concatenate([mlp_user_input, mlp_item_input])

num_layer = len(model_layers) # Number of layers in the MLP
for layer in xrange(1, num_layer):
for layer in range(1, num_layer):
model_layer = tf.keras.layers.Dense(
model_layers[layer],
kernel_regularizer=tf.keras.regularizers.l2(mlp_reg_layers[layer]),
Expand Down
16 changes: 13 additions & 3 deletions models/official/amoeba_net/amoeba_net.py
Expand Up @@ -137,6 +137,8 @@
'stem_reduction_size', 32, 'Stem filter size.')
flags.DEFINE_float(
'weight_decay', 4e-05, 'Weight decay for slim model.')
flags.DEFINE_integer(
'num_label_classes', 1001, 'The number of classes that images fit into.')

# Training hyper-parameters
flags.DEFINE_float(
Expand Down Expand Up @@ -166,6 +168,11 @@
flags.DEFINE_integer(
'image_size', 299, 'Size of image, assuming image height and width.')

flags.DEFINE_integer(
'num_train_images', 1281167, 'The number of images in the training set.')
flags.DEFINE_integer(
'num_eval_images', 50000, 'The number of images in the evaluation set.')

flags.DEFINE_bool(
'use_bp16', True, 'If True, use bfloat16 for activations')

Expand All @@ -183,7 +190,7 @@ def build_run_config():
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project)

eval_steps = model_lib.NUM_EVAL_IMAGES // FLAGS.eval_batch_size
eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
iterations_per_loop = (eval_steps if FLAGS.mode == 'eval'
else FLAGS.iterations_per_loop)
save_checkpoints_steps = FLAGS.save_checkpoints_steps or iterations_per_loop
Expand Down Expand Up @@ -261,6 +268,9 @@ def override_with_flags(hparams):
'weight_decay',
'num_shards',
'distributed_group_size',
'num_train_images',
'num_eval_images',
'num_label_classes',
]
for flag_name in override_flag_names:
flag_value = getattr(FLAGS, flag_name, 'INVALID')
Expand Down Expand Up @@ -314,8 +324,8 @@ def main(_):
estimator_parmas = {}

train_steps_per_epoch = int(
math.ceil(model_lib.NUM_TRAIN_IMAGES / float(hparams.train_batch_size)))
eval_steps = model_lib.NUM_EVAL_IMAGES // hparams.eval_batch_size
math.ceil(hparams.num_train_images / float(hparams.train_batch_size)))
eval_steps = hparams.num_eval_images // hparams.eval_batch_size
eval_batch_size = (None if mode == 'train' else
hparams.eval_batch_size)

Expand Down
32 changes: 21 additions & 11 deletions models/official/amoeba_net/amoeba_net_model.py
Expand Up @@ -29,11 +29,6 @@
import model_specs


# Dataset constants
NUM_TRAIN_IMAGES = 1281167
NUM_EVAL_IMAGES = 50000
LABEL_CLASSES = 1001

# Random cropping constants
_RESIZE_SIDE_MIN = 300
_RESIZE_SIDE_MAX = 600
Expand All @@ -56,7 +51,9 @@ def imagenet_hparams():
##########################################################################

image_size=299,

num_train_images=1281167,
num_eval_images=50000,
num_label_classes=1001,
##########################################################################
# Architectural params. ##################################################
##########################################################################
Expand Down Expand Up @@ -156,7 +153,18 @@ def imagenet_hparams():


def build_hparams(cell_name='amoeba_net_d'):
"""Build tf.Hparams for training Amoeba Net."""
"""Build tf.Hparams for training Amoeba Net.
Args:
cell_name: Which of the cells in model_specs.py to use to build the
amoebanet neural network; the cell names defined in that
module correspond to architectures discovered by an
evolutionary search described in
https://arxiv.org/abs/1802.01548.
Returns:
A set of tf.HParams suitable for Amoeba Net training.
"""
hparams = imagenet_hparams()
operations, hiddenstate_indices, used_hiddenstates = (
model_specs.get_normal_cell(cell_name))
Expand Down Expand Up @@ -222,7 +230,8 @@ def _calc_num_trainable_params(self):

def _build_learning_rate_schedule(self, global_step):
"""Build learning rate."""
steps_per_epoch = NUM_TRAIN_IMAGES // self.hparams.train_batch_size
steps_per_epoch = (
self.hparams.num_train_images // self.hparams.train_batch_size)
lr_warmup_epochs = 0
if self.hparams.lr_decay_method == 'exponential':
lr_warmup_epochs = self.hparams.lr_warmup_epochs
Expand All @@ -244,7 +253,8 @@ def _build_network(self, features, labels, mode):
"""Build a network that returns loss and logits from features and labels."""
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
is_predict = (mode == tf.estimator.ModeKeys.PREDICT)
steps_per_epoch = float(NUM_TRAIN_IMAGES) / self.hparams.train_batch_size
steps_per_epoch = float(
self.hparams.num_train_images) / self.hparams.train_batch_size
num_total_steps = int(steps_per_epoch * self.hparams.num_epochs)
self.hparams.set_hparam('drop_path_burn_in_steps', num_total_steps)

Expand All @@ -260,10 +270,10 @@ def _build_network(self, features, labels, mode):
formatted_hparams(hparams)))

logits, end_points = model_builder.build_network(
features, LABEL_CLASSES, is_training, hparams)
features, hparams.num_label_classes, is_training, hparams)

if not is_predict:
labels = tf.one_hot(labels, LABEL_CLASSES)
labels = tf.one_hot(labels, hparams.num_label_classes)
loss = model_builder.build_softmax_loss(
logits,
end_points,
Expand Down
2 changes: 1 addition & 1 deletion models/official/amoeba_net/tf_hub.py
Expand Up @@ -296,7 +296,7 @@ def main(_):
# fine-tuning Hub image modules. Disable aux heads to avoid putting unused
# variables and ops into the module.
hparams.set_hparam('use_aux_head', False)
eval_steps = model_lib.NUM_EVAL_IMAGES // FLAGS.eval_batch_size
eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
export_path = FLAGS.export_path or (model_dir + '/export')

input_pipeline = model_lib.InputPipeline(
Expand Down

0 comments on commit 82eb52b

Please sign in to comment.