Skip to content

Conversation

assafshocher
Copy link
Contributor

@assafshocher assafshocher commented Jan 30, 2020

Initializing weights of group-conv with init.dirac_, and applying, previously resulted in an output that makes no sense:

x = torch.randn([1, 3, 3, 3])
print('input:\n', x)
conv_layer = torch.nn.Conv2d(3, 3, 3, padding=1, groups=3, bias=False)
torch.nn.init.dirac_(conv_layer.weight.data)
print('\noutput (before this PR):\n',conv_layer(x))


input:
 tensor([[[[ 0.5369, -1.1428,  0.1031],
          [ 0.4638, -0.0854, -0.6553],
          [ 0.8321, -2.5926, -0.3214]],

         [[-0.2289, -0.0895,  0.4407],
          [ 1.2309, -1.2096, -1.5216],
          [-0.1798,  1.1694,  0.3469]],

         [[ 0.1905,  0.8095,  0.5490],
          [-0.4525, -0.4284, -0.1141],
          [ 1.1857, -0.9246, -0.5119]]]])

output (before this PR):
 tensor([[[[ 0.5369, -1.1428,  0.1031],
          [ 0.4638, -0.0854, -0.6553],
          [ 0.8321, -2.5926, -0.3214]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MkldnnConvolutionBackward>)

This PR allows introducing groups to the initialization:

torch.nn.init.dirac_(conv_layer.weight.data, groups=3)
print('output (after this PR):\n', conv_layer(x))

output (after this PR):
 tensor([[[[ 0.5369, -1.1428,  0.1031],
          [ 0.4638, -0.0854, -0.6553],
          [ 0.8321, -2.5926, -0.3214]],

         [[-0.2289, -0.0895,  0.4407],
          [ 1.2309, -1.2096, -1.5216],
          [-0.1798,  1.1694,  0.3469]],

         [[ 0.1905,  0.8095,  0.5490],
          [-0.4525, -0.4284, -0.1141],
          [ 1.1857, -0.9246, -0.5119]]]], grad_fn=<MkldnnConvolutionBackward>)

When out_channels is different than input_channels, it does the natural thing which is applying identity in each group separately:

x = torch.randn([1, 2, 3, 3])
print('input:\n', x)
conv_layer = torch.nn.Conv2d(2, 4, 3, padding=1, groups=2, bias=False)
torch.nn.init.dirac_(conv_layer.weight.data, groups=2)
print('\noutput:\n', conv_layer(x))


input:
 tensor([[[[ 1.2205, -0.6608,  0.8640],
          [-0.5464,  1.1288,  1.4726],
          [-0.6693,  0.4000, -1.7613]],

         [[-0.8760, -0.8814, -0.4705],
          [ 0.6283, -0.5943,  0.6873],
          [-0.6852,  1.4723,  0.3325]]]])

output:
 tensor([[[[ 1.2205, -0.6608,  0.8640],
          [-0.5464,  1.1288,  1.4726],
          [-0.6693,  0.4000, -1.7613]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[-0.8760, -0.8814, -0.4705],
          [ 0.6283, -0.5943,  0.6873],
          [-0.6852,  1.4723,  0.3325]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MkldnnConvolutionBackward>)

Argument 'groups' defaults to 1 so it is backward compatible.

Tests are modified to include cases of with groups>1 but also contain groups=1 cases.

@kostmo
Copy link
Member

kostmo commented Jan 30, 2020

💊 CircleCI build failures summary and remediations

As of commit 94b539d:

  • 3/4 failures introduced in this PR
  • 1/4 recognized as flaky ❄️
    • Re-run these jobs?

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🕵️ 3 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakage:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (1/3)

Step: "Build" (full log | pattern match details)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/verbatim-sources/workflows-pytorch-ios-builds.yml 
Auto-merging .circleci/verbatim-sources/workflows-pytorch-ios-builds.yml 
CONFLICT (add/add): Merge conflict in .circleci/verbatim-sources/workflows-binary-builds-smoke-subset.yml 
Auto-merging .circleci/verbatim-sources/workflows-binary-builds-smoke-subset.yml 
CONFLICT (add/add): Merge conflict in .circleci/verbatim-sources/pytorch-build-params.yml 
Auto-merging .circleci/verbatim-sources/pytorch-build-params.yml 
CONFLICT (add/add): Merge conflict in .circleci/config.yml 
Auto-merging .circleci/config.yml 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_data.py 
Auto-merging .circleci/cimodel/data/pytorch_build_data.py 
Automatic merge failed; fix conflicts and then commit the result. 

See CircleCI build pytorch_xla_linux_xenial_py3_6_clang7_build (2/3)

Step: "Build" (full log | pattern match details)

Feb 12 20:27:49 Failed to generate ATEN bindings: ['/var/lib/jenkins/workspace/xla/scripts/generate_code.sh']
Feb 12 20:27:49 Generated 880 wrappers for /var/lib/jenkins/workspace/xla/scripts/../../torch/csrc/autograd/generated/RegistrationDeclarations.h 
Feb 12 20:27:49 AtenXlaType function missed override: Tensor argmax(const Tensor& self, c10::optional<int64_t> dim, bool keepdim); // argmax(Tensor,c10::optional<int64_t>,bool)->Tensor 
Feb 12 20:27:49 AtenXlaType function missed override: Tensor argmin(const Tensor& self, c10::optional<int64_t> dim, bool keepdim); // argmin(Tensor,c10::optional<int64_t>,bool)->Tensor 
Feb 12 20:27:49 Traceback (most recent call last): 
Feb 12 20:27:49   File "/var/lib/jenkins/workspace/xla/scripts/gen.py", line 1053, in <module> 
Feb 12 20:27:49     generate(args) 
Feb 12 20:27:49   File "/var/lib/jenkins/workspace/xla/scripts/gen.py", line 1023, in generate 
Feb 12 20:27:49     assert check_overrides(overrides, overridden) 
Feb 12 20:27:49 AssertionError 
Feb 12 20:27:49 Building torch_xla version: 0.8 
Feb 12 20:27:49 Failed to generate ATEN bindings: ['/var/lib/jenkins/workspace/xla/scripts/generate_code.sh'] 
Feb 12 20:27:49 + cleanup 
Feb 12 20:27:49 + retcode=1 
Feb 12 20:27:49 + set +x 
Feb 12 20:27:49 =================== sccache compilation log =================== 
Feb 12 20:27:49 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 12 20:27:49 Compile requests               2544 
Feb 12 20:27:49 Compile requests executed      2284 
Feb 12 20:27:49 Cache hits                     1205 
Feb 12 20:27:49 Cache misses                   1067 
Feb 12 20:27:49 Cache timeouts                    0 

See CircleCI build pytorch_linux_xenial_py2_7_9_test (3/3)

Step: "Test" (full log | pattern match details)

Feb 12 20:33:21 RuntimeError: test_nn failed!
Feb 12 20:33:21   File "/opt/python/2.7.9/lib/python2.7/site-packages/xmlrunner/runner.py", line 7, in <module> 
Feb 12 20:33:21     from .result import _XMLTestResult 
Feb 12 20:33:21   File "/opt/python/2.7.9/lib/python2.7/site-packages/xmlrunner/result.py", line 42, in <module> 
Feb 12 20:33:21     for (low, high) in _illegal_unichrs 
Feb 12 20:33:21 ValueError: chr() arg not in range(256) 
Feb 12 20:33:21 Traceback (most recent call last): 
Feb 12 20:33:21   File "test/run_test.py", line 457, in <module> 
Feb 12 20:33:21     main() 
Feb 12 20:33:21   File "test/run_test.py", line 450, in main 
Feb 12 20:33:21     raise RuntimeError(message) 
Feb 12 20:33:21 RuntimeError: test_nn failed! 
Feb 12 20:33:21 =================== sccache compilation log =================== 
Feb 12 20:33:21 + cleanup 
Feb 12 20:33:21 + retcode=1 
Feb 12 20:33:21 + set +x 
Feb 12 20:33:21 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 12 20:33:21 Compile requests                 10 
Feb 12 20:33:21 Compile requests executed         8 
Feb 12 20:33:21 Cache hits                        2 
Feb 12 20:33:21 Cache misses                      6 
Feb 12 20:33:21 Cache timeouts                    0 

❄️ 1 failure recognized as flaky

The following build failures have been detected as flaky and may not be your fault:

See CircleCI build caffe2_onnx_py3_6_clang7_ubuntu16_04_test (1/1)

Step: "Test" (full log | pattern match details) ❄️

Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_reshape Fatal Python error: Segmentation fault
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns_opset11::test_constant_fold_transpose_matmul PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns_opset11::test_constant_fold_unsqueeze PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns_opset11::test_error_on_data_parallel PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns_opset11::test_is_in_onnx_export PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns_opset11::test_strip_doc_string PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns_opset11::test_validate_dynamic_axes_invalid_input_output_name PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_concat PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_div PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_lstm PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_mul PASSED [ 99%] 
Feb 12 20:43:28 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_reshape Fatal Python error: Segmentation fault 
Feb 12 20:43:28  
Feb 12 20:43:28 Current thread 0x00007f758f355700 (most recent call first): 
Feb 12 20:43:28   File "/usr/local/lib/python3.6/dist-packages/_pytest/_io/saferepr.py", line 43 in repr_instance 
Feb 12 20:43:28   File "/usr/lib/python3.6/reprlib.py", line 65 in repr1 
Feb 12 20:43:28   File "/usr/lib/python3.6/reprlib.py", line 55 in repr 
Feb 12 20:43:28   File "/usr/local/lib/python3.6/dist-packages/_pytest/_io/saferepr.py", line 36 in repr 
Feb 12 20:43:28   File "/usr/local/lib/python3.6/dist-packages/_pytest/_io/saferepr.py", line 67 in saferepr 
Feb 12 20:43:28   File "/usr/local/lib/python3.6/dist-packages/_pytest/_code/code.py", line 655 in repr_args 
Feb 12 20:43:28   File "/usr/local/lib/python3.6/dist-packages/_pytest/_code/code.py", line 736 in repr_traceback_entry 
Feb 12 20:43:28   File "/usr/local/lib/python3.6/dist-packages/_pytest/_code/code.py", line 777 in repr_traceback 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 7 times.

@vincentqb vincentqb added the module: nn Related to torch.nn label Jan 30, 2020
@ezyang ezyang requested a review from vincentqb February 3, 2020 15:54
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 3, 2020
@ezyang
Copy link
Contributor

ezyang commented Feb 3, 2020

@vincentqb do you think you would be able to review this? (No is a fine answer)

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't you do the following instead?

x = torch.randn([1, 3, 3, 3]) 
print('input:\n', x) 
conv_layer = torch.nn.Conv2d(3, 3, 3, padding=1, groups=1, bias=False) 
torch.nn.init.dirac_(conv_layer.weight)
print('\noutput (before this PR):\n',conv_layer(x))  
input:
 tensor([[[[-0.4437, -0.2448,  1.4811],
          [ 0.2103,  0.2234,  0.2437],
          [-2.6258,  0.2899, -0.1680]],

         [[-0.3531, -0.3139, -3.1256],
          [ 0.5868,  0.3861, -0.0573],
          [ 1.8914,  0.2475,  0.2421]],

         [[-0.8827,  0.4264, -0.4944],
          [ 1.8532,  0.0400,  1.4679],
          [ 1.8510,  2.3384,  0.8307]]]])

output (before this PR):
 tensor([[[[-0.4437, -0.2448,  1.4811],
          [ 0.2103,  0.2234,  0.2437],
          [-2.6258,  0.2899, -0.1680]],

         [[-0.3531, -0.3139, -3.1256],
          [ 0.5868,  0.3861, -0.0573],
          [ 1.8914,  0.2475,  0.2421]],

         [[-0.8827,  0.4264, -0.4944],
          [ 1.8532,  0.0400,  1.4679],
          [ 1.8510,  2.3384,  0.8307]]]], grad_fn=<MkldnnConvolutionBackward>)

I'm assuming the win is to do the identity by groups when out_channels is different from input_channels? But then this can be explicitly achieved by assigning with advanced indexing on a zero matrix of the final shape, no?

i.e. assign the sub tensor of

 tensor([[[[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MkldnnConvolutionBackward>)

the values of

 tensor([[[[ 1.2205, -0.6608,  0.8640],
          [-0.5464,  1.1288,  1.4726],
          [-0.6693,  0.4000, -1.7613]],

         [[-0.8760, -0.8814, -0.4705],
          [ 0.6283, -0.5943,  0.6873],
          [-0.6852,  1.4723,  0.3325]]]])

to get

 tensor([[[[ 1.2205, -0.6608,  0.8640],
          [-0.5464,  1.1288,  1.4726],
          [-0.6693,  0.4000, -1.7613]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[-0.8760, -0.8814, -0.4705],
          [ 0.6283, -0.5943,  0.6873],
          [-0.6852,  1.4723,  0.3325]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MkldnnConvolutionBackward>)

Moreover, if torch.nn.init.dirac_ supports groups, should torch.nn.init.* also?

@assafshocher
Copy link
Contributor Author

assafshocher commented Feb 8, 2020

@vincentqb
Your suggestions are valid if you wanted to do one time operation, but if you want to initialize a group conv layer, so changing back to groups=1 misses the point. Using indexing also. There is no problem copying a tensor but the point is initializing weights when using a group conv and do it consistently.
Moreover, currently what you get for a group conv initialized with Dirac doesn't make sense.
Reg init.*, Other initializations don't suffer from this inconsistency, since they don't need to create a structure that is channel aware like a dirc matrix, they simply fill with numbers all the channels without cross-dependencies so I think Dirac is the only one requires group support.
Thanks.

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your suggestions are valid if you wanted to do one time operation, but if you want to initialize a group conv layer, so changing back to groups=1 misses the point. Using indexing also. There is no problem copying a tensor but the point is initializing weights when using a group conv and do it consistently.
Moreover, currently what you get for a group conv initialized with Dirac doesn't make sense.
Reg init.*, Other initializations don't suffer from this inconsistency, since they don't need to create a structure that is channel aware like a dirc matrix, they simply fill with numbers all the channels without cross-dependencies so I think Dirac is the only one requires group support.

Sounds good, thanks!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@assafshocher
Copy link
Contributor Author

@vincentqb
Thanks,
I see that some tests are failing now (didn't before). should I do anything?

@vincentqb
Copy link
Contributor

@vincentqb
Thanks,
I see that some tests are failing now (didn't before). should I do anything?

None of the tests failing seem related. Landing.

@facebook-github-bot
Copy link
Contributor

@vincentqb merged this pull request in 2c99ea8.

ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
Initializing weights of group-conv with init.dirac_, and applying, previously resulted in an output that makes no sense:
```
x = torch.randn([1, 3, 3, 3])
print('input:\n', x)
conv_layer = torch.nn.Conv2d(3, 3, 3, padding=1, groups=3, bias=False)
torch.nn.init.dirac_(conv_layer.weight.data)
print('\noutput (before this PR):\n',conv_layer(x))

input:
 tensor([[[[ 0.5369, -1.1428,  0.1031],
          [ 0.4638, -0.0854, -0.6553],
          [ 0.8321, -2.5926, -0.3214]],

         [[-0.2289, -0.0895,  0.4407],
          [ 1.2309, -1.2096, -1.5216],
          [-0.1798,  1.1694,  0.3469]],

         [[ 0.1905,  0.8095,  0.5490],
          [-0.4525, -0.4284, -0.1141],
          [ 1.1857, -0.9246, -0.5119]]]])

output (before this PR):
 tensor([[[[ 0.5369, -1.1428,  0.1031],
          [ 0.4638, -0.0854, -0.6553],
          [ 0.8321, -2.5926, -0.3214]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MkldnnConvolutionBackward>)
````

This PR allows introducing groups to the initialization:
```
torch.nn.init.dirac_(conv_layer.weight.data, groups=3)
print('output (after this PR):\n', conv_layer(x))

output (after this PR):
 tensor([[[[ 0.5369, -1.1428,  0.1031],
          [ 0.4638, -0.0854, -0.6553],
          [ 0.8321, -2.5926, -0.3214]],

         [[-0.2289, -0.0895,  0.4407],
          [ 1.2309, -1.2096, -1.5216],
          [-0.1798,  1.1694,  0.3469]],

         [[ 0.1905,  0.8095,  0.5490],
          [-0.4525, -0.4284, -0.1141],
          [ 1.1857, -0.9246, -0.5119]]]], grad_fn=<MkldnnConvolutionBackward>)
```

When out_channels is different than input_channels, it does the natural thing which is applying identity in each group separately:

```
x = torch.randn([1, 2, 3, 3])
print('input:\n', x)
conv_layer = torch.nn.Conv2d(2, 4, 3, padding=1, groups=2, bias=False)
torch.nn.init.dirac_(conv_layer.weight.data, groups=2)
print('\noutput:\n', conv_layer(x))

input:
 tensor([[[[ 1.2205, -0.6608,  0.8640],
          [-0.5464,  1.1288,  1.4726],
          [-0.6693,  0.4000, -1.7613]],

         [[-0.8760, -0.8814, -0.4705],
          [ 0.6283, -0.5943,  0.6873],
          [-0.6852,  1.4723,  0.3325]]]])

output:
 tensor([[[[ 1.2205, -0.6608,  0.8640],
          [-0.5464,  1.1288,  1.4726],
          [-0.6693,  0.4000, -1.7613]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[-0.8760, -0.8814, -0.4705],
          [ 0.6283, -0.5943,  0.6873],
          [-0.6852,  1.4723,  0.3325]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MkldnnConvolutionBackward>)
```

Argument 'groups' defaults to 1 so it is backward compatible.

Tests are modified to include cases of with groups>1 but also contain groups=1 cases.
Pull Request resolved: pytorch#32825

Differential Revision: D19859926

Pulled By: vincentqb

fbshipit-source-id: 9dfdd24471ff14d79c442dfd28c1891aff812fdf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants