Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP/NO_MERGE] Prototype RegularizedShortcut #4549

Closed
wants to merge 10 commits into from

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 6, 2021

This is an early prototype utility based on FX.

The target is to detect Residual Connections in arbitrary Model architectures and modify the network to add regularlization blocks (such as StochasticDepth).

Example usage:

# Before
model = resnet18()
fx.symbolic_trace(model).graph.print_tabular()

# After addition
regularizer_layer = partial(StochasticDepth, p=0.0, mode="row")
model = add_regularized_shortcut(model, BasicBlock, regularizer_layer)
fx.symbolic_trace(model).graph.print_tabular()

# After deletion
model = del_regularized_shortcut(model)
fx.symbolic_trace(model).graph.print_tabular()

Output:

Before
opcode         name                   target                                                   args                                   kwargs
-------------  ---------------------  -------------------------------------------------------  -------------------------------------  --------
placeholder    x                      x                                                        ()                                     {}
call_module    conv1                  conv1                                                    (x,)                                   {}
call_module    bn1                    bn1                                                      (conv1,)                               {}
call_module    relu                   relu                                                     (bn1,)                                 {}
call_module    maxpool                maxpool                                                  (relu,)                                {}
call_module    layer1_0_conv1         layer1.0.conv1                                           (maxpool,)                             {}
call_module    layer1_0_bn1           layer1.0.bn1                                             (layer1_0_conv1,)                      {}
call_module    layer1_0_relu          layer1.0.relu                                            (layer1_0_bn1,)                        {}
call_module    layer1_0_conv2         layer1.0.conv2                                           (layer1_0_relu,)                       {}
call_module    layer1_0_bn2           layer1.0.bn2                                             (layer1_0_conv2,)                      {}
call_function  add                    <built-in function add>                                  (layer1_0_bn2, maxpool)                {}
call_module    layer1_0_relu_1        layer1.0.relu                                            (add,)                                 {}
call_module    layer1_1_conv1         layer1.1.conv1                                           (layer1_0_relu_1,)                     {}
call_module    layer1_1_bn1           layer1.1.bn1                                             (layer1_1_conv1,)                      {}
call_module    layer1_1_relu          layer1.1.relu                                            (layer1_1_bn1,)                        {}
call_module    layer1_1_conv2         layer1.1.conv2                                           (layer1_1_relu,)                       {}
call_module    layer1_1_bn2           layer1.1.bn2                                             (layer1_1_conv2,)                      {}
call_function  add_1                  <built-in function add>                                  (layer1_1_bn2, layer1_0_relu_1)        {}
call_module    layer1_1_relu_1        layer1.1.relu                                            (add_1,)                               {}
call_module    layer2_0_conv1         layer2.0.conv1                                           (layer1_1_relu_1,)                     {}
call_module    layer2_0_bn1           layer2.0.bn1                                             (layer2_0_conv1,)                      {}
call_module    layer2_0_relu          layer2.0.relu                                            (layer2_0_bn1,)                        {}
call_module    layer2_0_conv2         layer2.0.conv2                                           (layer2_0_relu,)                       {}
call_module    layer2_0_bn2           layer2.0.bn2                                             (layer2_0_conv2,)                      {}
call_module    layer2_0_downsample_0  layer2.0.downsample.0                                    (layer1_1_relu_1,)                     {}
call_module    layer2_0_downsample_1  layer2.0.downsample.1                                    (layer2_0_downsample_0,)               {}
call_function  add_2                  <built-in function add>                                  (layer2_0_bn2, layer2_0_downsample_1)  {}
call_module    layer2_0_relu_1        layer2.0.relu                                            (add_2,)                               {}
call_module    layer2_1_conv1         layer2.1.conv1                                           (layer2_0_relu_1,)                     {}
call_module    layer2_1_bn1           layer2.1.bn1                                             (layer2_1_conv1,)                      {}
call_module    layer2_1_relu          layer2.1.relu                                            (layer2_1_bn1,)                        {}
call_module    layer2_1_conv2         layer2.1.conv2                                           (layer2_1_relu,)                       {}
call_module    layer2_1_bn2           layer2.1.bn2                                             (layer2_1_conv2,)                      {}
call_function  add_3                  <built-in function add>                                  (layer2_1_bn2, layer2_0_relu_1)        {}
call_module    layer2_1_relu_1        layer2.1.relu                                            (add_3,)                               {}
call_module    layer3_0_conv1         layer3.0.conv1                                           (layer2_1_relu_1,)                     {}
call_module    layer3_0_bn1           layer3.0.bn1                                             (layer3_0_conv1,)                      {}
call_module    layer3_0_relu          layer3.0.relu                                            (layer3_0_bn1,)                        {}
call_module    layer3_0_conv2         layer3.0.conv2                                           (layer3_0_relu,)                       {}
call_module    layer3_0_bn2           layer3.0.bn2                                             (layer3_0_conv2,)                      {}
call_module    layer3_0_downsample_0  layer3.0.downsample.0                                    (layer2_1_relu_1,)                     {}
call_module    layer3_0_downsample_1  layer3.0.downsample.1                                    (layer3_0_downsample_0,)               {}
call_function  add_4                  <built-in function add>                                  (layer3_0_bn2, layer3_0_downsample_1)  {}
call_module    layer3_0_relu_1        layer3.0.relu                                            (add_4,)                               {}
call_module    layer3_1_conv1         layer3.1.conv1                                           (layer3_0_relu_1,)                     {}
call_module    layer3_1_bn1           layer3.1.bn1                                             (layer3_1_conv1,)                      {}
call_module    layer3_1_relu          layer3.1.relu                                            (layer3_1_bn1,)                        {}
call_module    layer3_1_conv2         layer3.1.conv2                                           (layer3_1_relu,)                       {}
call_module    layer3_1_bn2           layer3.1.bn2                                             (layer3_1_conv2,)                      {}
call_function  add_5                  <built-in function add>                                  (layer3_1_bn2, layer3_0_relu_1)        {}
call_module    layer3_1_relu_1        layer3.1.relu                                            (add_5,)                               {}
call_module    layer4_0_conv1         layer4.0.conv1                                           (layer3_1_relu_1,)                     {}
call_module    layer4_0_bn1           layer4.0.bn1                                             (layer4_0_conv1,)                      {}
call_module    layer4_0_relu          layer4.0.relu                                            (layer4_0_bn1,)                        {}
call_module    layer4_0_conv2         layer4.0.conv2                                           (layer4_0_relu,)                       {}
call_module    layer4_0_bn2           layer4.0.bn2                                             (layer4_0_conv2,)                      {}
call_module    layer4_0_downsample_0  layer4.0.downsample.0                                    (layer3_1_relu_1,)                     {}
call_module    layer4_0_downsample_1  layer4.0.downsample.1                                    (layer4_0_downsample_0,)               {}
call_function  add_6                  <built-in function add>                                  (layer4_0_bn2, layer4_0_downsample_1)  {}
call_module    layer4_0_relu_1        layer4.0.relu                                            (add_6,)                               {}
call_module    layer4_1_conv1         layer4.1.conv1                                           (layer4_0_relu_1,)                     {}
call_module    layer4_1_bn1           layer4.1.bn1                                             (layer4_1_conv1,)                      {}
call_module    layer4_1_relu          layer4.1.relu                                            (layer4_1_bn1,)                        {}
call_module    layer4_1_conv2         layer4.1.conv2                                           (layer4_1_relu,)                       {}
call_module    layer4_1_bn2           layer4.1.bn2                                             (layer4_1_conv2,)                      {}
call_function  add_7                  <built-in function add>                                  (layer4_1_bn2, layer4_0_relu_1)        {}
call_module    layer4_1_relu_1        layer4.1.relu                                            (add_7,)                               {}
call_module    avgpool                avgpool                                                  (layer4_1_relu_1,)                     {}
call_function  flatten                <built-in method flatten of type object at 0x112aac6c0>  (avgpool, 1)                           {}
call_module    fc                     fc                                                       (flatten,)                             {}
output         output                 output                                                   (fc,)                                  {}
After addition
opcode         name                   target                                                   args                                   kwargs
-------------  ---------------------  -------------------------------------------------------  -------------------------------------  --------
placeholder    x                      x                                                        ()                                     {}
call_module    conv1                  conv1                                                    (x,)                                   {}
call_module    bn1                    bn1                                                      (conv1,)                               {}
call_module    relu                   relu                                                     (bn1,)                                 {}
call_module    maxpool                maxpool                                                  (relu,)                                {}
call_module    layer1_0_conv1         layer1.0.conv1                                           (maxpool,)                             {}
call_module    layer1_0_bn1           layer1.0.bn1                                             (layer1_0_conv1,)                      {}
call_module    layer1_0_relu          layer1.0.relu                                            (layer1_0_bn1,)                        {}
call_module    layer1_0_conv2         layer1.0.conv2                                           (layer1_0_relu,)                       {}
call_module    layer1_0_bn2           layer1.0.bn2                                             (layer1_0_conv2,)                      {}
call_function  stochastic_depth       <function stochastic_depth at 0x7fcc18cf60d0>            (layer1_0_bn2, 0.0, 'row', True)       {}
call_function  add                    <built-in function add>                                  (maxpool, stochastic_depth)            {}
call_module    layer1_0_relu_1        layer1.0.relu                                            (add,)                                 {}
call_module    layer1_1_conv1         layer1.1.conv1                                           (layer1_0_relu_1,)                     {}
call_module    layer1_1_bn1           layer1.1.bn1                                             (layer1_1_conv1,)                      {}
call_module    layer1_1_relu          layer1.1.relu                                            (layer1_1_bn1,)                        {}
call_module    layer1_1_conv2         layer1.1.conv2                                           (layer1_1_relu,)                       {}
call_module    layer1_1_bn2           layer1.1.bn2                                             (layer1_1_conv2,)                      {}
call_function  stochastic_depth_1     <function stochastic_depth at 0x7fcc18cf60d0>            (layer1_1_bn2, 0.0, 'row', True)       {}
call_function  add_1                  <built-in function add>                                  (layer1_0_relu_1, stochastic_depth_1)  {}
call_module    layer1_1_relu_1        layer1.1.relu                                            (add_1,)                               {}
call_module    layer2_0_conv1         layer2.0.conv1                                           (layer1_1_relu_1,)                     {}
call_module    layer2_0_bn1           layer2.0.bn1                                             (layer2_0_conv1,)                      {}
call_module    layer2_0_relu          layer2.0.relu                                            (layer2_0_bn1,)                        {}
call_module    layer2_0_conv2         layer2.0.conv2                                           (layer2_0_relu,)                       {}
call_module    layer2_0_bn2           layer2.0.bn2                                             (layer2_0_conv2,)                      {}
call_module    layer2_0_downsample_0  layer2.0.downsample.0                                    (layer1_1_relu_1,)                     {}
call_module    layer2_0_downsample_1  layer2.0.downsample.1                                    (layer2_0_downsample_0,)               {}
call_function  add_2                  <built-in function add>                                  (layer2_0_bn2, layer2_0_downsample_1)  {}
call_module    layer2_0_relu_1        layer2.0.relu                                            (add_2,)                               {}
call_module    layer2_1_conv1         layer2.1.conv1                                           (layer2_0_relu_1,)                     {}
call_module    layer2_1_bn1           layer2.1.bn1                                             (layer2_1_conv1,)                      {}
call_module    layer2_1_relu          layer2.1.relu                                            (layer2_1_bn1,)                        {}
call_module    layer2_1_conv2         layer2.1.conv2                                           (layer2_1_relu,)                       {}
call_module    layer2_1_bn2           layer2.1.bn2                                             (layer2_1_conv2,)                      {}
call_function  stochastic_depth_2     <function stochastic_depth at 0x7fcc18cf60d0>            (layer2_1_bn2, 0.0, 'row', True)       {}
call_function  add_3                  <built-in function add>                                  (layer2_0_relu_1, stochastic_depth_2)  {}
call_module    layer2_1_relu_1        layer2.1.relu                                            (add_3,)                               {}
call_module    layer3_0_conv1         layer3.0.conv1                                           (layer2_1_relu_1,)                     {}
call_module    layer3_0_bn1           layer3.0.bn1                                             (layer3_0_conv1,)                      {}
call_module    layer3_0_relu          layer3.0.relu                                            (layer3_0_bn1,)                        {}
call_module    layer3_0_conv2         layer3.0.conv2                                           (layer3_0_relu,)                       {}
call_module    layer3_0_bn2           layer3.0.bn2                                             (layer3_0_conv2,)                      {}
call_module    layer3_0_downsample_0  layer3.0.downsample.0                                    (layer2_1_relu_1,)                     {}
call_module    layer3_0_downsample_1  layer3.0.downsample.1                                    (layer3_0_downsample_0,)               {}
call_function  add_4                  <built-in function add>                                  (layer3_0_bn2, layer3_0_downsample_1)  {}
call_module    layer3_0_relu_1        layer3.0.relu                                            (add_4,)                               {}
call_module    layer3_1_conv1         layer3.1.conv1                                           (layer3_0_relu_1,)                     {}
call_module    layer3_1_bn1           layer3.1.bn1                                             (layer3_1_conv1,)                      {}
call_module    layer3_1_relu          layer3.1.relu                                            (layer3_1_bn1,)                        {}
call_module    layer3_1_conv2         layer3.1.conv2                                           (layer3_1_relu,)                       {}
call_module    layer3_1_bn2           layer3.1.bn2                                             (layer3_1_conv2,)                      {}
call_function  stochastic_depth_3     <function stochastic_depth at 0x7fcc18cf60d0>            (layer3_1_bn2, 0.0, 'row', True)       {}
call_function  add_5                  <built-in function add>                                  (layer3_0_relu_1, stochastic_depth_3)  {}
call_module    layer3_1_relu_1        layer3.1.relu                                            (add_5,)                               {}
call_module    layer4_0_conv1         layer4.0.conv1                                           (layer3_1_relu_1,)                     {}
call_module    layer4_0_bn1           layer4.0.bn1                                             (layer4_0_conv1,)                      {}
call_module    layer4_0_relu          layer4.0.relu                                            (layer4_0_bn1,)                        {}
call_module    layer4_0_conv2         layer4.0.conv2                                           (layer4_0_relu,)                       {}
call_module    layer4_0_bn2           layer4.0.bn2                                             (layer4_0_conv2,)                      {}
call_module    layer4_0_downsample_0  layer4.0.downsample.0                                    (layer3_1_relu_1,)                     {}
call_module    layer4_0_downsample_1  layer4.0.downsample.1                                    (layer4_0_downsample_0,)               {}
call_function  add_6                  <built-in function add>                                  (layer4_0_bn2, layer4_0_downsample_1)  {}
call_module    layer4_0_relu_1        layer4.0.relu                                            (add_6,)                               {}
call_module    layer4_1_conv1         layer4.1.conv1                                           (layer4_0_relu_1,)                     {}
call_module    layer4_1_bn1           layer4.1.bn1                                             (layer4_1_conv1,)                      {}
call_module    layer4_1_relu          layer4.1.relu                                            (layer4_1_bn1,)                        {}
call_module    layer4_1_conv2         layer4.1.conv2                                           (layer4_1_relu,)                       {}
call_module    layer4_1_bn2           layer4.1.bn2                                             (layer4_1_conv2,)                      {}
call_function  stochastic_depth_4     <function stochastic_depth at 0x7fcc18cf60d0>            (layer4_1_bn2, 0.0, 'row', True)       {}
call_function  add_7                  <built-in function add>                                  (layer4_0_relu_1, stochastic_depth_4)  {}
call_module    layer4_1_relu_1        layer4.1.relu                                            (add_7,)                               {}
call_module    avgpool                avgpool                                                  (layer4_1_relu_1,)                     {}
call_function  flatten                <built-in method flatten of type object at 0x112aac6c0>  (avgpool, 1)                           {}
call_module    fc                     fc                                                       (flatten,)                             {}
output         output                 output                                                   (fc,)                                  {}
After deletion
opcode         name                   target                                                   args                                   kwargs
-------------  ---------------------  -------------------------------------------------------  -------------------------------------  --------
placeholder    x                      x                                                        ()                                     {}
call_module    conv1                  conv1                                                    (x,)                                   {}
call_module    bn1                    bn1                                                      (conv1,)                               {}
call_module    relu                   relu                                                     (bn1,)                                 {}
call_module    maxpool                maxpool                                                  (relu,)                                {}
call_module    layer1_0_conv1         layer1.0.conv1                                           (maxpool,)                             {}
call_module    layer1_0_bn1           layer1.0.bn1                                             (layer1_0_conv1,)                      {}
call_module    layer1_0_relu          layer1.0.relu                                            (layer1_0_bn1,)                        {}
call_module    layer1_0_conv2         layer1.0.conv2                                           (layer1_0_relu,)                       {}
call_module    layer1_0_bn2           layer1.0.bn2                                             (layer1_0_conv2,)                      {}
call_function  add                    <built-in function add>                                  (maxpool, layer1_0_bn2)                {}
call_module    layer1_0_relu_1        layer1.0.relu                                            (add,)                                 {}
call_module    layer1_1_conv1         layer1.1.conv1                                           (layer1_0_relu_1,)                     {}
call_module    layer1_1_bn1           layer1.1.bn1                                             (layer1_1_conv1,)                      {}
call_module    layer1_1_relu          layer1.1.relu                                            (layer1_1_bn1,)                        {}
call_module    layer1_1_conv2         layer1.1.conv2                                           (layer1_1_relu,)                       {}
call_module    layer1_1_bn2           layer1.1.bn2                                             (layer1_1_conv2,)                      {}
call_function  add_1                  <built-in function add>                                  (layer1_0_relu_1, layer1_1_bn2)        {}
call_module    layer1_1_relu_1        layer1.1.relu                                            (add_1,)                               {}
call_module    layer2_0_conv1         layer2.0.conv1                                           (layer1_1_relu_1,)                     {}
call_module    layer2_0_bn1           layer2.0.bn1                                             (layer2_0_conv1,)                      {}
call_module    layer2_0_relu          layer2.0.relu                                            (layer2_0_bn1,)                        {}
call_module    layer2_0_conv2         layer2.0.conv2                                           (layer2_0_relu,)                       {}
call_module    layer2_0_bn2           layer2.0.bn2                                             (layer2_0_conv2,)                      {}
call_module    layer2_0_downsample_0  layer2.0.downsample.0                                    (layer1_1_relu_1,)                     {}
call_module    layer2_0_downsample_1  layer2.0.downsample.1                                    (layer2_0_downsample_0,)               {}
call_function  add_2                  <built-in function add>                                  (layer2_0_bn2, layer2_0_downsample_1)  {}
call_module    layer2_0_relu_1        layer2.0.relu                                            (add_2,)                               {}
call_module    layer2_1_conv1         layer2.1.conv1                                           (layer2_0_relu_1,)                     {}
call_module    layer2_1_bn1           layer2.1.bn1                                             (layer2_1_conv1,)                      {}
call_module    layer2_1_relu          layer2.1.relu                                            (layer2_1_bn1,)                        {}
call_module    layer2_1_conv2         layer2.1.conv2                                           (layer2_1_relu,)                       {}
call_module    layer2_1_bn2           layer2.1.bn2                                             (layer2_1_conv2,)                      {}
call_function  add_3                  <built-in function add>                                  (layer2_0_relu_1, layer2_1_bn2)        {}
call_module    layer2_1_relu_1        layer2.1.relu                                            (add_3,)                               {}
call_module    layer3_0_conv1         layer3.0.conv1                                           (layer2_1_relu_1,)                     {}
call_module    layer3_0_bn1           layer3.0.bn1                                             (layer3_0_conv1,)                      {}
call_module    layer3_0_relu          layer3.0.relu                                            (layer3_0_bn1,)                        {}
call_module    layer3_0_conv2         layer3.0.conv2                                           (layer3_0_relu,)                       {}
call_module    layer3_0_bn2           layer3.0.bn2                                             (layer3_0_conv2,)                      {}
call_module    layer3_0_downsample_0  layer3.0.downsample.0                                    (layer2_1_relu_1,)                     {}
call_module    layer3_0_downsample_1  layer3.0.downsample.1                                    (layer3_0_downsample_0,)               {}
call_function  add_4                  <built-in function add>                                  (layer3_0_bn2, layer3_0_downsample_1)  {}
call_module    layer3_0_relu_1        layer3.0.relu                                            (add_4,)                               {}
call_module    layer3_1_conv1         layer3.1.conv1                                           (layer3_0_relu_1,)                     {}
call_module    layer3_1_bn1           layer3.1.bn1                                             (layer3_1_conv1,)                      {}
call_module    layer3_1_relu          layer3.1.relu                                            (layer3_1_bn1,)                        {}
call_module    layer3_1_conv2         layer3.1.conv2                                           (layer3_1_relu,)                       {}
call_module    layer3_1_bn2           layer3.1.bn2                                             (layer3_1_conv2,)                      {}
call_function  add_5                  <built-in function add>                                  (layer3_0_relu_1, layer3_1_bn2)        {}
call_module    layer3_1_relu_1        layer3.1.relu                                            (add_5,)                               {}
call_module    layer4_0_conv1         layer4.0.conv1                                           (layer3_1_relu_1,)                     {}
call_module    layer4_0_bn1           layer4.0.bn1                                             (layer4_0_conv1,)                      {}
call_module    layer4_0_relu          layer4.0.relu                                            (layer4_0_bn1,)                        {}
call_module    layer4_0_conv2         layer4.0.conv2                                           (layer4_0_relu,)                       {}
call_module    layer4_0_bn2           layer4.0.bn2                                             (layer4_0_conv2,)                      {}
call_module    layer4_0_downsample_0  layer4.0.downsample.0                                    (layer3_1_relu_1,)                     {}
call_module    layer4_0_downsample_1  layer4.0.downsample.1                                    (layer4_0_downsample_0,)               {}
call_function  add_6                  <built-in function add>                                  (layer4_0_bn2, layer4_0_downsample_1)  {}
call_module    layer4_0_relu_1        layer4.0.relu                                            (add_6,)                               {}
call_module    layer4_1_conv1         layer4.1.conv1                                           (layer4_0_relu_1,)                     {}
call_module    layer4_1_bn1           layer4.1.bn1                                             (layer4_1_conv1,)                      {}
call_module    layer4_1_relu          layer4.1.relu                                            (layer4_1_bn1,)                        {}
call_module    layer4_1_conv2         layer4.1.conv2                                           (layer4_1_relu,)                       {}
call_module    layer4_1_bn2           layer4.1.bn2                                             (layer4_1_conv2,)                      {}
call_function  add_7                  <built-in function add>                                  (layer4_0_relu_1, layer4_1_bn2)        {}
call_module    layer4_1_relu_1        layer4.1.relu                                            (add_7,)                               {}
call_module    avgpool                avgpool                                                  (layer4_1_relu_1,)                     {}
call_function  flatten                <built-in method flatten of type object at 0x112aac6c0>  (avgpool, 1)                           {}
call_module    fc                     fc                                                       (flatten,)                             {}
output         output                 output                                                   (fc,)                                  {}

Also tested with:

model = add_regularized_shortcut(resnet50(), Bottleneck, partial(StochasticDepth, p=0.0, mode="row"))
model = add_regularized_shortcut(mobilenet_v2(), InvertedResidual, partial(StochasticDepth, p=0.0, mode="row"))
model = add_regularized_shortcut(mobilenet_v3_small(), InvertedResidual, partial(StochasticDepth, p=0.0, mode="row"))

model = del_regularized_shortcut(efficientnet_b0(), block_types=StochasticDepth, op=None) # First delete original StochasticDepth
model = add_regularized_shortcut(model, MBConv, partial(StochasticDepth, p=0.0, mode="row"))

Affected by pytorch/pytorch#66197 and pytorch/pytorch#66335

@datumbox datumbox requested a review from Chillee October 6, 2021 18:02
Copy link

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall this looks OK. If I understand correctly, the procedure is:

  1. Iterate through the named modules in the module hierarchy, and for each module that's part of the block_types of interest:
    a. Add the shortcut module
    b. trace the module and search for a residual connection (i.e. add node with two input and a placeholder input)
    c. Replace the residual connection with the shortcut module

torchvision/prototype/ops/_utils.py Outdated Show resolved Hide resolved
torchvision/prototype/ops/_utils.py Show resolved Hide resolved
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall this looks OK. If I understand correctly, the procedure is...

