Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensor2tensor/models/common_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def basic_params1():
weight_noise=0.0,
learning_rate_decay_scheme="none",
learning_rate_warmup_steps=100,
learning_rate_cosine_cycle_steps=250000,
learning_rate=0.1,
sampling_method="argmax", # "argmax" or "random"
problem_choice="adaptive", # "uniform", "adaptive", "distributed"
Expand Down
23 changes: 21 additions & 2 deletions tensor2tensor/models/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,15 @@ def inverse_exp_decay(max_step, min_value=0.01):
return inv_base**tf.maximum(float(max_step) - step, 0.0)


def shakeshake2_py(x, y, equal=False):
def shakeshake2_py(x, y, equal=False, individual=False):
"""The shake-shake sum of 2 tensors, python version."""
alpha = 0.5 if equal else tf.random_uniform([])
if equal:
alpha = 0.5
if individual:
alpha = tf.random_uniform(tf.get_shape(x)[:1])
else:
alpha = tf.random_uniform([])

return alpha * x + (1.0 - alpha) * y


Expand All @@ -72,6 +78,14 @@ def shakeshake2_grad(x1, x2, dy):
return dx


@function.Defun()
def shakeshake2_indiv_grad(x1, x2, dy):
"""Overriding gradient for shake-shake of 2 tensors."""
y = shakeshake2_py(x1, x2, individual=True)
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
return dx


@function.Defun()
def shakeshake2_equal_grad(x1, x2, dy):
"""Overriding gradient for shake-shake of 2 tensors."""
Expand All @@ -86,6 +100,11 @@ def shakeshake2(x1, x2):
return shakeshake2_py(x1, x2)


@function.Defun(grad_func=shakeshake2_indiv_grad)
def shakeshake2_indiv(x1, x2):
return shakeshake2_py(x1, x2, individual=True)


@function.Defun(grad_func=shakeshake2_equal_grad)
def shakeshake2_eqgrad(x1, x2):
"""The shake-shake function with a different alpha for forward/backward."""
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tensor2tensor.models import modalities
from tensor2tensor.models import multimodel
from tensor2tensor.models import neural_gpu
from tensor2tensor.models import shake_shake
from tensor2tensor.models import slicenet
from tensor2tensor.models import transformer
from tensor2tensor.models import transformer_alternative
Expand Down
143 changes: 143 additions & 0 deletions tensor2tensor/models/shake_shake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import xrange # pylint: disable=redefined-builtin

from tensor2tensor.models import common_hparams
from tensor2tensor.models import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

import tensorflow as tf


def shake_shake_block_branch(x, conv_filters, stride):
x = tf.nn.relu(x)
x = tf.layers.conv2d(
x, conv_filters, (3, 3), strides=(stride, stride), padding='SAME')
x = tf.layers.batch_normalization(x)
x = tf.nn.relu(x)
x = tf.layers.conv2d(x, conv_filters, (3, 3), strides=(1, 1), padding='SAME')
x = tf.layers.batch_normalization(x)
return x


def downsampling_residual_branch(x, conv_filters):
x = tf.nn.relu(x)

x1 = tf.layers.average_pooling2d(x, pool_size=(1, 1), strides=(2, 2))
x1 = tf.layers.conv2d(x1, conv_filters / 2, (1, 1), padding='SAME')

x2 = tf.pad(x[:, 1:, 1:], [[0, 0], [0, 1], [0, 1], [0, 0]])
x2 = tf.layers.average_pooling2d(x2, pool_size=(1, 1), strides=(2, 2))
x2 = tf.layers.conv2d(x2, conv_filters / 2, (1, 1), padding='SAME')

return tf.concat([x1, x2], axis=3)


def shake_shake_block(x, conv_filters, stride, hparams):
with tf.variable_scope('branch_1'):
branch1 = shake_shake_block_branch(x, conv_filters, stride)
with tf.variable_scope('branch_2'):
branch2 = shake_shake_block_branch(x, conv_filters, stride)
if x.shape[-1] == conv_filters:
skip = tf.identity(x)
else:
skip = downsampling_residual_branch(x, conv_filters)

# TODO(rshin): Use different alpha for each image in batch.
if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN:
if hparams.shakeshake_type == 'batch':
shaken = common_layers.shakeshake2(branch1, branch2)
elif hparams.shakeshake_type == 'image':
shaken = common_layers.shakeshake2_indiv(branch1, branch2)
elif hparams.shakeshake_type == 'equal':
shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True)
else:
raise ValueError('Invalid shakeshake_type: {!r}'.format(shaken))
else:
shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True)
shaken.set_shape(branch1.get_shape())

return skip + shaken


def shake_shake_stage(x, num_blocks, conv_filters, initial_stride, hparams):
with tf.variable_scope('block_0'):
x = shake_shake_block(x, conv_filters, initial_stride, hparams)
for i in xrange(1, num_blocks):
with tf.variable_scope('block_{}'.format(i)):
x = shake_shake_block(x, conv_filters, 1, hparams)
return x


@registry.register_model
class ShakeShake(t2t_model.T2TModel):
'''Implements the Shake-Shake architecture.

From <https://arxiv.org/pdf/1705.07485.pdf>
This is intended to match the CIFAR-10 version, and correspond to
"Shake-Shake-Batch" in Table 1.
'''

def model_fn_body(self, features):
hparams = self._hparams
print(hparams.learning_rate)

inputs = features["inputs"]
assert (hparams.num_hidden_layers - 2) % 6 == 0
blocks_per_stage = (hparams.num_hidden_layers - 2) // 6

# For canonical Shake-Shake, the entry flow is a 3x3 convolution with 16
# filters then a batch norm. Instead we will rely on the one in
# SmallImageModality, which seems to instead use a layer norm.
x = inputs
mode = hparams.mode
with tf.variable_scope('shake_shake_stage_1'):
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters, 1,
hparams)
with tf.variable_scope('shake_shake_stage_2'):
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 2, 2,
hparams)
with tf.variable_scope('shake_shake_stage_3'):
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 4, 2,
hparams)

# For canonical Shake-Shake, we should perform 8x8 average pooling and then
# have a fully-connected layer (which produces the logits for each class).
# Instead, we rely on the Xception exit flow in ClassLabelModality.
#
# Also, this model_fn does not return an extra_loss. However, TensorBoard
# reports an exponential moving average for extra_loss, where the initial
# value for the moving average may be a large number, so extra_loss will
# look large at the beginning of training.
return x


@registry.register_hparams
def shakeshake_cifar10():
hparams = common_hparams.basic_params1()
# This leads to effective batch size 128 when number of GPUs is 1
hparams.batch_size = 4096 * 8
hparams.hidden_size = 16
hparams.dropout = 0
hparams.label_smoothing = 0.0
hparams.clip_grad_norm = 2.0
hparams.num_hidden_layers = 26
hparams.kernel_height = -1 # Unused
hparams.kernel_width = -1 # Unused
hparams.learning_rate_decay_scheme = "cosine"
# Model should be run for 700000 steps with batch size 128 (~1800 epochs)
hparams.learning_rate_cosine_cycle_steps = 700000
hparams.learning_rate = 0.2
hparams.learning_rate_warmup_steps = 3000
hparams.initializer = "uniform_unit_scaling"
hparams.initializer_gain = 1.0
# TODO(rshin): Adjust so that effective value becomes ~1e-4
hparams.weight_decay = 3.0
hparams.optimizer = "Momentum"
hparams.optimizer_momentum_momentum = 0.9
hparams.add_hparam('base_filters', 16)
hparams.add_hparam('shakeshake_type', 'batch')
return hparams
3 changes: 3 additions & 0 deletions tensor2tensor/utils/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ def learning_rate_decay():
(step + 1) * warmup_steps**-1.5, (step + 1)**-0.5)
elif hparams.learning_rate_decay_scheme == "exp100k":
return 0.94**(step // 100000)
elif hparams.learning_rate_decay_scheme == "cosine":
cycle_steps = hparams.learning_rate_cosine_cycle_steps
return 0.5 * (1 + tf.cos(np.pi * (step % cycle_steps) / cycle_steps))

inv_base = tf.exp(tf.log(0.01) / warmup_steps)
inv_decay = inv_base**(warmup_steps - step)
Expand Down