diff --git a/tensor2tensor/models/common_hparams.py b/tensor2tensor/models/common_hparams.py index c8c458414..f48a67c15 100644 --- a/tensor2tensor/models/common_hparams.py +++ b/tensor2tensor/models/common_hparams.py @@ -61,6 +61,7 @@ def basic_params1(): weight_noise=0.0, learning_rate_decay_scheme="none", learning_rate_warmup_steps=100, + learning_rate_cosine_cycle_steps=250000, learning_rate=0.1, sampling_method="argmax", # "argmax" or "random" problem_choice="adaptive", # "uniform", "adaptive", "distributed" diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 15a712ef2..d38f97fb0 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -58,9 +58,15 @@ def inverse_exp_decay(max_step, min_value=0.01): return inv_base**tf.maximum(float(max_step) - step, 0.0) -def shakeshake2_py(x, y, equal=False): +def shakeshake2_py(x, y, equal=False, individual=False): """The shake-shake sum of 2 tensors, python version.""" - alpha = 0.5 if equal else tf.random_uniform([]) + if equal: + alpha = 0.5 + if individual: + alpha = tf.random_uniform(tf.get_shape(x)[:1]) + else: + alpha = tf.random_uniform([]) + return alpha * x + (1.0 - alpha) * y @@ -72,6 +78,14 @@ def shakeshake2_grad(x1, x2, dy): return dx +@function.Defun() +def shakeshake2_indiv_grad(x1, x2, dy): + """Overriding gradient for shake-shake of 2 tensors.""" + y = shakeshake2_py(x1, x2, individual=True) + dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy]) + return dx + + @function.Defun() def shakeshake2_equal_grad(x1, x2, dy): """Overriding gradient for shake-shake of 2 tensors.""" @@ -86,6 +100,11 @@ def shakeshake2(x1, x2): return shakeshake2_py(x1, x2) +@function.Defun(grad_func=shakeshake2_indiv_grad) +def shakeshake2_indiv(x1, x2): + return shakeshake2_py(x1, x2, individual=True) + + @function.Defun(grad_func=shakeshake2_equal_grad) def shakeshake2_eqgrad(x1, x2): """The shake-shake function with a different alpha for forward/backward.""" diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index ae0e0da61..214aec245 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -30,6 +30,7 @@ from tensor2tensor.models import modalities from tensor2tensor.models import multimodel from tensor2tensor.models import neural_gpu +from tensor2tensor.models import shake_shake from tensor2tensor.models import slicenet from tensor2tensor.models import transformer from tensor2tensor.models import transformer_alternative diff --git a/tensor2tensor/models/shake_shake.py b/tensor2tensor/models/shake_shake.py new file mode 100644 index 000000000..f87eaa335 --- /dev/null +++ b/tensor2tensor/models/shake_shake.py @@ -0,0 +1,143 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensor2tensor.models import common_hparams +from tensor2tensor.models import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +def shake_shake_block_branch(x, conv_filters, stride): + x = tf.nn.relu(x) + x = tf.layers.conv2d( + x, conv_filters, (3, 3), strides=(stride, stride), padding='SAME') + x = tf.layers.batch_normalization(x) + x = tf.nn.relu(x) + x = tf.layers.conv2d(x, conv_filters, (3, 3), strides=(1, 1), padding='SAME') + x = tf.layers.batch_normalization(x) + return x + + +def downsampling_residual_branch(x, conv_filters): + x = tf.nn.relu(x) + + x1 = tf.layers.average_pooling2d(x, pool_size=(1, 1), strides=(2, 2)) + x1 = tf.layers.conv2d(x1, conv_filters / 2, (1, 1), padding='SAME') + + x2 = tf.pad(x[:, 1:, 1:], [[0, 0], [0, 1], [0, 1], [0, 0]]) + x2 = tf.layers.average_pooling2d(x2, pool_size=(1, 1), strides=(2, 2)) + x2 = tf.layers.conv2d(x2, conv_filters / 2, (1, 1), padding='SAME') + + return tf.concat([x1, x2], axis=3) + + +def shake_shake_block(x, conv_filters, stride, hparams): + with tf.variable_scope('branch_1'): + branch1 = shake_shake_block_branch(x, conv_filters, stride) + with tf.variable_scope('branch_2'): + branch2 = shake_shake_block_branch(x, conv_filters, stride) + if x.shape[-1] == conv_filters: + skip = tf.identity(x) + else: + skip = downsampling_residual_branch(x, conv_filters) + + # TODO(rshin): Use different alpha for each image in batch. + if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN: + if hparams.shakeshake_type == 'batch': + shaken = common_layers.shakeshake2(branch1, branch2) + elif hparams.shakeshake_type == 'image': + shaken = common_layers.shakeshake2_indiv(branch1, branch2) + elif hparams.shakeshake_type == 'equal': + shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True) + else: + raise ValueError('Invalid shakeshake_type: {!r}'.format(shaken)) + else: + shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True) + shaken.set_shape(branch1.get_shape()) + + return skip + shaken + + +def shake_shake_stage(x, num_blocks, conv_filters, initial_stride, hparams): + with tf.variable_scope('block_0'): + x = shake_shake_block(x, conv_filters, initial_stride, hparams) + for i in xrange(1, num_blocks): + with tf.variable_scope('block_{}'.format(i)): + x = shake_shake_block(x, conv_filters, 1, hparams) + return x + + +@registry.register_model +class ShakeShake(t2t_model.T2TModel): + '''Implements the Shake-Shake architecture. + + From + This is intended to match the CIFAR-10 version, and correspond to + "Shake-Shake-Batch" in Table 1. + ''' + + def model_fn_body(self, features): + hparams = self._hparams + print(hparams.learning_rate) + + inputs = features["inputs"] + assert (hparams.num_hidden_layers - 2) % 6 == 0 + blocks_per_stage = (hparams.num_hidden_layers - 2) // 6 + + # For canonical Shake-Shake, the entry flow is a 3x3 convolution with 16 + # filters then a batch norm. Instead we will rely on the one in + # SmallImageModality, which seems to instead use a layer norm. + x = inputs + mode = hparams.mode + with tf.variable_scope('shake_shake_stage_1'): + x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters, 1, + hparams) + with tf.variable_scope('shake_shake_stage_2'): + x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 2, 2, + hparams) + with tf.variable_scope('shake_shake_stage_3'): + x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 4, 2, + hparams) + + # For canonical Shake-Shake, we should perform 8x8 average pooling and then + # have a fully-connected layer (which produces the logits for each class). + # Instead, we rely on the Xception exit flow in ClassLabelModality. + # + # Also, this model_fn does not return an extra_loss. However, TensorBoard + # reports an exponential moving average for extra_loss, where the initial + # value for the moving average may be a large number, so extra_loss will + # look large at the beginning of training. + return x + + +@registry.register_hparams +def shakeshake_cifar10(): + hparams = common_hparams.basic_params1() + # This leads to effective batch size 128 when number of GPUs is 1 + hparams.batch_size = 4096 * 8 + hparams.hidden_size = 16 + hparams.dropout = 0 + hparams.label_smoothing = 0.0 + hparams.clip_grad_norm = 2.0 + hparams.num_hidden_layers = 26 + hparams.kernel_height = -1 # Unused + hparams.kernel_width = -1 # Unused + hparams.learning_rate_decay_scheme = "cosine" + # Model should be run for 700000 steps with batch size 128 (~1800 epochs) + hparams.learning_rate_cosine_cycle_steps = 700000 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 3000 + hparams.initializer = "uniform_unit_scaling" + hparams.initializer_gain = 1.0 + # TODO(rshin): Adjust so that effective value becomes ~1e-4 + hparams.weight_decay = 3.0 + hparams.optimizer = "Momentum" + hparams.optimizer_momentum_momentum = 0.9 + hparams.add_hparam('base_filters', 16) + hparams.add_hparam('shakeshake_type', 'batch') + return hparams diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index 057612ecb..3c693d08e 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -321,6 +321,9 @@ def learning_rate_decay(): (step + 1) * warmup_steps**-1.5, (step + 1)**-0.5) elif hparams.learning_rate_decay_scheme == "exp100k": return 0.94**(step // 100000) + elif hparams.learning_rate_decay_scheme == "cosine": + cycle_steps = hparams.learning_rate_cosine_cycle_steps + return 0.5 * (1 + tf.cos(np.pi * (step % cycle_steps) / cycle_steps)) inv_base = tf.exp(tf.log(0.01) / warmup_steps) inv_decay = inv_base**(warmup_steps - step)