@jamesr66a Thanks a lot for reviewing. Your description of the approach is correct.

I was worried that looping through named_modules, tracing independently the graphs of the submodules and then overwriting the original modules would be problematic. Just to be safe, below I highlight the bits that concerned me. If you have any thoughts on how to improve it I'm happy to adopt it.

if isinstance(m, block_types):
# Add the Layer directly on submodule prior tracing
# workaround due to https://github.com/pytorch/pytorch/issues/66197
m.add_module(_MODULE_NAME, RegularizedShortcut(regularizer_layer))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I wanted to create the Layer on the fly and attach it directly on the graph but the pytorch/pytorch#66197 issue prohibits me from doing this. Here I attach it to the module just before tracing it as a workaround. Any concerns?

It will be removed once the issue is fixed.

with graph.inserting_after(node):
# Always put the shortcut value first
args = node.args if node.args[0] == input else node.args[::-1]
node.replace_all_uses_with(graph.call_module(_MODULE_NAME, args))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling the previously created module by name. Hopefully this will be replaced with something like the following:

fn_impl_traced = torch.fx.symbolic_trace(RegularizedShortcut(regularizer_layer))
args = node.args if node.args[0] == input else node.args[::-1]
fn_impl_output_node = fn_impl_traced(*map_arg(args, Proxy))
node.replace_all_uses_with(fn_impl_output_node.node)

torchvision/prototype/ops/_utils.py Show resolved Hide resolved
for node in graph.nodes:
# The isinstance() won't work if the model has already been traced before because it loses
# the class info of submodules. See https://github.com/pytorch/pytorch/issues/66335
if node.op == "call_module" and isinstance(model.get_submodule(node.target), block_types):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a We just figured out that FX traced models lose their submodule class information. This means that for a model that has been traced before, we can't use isinstance() to identify its Block type. Is this intentional or a bug?

@datumbox datumbox marked this pull request as draft October 14, 2021 09:58
@datumbox datumbox closed this Oct 27, 2021
@datumbox datumbox deleted the prototype/regularized_shortcut branch October 27, 2021 15:56
@datumbox datumbox mentioned this pull request Feb 13, 2022
24 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants