Skip to content

Conversation

@robieta
Copy link
Contributor

@robieta robieta commented Feb 27, 2018

Previously the official ResNet provided was the "v2" version from He, et al., (2016). However the "v1" described in He, et al., (2015) is also very popular. This PR allows the official ResNet model to be run using either the v1 or v2 conventions.

@robieta robieta requested a review from karmel February 27, 2018 21:02
@robieta robieta requested review from k-w-w and nealwu as code owners February 27, 2018 21:02
@tensorflow-jenkins
Copy link
Collaborator

Can one of the admins verify this patch?

@googlebot
Copy link

We found a Contributor License Agreement for you (the sender of this pull request), but were unable to find agreements for the commit author(s). If you authored these, maybe you used a different email address in the git commits than was used to sign the CLA (login here to double check)? If these were authored by someone else, then they will need to sign a CLA as well, and confirm that they're okay with these being contributed to Google.
In order to pass this check, please resolve this problem and have the pull request author add another comment and the bot will run again. If the bot doesn't comment, it means it doesn't think anything has changed.

Copy link
Contributor

@karmel karmel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Taylor, and congrats!

A few notes:

  • A PR comment describing the changes is always helpful.
  • I reviewed here for style and structure, but not for the actual changes to the model-- @wun, can you make sure this does the right thing WRT v1 versus v2?
  • Can you add some unit tests as well?

data_format=params['data_format'],
loss_filter_fn=loss_filter_fn)
loss_filter_fn=loss_filter_fn,
version=params['version'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Word of warning: this is going to conflict with #3472 . Won't be too hard to resolve, I think, but... be prepared.

Returns:
The output tensor of the block.
"""
if version == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like if we are going to have a model class, we may as well use it to store state. Is it feasible to limit the use of version to the Model.init? For example,

if version == 1:
  self.building_block_fn = _building_block_v1
elif version == 2:
  self.building_block_fn = _building_block_v2

and then just reference self.building_block_fn when the time comes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the block_fn selection logic into the init. It's cleaner, but it will break things that inherit the resnet model. (Not sure if we make any interface promises on these classes.)

return _building_block_v2(inputs=inputs, filters=filters, training=training,
projection_shortcut=projection_shortcut,
strides=strides, data_format=data_format)
raise ValueError("version should be 1 or 2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis check can also move to the Model.init. May as well be more specific/properly capitalized, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

training=training,
projection_shortcut=projection_shortcut,
strides=strides, data_format=data_format)
raise ValueError("version should be 1 or 2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto the above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dit... you get the idea.

conv_stride, first_pool_size, first_pool_stride,
second_pool_size, second_pool_stride, block_fn, block_sizes,
block_strides, final_size, data_format=None):
block_strides, final_size, version, data_format=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Version may as well take a default value here of _DEFAULT_VERSION

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I made the version non-private so cifar10 and imagenet classes could share the default.

# Only the first block per block_layer uses projection_shortcut and strides
inputs = block_fn(inputs, filters, training, projection_shortcut, strides,
data_format)
data_format, version=version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not sure if it's feasible, but not having to pass the version all the way down would keep things cleaner, it seems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this was a side effect of moving block_fn selection into the init.

tf.summary.image('images', features, max_outputs=6)

model = model_class(resnet_size, data_format)
model = model_class(resnet_size, data_format, version=version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the ordering here is inconsistent with the order of args in init? Oh wait, the child classes have data_format and version in a different order than the parent. That seems confusing-- can we standardize the ordering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If memory serves the order got weird because python doesn't allow args without defaults to come after those with. (I think that's the reason I eventually set the version=None default.) I'll see what I can do to make everything cleaner.


self.add_argument(
'-v', '--version', type=int, choices=[1, 2], dest="version",
metavar="", default=_DEFAULT_VERSION,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is metavar? Why have it if blank?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metavar is how argparse displays the variable. For whatever reason, when you have multiple spellings of a flag with choices set, rather than display it in -h as:

-v, --version {1, 2}

it is displayed as:
-v {1, 2}, --version {1, 2}

which is ugly. Setting metavar="" suppresses that behavior.

self.add_argument(
'-v', '--version', type=int, choices=[1, 2], dest="version",
metavar="", default=_DEFAULT_VERSION,
help="Version of resnet. (1 or 2)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please describe more fully. Ditto for the comments above-- somewhere, we will need a more full description of what the differences are. Maybe just citations here, and further details in the README?

@robieta robieta requested a review from sguada as a code owner February 28, 2018 00:52
strides: The block's stride. If greater than 1, this block will ultimately
downsample the input.
data_format: The input format ('channels_last' or 'channels_first').
data_format: The input format ('channels_last' or 'channels_first')..
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double period here

# since it performs a 1x1 convolution.
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
shortcut = batch_norm_relu(shortcut, training, data_format)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you don't want this here. Just add shortcut to inputs below and then call batch_norm_relu afterward on the combined value. In this case you can also remove intermediate_addition from batch_norm_relu above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, if you think the BN-ReLU is unnecessary for the projection I'll remove it. As far as removing the intermediate_addition arg, the 2016 v2 paper is clear that the shortcut addition goes between the BN and ReLU for v1, so I don't follow how that can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the TPU version of v1 just repeats the ReLU, which I suppose also works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see; you have the right idea with intermediate_addition then. In that case, for the current batch_norm_relu after this projection shortcut, I think we actually only want a batch norm, since the relu will come after the shortcut and residual are combined together.

The best thing to do then is likely to add a separate function for just batch norm (without relu), which we can use both here and below, in which case we should no longer need the intermediate_addition argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is the TPU version you're referring to? https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L171

This doesn't look quite correct to me, since there are too many relus. In fact the relu that comes at the end doesn't do anything, since the two values being added together have both been relu'd themselves. Goes along with what I mentioned of every ResNet implementation having minor differences. :)


if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
shortcut = batch_norm_relu(shortcut, training, data_format)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here.

return inputs


def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we actually combine the v1 and v2 functions into a single function? There is a lot of repeated code here, and I think it should also be simpler after the batch_norm_relu fix I suggested above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it doesn't seem ideal. As an exercise I put together an ultra concise function that could do v1 and v2 as well as regular and bottleneck.

_BOTTLENECK_LAYERS = [(1, 1, False, False), (3, 1, False, True), (1, 4, True, False)]
_REGULAR_LAYERS = [(3, 1, False, True), (3, 1, True, False)]
def _build_block(inputs, filters, training, projection_shortcut, strides,
    data_format, version, bottleneck):
  """
  Docstring removed for easier viewing on GitHub.
  """
  shortcut = inputs
  if version == 2:
    inputs = batch_norm_relu(inputs, training, data_format)

  if projection_shortcut is not None:
    shortcut = projection_shortcut(inputs)

  layers = _BOTTLENECK_LAYERS if bottleneck else _REGULAR_LAYERS
  for kernel_size, filter_num_mult, is_final, use_strides in layers:
    inputs = conv2d_fixed_padding(
        inputs=inputs, filters=filters, kernel_size=kernel_size,
        strides=strides if use_strides else 1,
        data_format=data_format)

    if version == 1:
      inputs = batch_norm_relu(inputs, training, data_format,
                       intermediate_addition=shortcut if is_final else None)
    else:  # version in (1, 2) check is performed by the resnet model
      if is_final:
        inputs += shortcut
      else:
        inputs = batch_norm_relu(inputs, training, data_format)
  return inputs

Even though this is < 50 lines (including the full docstring) compared to the ~170 being used now, it's not easy to read. I think this is one of the rare cases where some code duplication may be good simply to improve readability. (That doesn't necessarily mean what is there now is optimal.) @karmel thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, combining both v1 + v2 and regular + bottleneck is likely too much. I think it's fine to combine just v1 + v2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Questions one might use to decide:

  1. Are people more likely to want to read through the code for one version, or to compare the two versions? I feel like the answer to this is probably "read through one version," as most users want to run one version of ResNet, rather than use this as the place to learn about the difference between the two. That would imply the current (separated) version is preferable, because it's easier to parse (and extract) each individual model.
  2. Are developers more likely to make mistakes with duplicated code or branching code? Both are tricky. Duplicated might be slightly more error-prone, but branching code is pretty bad too.
  3. What does the rest of the code do-- duplicate or branch? We use some of each pattern. We duplicate for imagenet versus cifar, but branch in some of the internals in deciding how many strides, etc. etc. We seem to err on the side of duplication, though.

There doesn't seem to be an obvious answer, but I would say I lean slightly in favor of duplication in the name of clarity given the above. Assuming this is correct WRT what the ops are supposed to be, we can go with this, and reassess if we find that we are changing the code a lot and running into maintenance problems.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see why you both feel duplication results in less complication for the reader. My main concern here is that our code is getting very long (resnet.py is ~800 lines as of this PR), and that by adding lots of different features we are also adding a lot of extra code paths which can dilute the value of our code to the reader.

I agree now that combining these functions together with extra logic probably doesn't improve the situation much, but my main point is that keeping our code concise is something we should keep in mind.

srcs_version = "PY2AND3",
deps = [
# "//tensorflow"
# "//tensorflow",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you pulled in some commits from research/slim into this PR. Can you take those out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out this stems from me making a commit on this branch before all of my git setup was complete. Should be resolved now.

@robieta robieta force-pushed the multiple_resnet_versions branch from 2445268 to 11cb659 Compare February 28, 2018 17:19
@googlebot
Copy link

CLAs look good, thanks!

@nealwu nealwu removed the request for review from sguada February 28, 2018 19:08
inputs=inputs, filters=filters, kernel_size=3, strides=1,
data_format=data_format)
inputs = batch_norm_relu(inputs, training, data_format,
intermediate_addition=shortcut)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can replace this with:

inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs + shortcut)

after which batch_norm_relu doesn't need the intermediate_addition arg anymore (in fact, you can get rid of batch_norm_relu entirely if you like).

return inputs


def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, combining both v1 + v2 and regular + bottleneck is likely too much. I think it's fine to combine just v1 + v2.

[Identity Mappings in Deep Residual Networks](https://arxiv.org/pdf/1603.05027.pdf) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
[2] [Identity Mappings in Deep Residual Networks](https://arxiv.org/pdf/1603.05027.pdf) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.

In code v1 refers to the resnet defined in [1], while v2 correspondingly refers to [2]. The principle difference between the two versions is that v1 applies batch normalization and activation after convolution, while v2 applies BN, then activation, and finally convolution. A schematic comparison is presented in Figure 1 (left) of [2].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Let's spell out batch normalization for clarity.

data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
version: ResNet version. See README.md for details.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Args should still include a summary and valid values. Something like: Integer representing which version of the ResNet network to use. See README for details. Valid values: [1, 2]


################################################################################
# Functions building the ResNet model.
# Convenience functions building the ResNet model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "for building" might parse better



def batch_norm_relu(inputs, training, data_format, intermediate_addition=None):
"""Performs a batch normalization followed by a ReLU."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring with Args and Returns here would be nice-- what is intermediate_addition?

data_format=data_format)

inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking that it is intentional that the conv2d kernel_size and strides are different than in the non-bottleneck version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct.

return inputs


def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Questions one might use to decide:

  1. Are people more likely to want to read through the code for one version, or to compare the two versions? I feel like the answer to this is probably "read through one version," as most users want to run one version of ResNet, rather than use this as the place to learn about the difference between the two. That would imply the current (separated) version is preferable, because it's easier to parse (and extract) each individual model.
  2. Are developers more likely to make mistakes with duplicated code or branching code? Both are tricky. Duplicated might be slightly more error-prone, but branching code is pretty bad too.
  3. What does the rest of the code do-- duplicate or branch? We use some of each pattern. We duplicate for imagenet versus cifar, but branch in some of the internals in deciding how many strides, etc. etc. We seem to err on the side of duplication, though.

There doesn't seem to be an obvious answer, but I would say I lean slightly in favor of duplication in the name of clarity given the above. Assuming this is correct WRT what the ops are supposed to be, we can go with this, and reassess if we find that we are changing the code a lot and running into maintenance problems.

filters_out = filters
# Bottleneck blocks end with 4x the number of filters as they start with
filters_out = 4 * filters if block_fn is bottleneck_block else filters
if block_fn is _bottleneck_block_v1 or block_fn is _bottleneck_block_v2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This identity check is error-prone. Let's just pass in the bottleneck bool that we already have for this check.

second_pool_size, second_pool_stride, block_fn, block_sizes,
block_strides, final_size, data_format=None):
second_pool_size, second_pool_stride, block_sizes, block_strides,
final_size, version=DEFAULT_VERSION, bottleneck=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we keep the original ordering of the bottleneck arg? I had originally put them in order of use, so bottleneck would belong in the middle there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python doesn't allow args without defaults to follow those with, and I do like having defaults on version and bottleneck.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue with so many other params, and a well-defined set of callers of this, we should just remove the default. Especially since the default run with current command line args actually does use bottleneck blocks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bottleneck. Version we should leave.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Changed.

final_size: The expected size of the model after the second pooling.
version: Integer defining the order of layers. A detailed
description can be found in README.md.
bottleneck: Use regular blocks or bottleneck blocks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to remove unused block_fn arg above.

@robieta robieta force-pushed the multiple_resnet_versions branch from 79494b7 to 7a50163 Compare March 1, 2018 20:54
strides=block["strides"],
data_format=data_format)

return inputs + shortcut
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karmel @nealwu I'm putting this forward as a plausible middle ground. By not having much duplication, we can be very liberal about what goes in docstrings and hopefully get clarity that way. But it unfortunately isn't as clear as the 4 function status quo and I think that means that we have to live with the duplicate code method. If you have strong feelings I'm willing to be persuaded otherwise though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, but I think it really underscores that the readability of the four separate is preferable in this case. And this is coming from someone who goes long distances to get rid of duplicated code under most circumstances.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I guess that settles it.

@robieta robieta requested review from sherrym and shlens as code owners March 2, 2018 18:09
@robieta robieta force-pushed the multiple_resnet_versions branch from 795f7d9 to d78af47 Compare March 2, 2018 18:25
@karmel karmel removed request for sherrym and shlens March 2, 2018 18:47
@robieta robieta force-pushed the multiple_resnet_versions branch 6 times, most recently from 2d10242 to 51e62b3 Compare March 12, 2018 18:51
@robieta
Copy link
Contributor Author

robieta commented Mar 12, 2018

I ran 3 runs of master and 4 of the branch for ResNet_50 on ImageNet under identical conditions. The final accuracies are:

          mean    |    std
master    0.7503  |    0.0012
branch    0.7488  |    0.0011

@robieta robieta force-pushed the multiple_resnet_versions branch from 51e62b3 to 1a62840 Compare March 12, 2018 19:10
@robieta robieta merged commit f6bba08 into master Mar 12, 2018
@robieta robieta deleted the multiple_resnet_versions branch March 12, 2018 21:27


import numpy as np
import resnet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will now have to be from official.resnet import resnet. Also, nit, should come below the tf import. That can be in a follow up PR though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants