Skip to content

Conversation

kurtamohler
Copy link
Collaborator

@kurtamohler kurtamohler commented Jul 2, 2020

Implement matrix norm for orders +/- 1, 2, inf

This PR contains BC-breaking changes to torch.norm, torch.functional.norm, and its underlying aten functions. The deprecation plan is to add a warning to torch.functional.norm, explaining what will change. This warning will be added in a different PR (#41193), and should exist in PyTorch 1.6.0. Then, we can release the norm changes in PyTorch 1.7.0.

Issue #24802

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jul 2, 2020
@dr-ci
Copy link

dr-ci bot commented Jul 2, 2020

💊 CI failures summary and remediations

As of commit c080981 (more details on the Dr. CI page):


  • 4/4 failures possibly* introduced in this PR
    • 1/4 non-CircleCI failure(s)

🕵️ 3 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages (reran 1 job to discount flakiness):

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (1/3)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

RuntimeError: test_torch failed!
Generated XML report: test-reports\python-unittest\TEST-TestTorchDeviceTypeCPU-20200709201948.xml 
Generated XML report: test-reports\python-unittest\TEST-TestTorchDeviceTypeCUDA-20200709201948.xml 
Generated XML report: test-reports\python-unittest\TEST-TestTorchMathOpsCPU-20200709201948.xml 
Generated XML report: test-reports\python-unittest\TEST-TestViewOpsCPU-20200709201948.xml 
Generated XML report: test-reports\python-unittest\TEST-TestViewOpsCUDA-20200709201948.xml 
Traceback (most recent call last): 
  File "run_test.py", line 728, in <module> 
    main() 
  File "run_test.py", line 721, in main 
    raise RuntimeError(message) 
RuntimeError: test_torch failed! 
 
(base) circleci@PACKER-5EFB90C2 C:\Users\circleci\project\test>if ERRORLEVEL 1 exit /b 1  
+ cleanup
+ retcode=1
+ set +x

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (2/3)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jul 09 21:40:18 ERROR: test_norm_p0 (__main__.TestAtenXlaTensor)
Jul 09 21:39:05 test_conditional (__main__.TestOpBuilder) ... ok 
Jul 09 21:39:05 test_mul (__main__.TestOpBuilder) ... ok 
Jul 09 21:39:05 test_while (__main__.TestOpBuilder) ... ok 
Jul 09 21:39:05 test (__main__.TestParallelLoader) ... ok 
Jul 09 21:39:09 test (__main__.TestParallelTensorMNIST) ... ok 
Jul 09 21:40:15 test (__main__.TestParallelTensorResnet18) ... ok 
Jul 09 21:40:18 test_get_xla_tensor (__main__.TestSelect) ... ok 
Jul 09 21:40:18 test (__main__.TestToXlaTensorArena) ... ok 
Jul 09 21:40:18  
Jul 09 21:40:18 ====================================================================== 
Jul 09 21:40:18 ERROR: test_norm_p0 (__main__.TestAtenXlaTensor) 
Jul 09 21:40:18 ---------------------------------------------------------------------- 
Jul 09 21:40:18 Traceback (most recent call last): 
Jul 09 21:40:18   File "/var/lib/jenkins/workspace/xla/test/test_operations.py", line 1255, in test_norm_p0 
Jul 09 21:40:18     norm = a.norm(p=0) 
Jul 09 21:40:18   File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 329, in norm 
Jul 09 21:40:18     return torch.norm(self, p, dim, keepdim, dtype=dtype) 
Jul 09 21:40:18   File "/opt/conda/lib/python3.6/site-packages/torch/functional.py", line 1101, in norm 
Jul 09 21:40:18     return _VF._norm_matrix(input, p, dim=(0, 1), keepdim=keepdim) 
Jul 09 21:40:18 RuntimeError: Order 0 not supported for matrix norm 
Jul 09 21:40:18  

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (3/3)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) <confirmed not flaky by 2 failures>

Jul 09 21:28:27 ConnectionResetError: [Errno 104] Connection reset by peer
Jul 09 21:28:27   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 493, in Client 
Jul 09 21:28:27     answer_challenge(c, authkey) 
Jul 09 21:28:27   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 737, in answer_challenge 
Jul 09 21:28:27     response = connection.recv_bytes(256)        # reject large message 
Jul 09 21:28:27   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes 
Jul 09 21:28:27     buf = self._recv_bytes(maxlength) 
Jul 09 21:28:27   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes 
Jul 09 21:28:27     buf = self._recv(4) 
Jul 09 21:28:27   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 379, in _recv 
Jul 09 21:28:27     chunk = read(handle, remaining) 
Jul 09 21:28:27 ConnectionResetError: [Errno 104] Connection reset by peer 
Jul 09 21:28:27  
Jul 09 21:28:27 Process ErrorTrackingProcess-120: 
Jul 09 21:28:27 Traceback (most recent call last): 
Jul 09 21:28:27   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap 
Jul 09 21:28:27     self.run() 
Jul 09 21:28:27   File "/var/lib/jenkins/workspace/test/test_dataloader.py", line 361, in run 
Jul 09 21:28:27     super(ErrorTrackingProcess, self).run() 
Jul 09 21:28:27   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 93, in run 
Jul 09 21:28:27     self._target(*self._args, **self._kwargs) 
Jul 09 21:28:27   File "/var/lib/jenkins/workspace/test/test_dataloader.py", line 629, in _test_proper_exit 

1 failure confirmed as flaky and can be ignored:

  • pytorch_windows_vs2017_14.13_py36_cuda10.1_build

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 75 times.

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Jul 2, 2020

I'm seeing some issues with some of the JIT tests. Apparently, the TorchScript is calling the vector norm function when it should be calling the matrix norm.

For instance, when I run the following test, (I added a print() to show the tensor being operated on) the non-JIT reference function that the test calls gives the correct matrix norm result of torch._VF._norm_matrix(i0, p=-math.inf) = 1.710742. But the JITed call gives the vector norm, which matches with torch._VF.norm(i0, p=-math.inf) = 0.054511.

I'm not sure yet how to fix it. I have a hunch that my changes to torch.functional.norm are tripping up the JIT scripts because the JIT doesn't know that the shape of the input affects the execution path inside that function. But I certainly don't know that for sure, and I don't know what the solution would be if that does turn out to be the issue.

$ python test/test_jit.py TestJitGeneratedAutogradCPU.test_norm_matrix_neg_inf_cpu
tensors:
[tensor([[-0.7860, -0.4389,  1.3664, -0.5094, -1.6146],
        [ 0.4352, -0.0545, -0.5263, -2.0913, -0.2500],
        [-0.3201, -0.1385, -0.4392,  0.6292, -0.1838],
        [ 0.3869, -1.0295,  0.4756, -2.0045,  0.1760],
        [ 0.4706, -0.8826,  0.1569,  0.4725,  0.7419]])]
actuals:
['i0', '-math.inf']
F
======================================================================
FAIL: test_norm_matrix_neg_inf_cpu (__main__.TestJitGeneratedAutogradCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/work/kurtamohler/development/pytorch-matrix-norm/torch/testing/_internal/common_device_type.py", line 217, in instantiated_test
    return test(self, device_arg)
  File "test/test_jit.py", line 18019, in do_test
    check(name)
  File "/work/kurtamohler/development/pytorch-matrix-norm/torch/testing/_internal/jit_utils.py", line 634, in wrapper
    fn(*args, **kwargs)
  File "test/test_jit.py", line 17982, in check
    check_against_reference(self, script_fn,
  File "test/test_jit.py", line 17786, in check_against_reference
    self.assertEqual(outputs, outputs_test)
  File "/work/kurtamohler/development/pytorch-matrix-norm/torch/testing/_internal/common_utils.py", line 1075, in assertEqual
    self.assertTrue(result, msg=msg)
AssertionError: False is not true : Tensors failed to compare as equal! With rtol=1e-07 and atol=1e-07, found 1 element(s) (out of 1) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 1.6562303556477995 (1.7107421740030309 vs. 0.05451181835523127), which occurred at index 0.

----------------------------------------------------------------------
Ran 1 test in 0.227s

FAILED (failures=1)

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Jul 2, 2020

@kurtamohler , the tests most likely fail because they expect norm to produce vector norms, not operator norms I presume...

I checked what the JIT tests were expecting, and they do expect the correct thing. When _norm_matrix is supposed to be used, the non-JIT reference function is actually calling the _norm_matrix function. The JITed function is really where the failure is.

@kurtamohler
Copy link
Collaborator Author

Looks like some of the CI builds can't use SVD: RuntimeError: svd: MAGMA library not found in compilation. Please rebuild with MAGMA. I wonder if this PR should just disable the failing test cases if MAGMA isn't found, and then in a future PR we can replace the SVD call with an eigenvalue approximation that doesn't depend on MAGMA. Or would it be better to just combine it all into one PR so that we don't have to go through a gap period where _norm_matrix(p=+/-2) is only supported in a subset of environments? Maybe this is a question for @ezyang

@vishwakftw
Copy link
Contributor

You can add @skipCUDAIfNoMagma to the tests that use MAGMA.

@ezyang
Copy link
Contributor

ezyang commented Jul 6, 2020

Looks like some of the CI builds can't use SVD: RuntimeError: svd: MAGMA library not found in compilation. Please rebuild with MAGMA. I wonder if this PR should just disable the failing test cases if MAGMA isn't found, and then in a future PR we can replace the SVD call with an eigenvalue approximation that doesn't depend on MAGMA. Or would it be better to just combine it all into one PR so that we don't have to go through a gap period where _norm_matrix(p=+/-2) is only supported in a subset of environments?

Without knowing any of the context, it is totally reasonable to disable the tests when MAGMA is not available, we have plenty of tests that do this.

Whether or not you should do it now: isn't this a bc breaking change? You have bigger fish to fry than whether or not it works without MAGMA :)

@ezyang
Copy link
Contributor

ezyang commented Jul 6, 2020

This is BC breaking right? What's the deprecation plan (if there is one?) Can we get more detail about it in the PR summary?

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Jul 6, 2020

You have bigger fish to fry than whether or not it works without MAGMA

Yep, agreed.

Yes, this is BC breaking. As for a deprecation plan, I'm not entirely sure what it should be, but I believe you've recommended something like this in the past:

  • Next release will not change behavior, but add a warning that the interface will change in the following release
  • Release after next will have the new interface

Do you think that's a fair plan in this case, @ezyang ?

Alternatively, we could do something like this (just brainstorming, maybe not a good idea):

  • Next release will add the new behavior under torch.functional._norm_new (or some other name), and keep torch.functional.norm's current behavior. But torch.functional.norm will throw a warning that the next release is going to remove the current torch.functional.norm interface, renaming torch.functional._norm_new to torch.functional.norm
  • Release after next will remove the current torch.functional.norm interface and rename torch.functional._norm_new to torch.functional.norm

@kurtamohler kurtamohler force-pushed the matrix-norm-consistency-24802 branch from 6e358b1 to 155dcaa Compare July 6, 2020 22:01
@kurtamohler
Copy link
Collaborator Author

It looks like the reason for the JIT issue I'm seeing is that when torch.norm is JITed, its size is not taken into account, so it unconditionally gets translated to aten::norm when torch.jit.CompilationUnit() parses it here: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/jit_metaprogramming_utils.py#L278

For that call, in the matrix norm TestJitGeneratedAutogradCPU.test_norm_matrix_1_cpu test case, script is the following:

def the_method(i0):
    return i0.norm(1)

Which is the exact same script that gets generated for the corresponding vector norm case. Since that script does not tell the compilation unit whether i0 is 1-D or 2-D, CU.the_method.graph becomes the following, in both the matrix norm and the vector norm cases:

graph(%i0.1 : Tensor):
  %2 : int = prim::Constant[value=1]() # <string>:3:19
  %3 : Tensor = aten::norm(%i0.1, %2) # <string>:3:11
  return (%3)

Evidently, nuclear and frobenius norm have this same issue, because those test cases are registered in torch/testing/_internal/jit_metaprogramming_utils.py's EXCLUDE_SCRIPT. So I'm just going to register the matrix norm cases there as well. It would be nice to make the compilation unit take the input tensor size into account and correctly resolve to at::_norm_matrix when it needs to, but I imagine that would not be a simple task (correct me if I'm wrong).

@kurtamohler kurtamohler force-pushed the matrix-norm-consistency-24802 branch from d1159dc to a959d21 Compare July 7, 2020 17:53
@kurtamohler
Copy link
Collaborator Author

Sparse tensor inputs to the original implementation of torch.norm was only supported in a subset of cases:

  • Vector norm is supported if the dtype, dim, and out arguments are all not specified. keepdim=True does not work properly, always returns a scalar.
  • Frobenius norm is supported if the dtype, dim, and out arguments are all not specified. keepdim=True does not work properly, always returns a scalar.
  • Nuclear norm doesn't work at all because SVD does not support sparse tensors.

My current changes broke all sparse tensor support, so at the minimum, I'll need to fix it to behave like it did before. Although I should probably throw an error if keepdim=True, rather than returning an incorrect result.

It would be nice to add support for sparse tensors in all cases (at least for everything except nuclear norm), but that should be done in a future PR.

@ezyang
Copy link
Contributor

ezyang commented Jul 7, 2020

As for a deprecation plan, I'm not entirely sure what it should be, but I believe you've recommended something like this in the past:

Yes, this plan is reasonable. You may also consider exposing the new behavior in some way (as you suggest below) so people can opt into it early.

@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 7, 2020
@mruberry mruberry self-requested a review July 7, 2020 23:20
@mruberry mruberry added the module: numpy Related to numpy support, and also numpy compatibility of our operators label Jul 7, 2020
@kurtamohler kurtamohler force-pushed the matrix-norm-consistency-24802 branch from 1462c85 to d977051 Compare July 8, 2020 21:11
@kurtamohler
Copy link
Collaborator Author

I noticed that there are no tests confirming that torch.norm's results for complex types are correct, so I decided to see for myself. I wrote a script to compare with numpy: https://github.com/kurtamohler/pytorch-perf-test-scripts/blob/master/matrix-norm/complex_test.py

Here's what I get when I run it:

Click to expand
Complex vector norms:
      p            torch_result numpy_result  equal
0  None      4.856037-2.452530j      13.8782  False
1   inf      3.685238+0.000000j      3.68524   True
2     2      4.856037-2.452530j      13.8782  False
3     1    122.005386+0.000000j      122.005   True
4   0.5  11315.891602+0.000000j      11315.9   True
5     0    100.000000+0.000000j          100   True
6  -0.5      0.000095-0.000000j  9.47431e-05   True
7    -1      0.008568-0.000000j   0.00856754   True
8    -2      0.069691-0.000000j     0.069691   True
9  -inf      0.195269+0.000000j     0.195269   True

Complex matrix norms:
      p        torch_result  numpy_result  equal
0  None  3.468106+4.901812j      8.000000  False
1   fro  3.468106+4.901812j      8.000000  False
2   nuc  8.839804+0.000000j      8.839804   True
3   inf  9.830952+0.000000j      9.830952   True
4     2  7.950407+0.000000j      7.950407   True
5     1  8.830952+0.000000j      8.830952   True
6    -1  6.236068+0.000000j      6.236068   True
7    -2  0.889397+0.000000j      0.889397   True
8  -inf  5.236068+0.000000j      5.236068   True

So vector norms with p=2 and matrix frobenius norms are both incorrect. I'd be happy to work on them, but I feel like that should be a different PR. These cases where incorrect before my changes--here's what I get when I run it on the main branch:

Click to expand
Complex vector norms:
  match = (torch_result - numpy_result).abs().lt(eps).item()
      p            torch_result numpy_result  equal
0  None      4.856037-2.452530j      13.8782  False
1   inf      3.685238+0.000000j      3.68524   True
2     2      4.856037-2.452530j      13.8782  False
3     1    122.005386+0.000000j      122.005   True
4   0.5  11315.891602+0.000000j      11315.9   True
5     0    100.000000+0.000000j          100   True
6  -0.5      0.000095-0.000000j  9.47431e-05   True
7    -1      0.008568-0.000000j   0.00856754   True
8    -2      0.069691-0.000000j     0.069691   True
9  -inf      0.195269+0.000000j     0.195269   True

Complex matrix norms:
      p         torch_result  numpy_result  equal
0  None   3.468106+4.901812j      8.000000  False
1   fro   3.468106+4.901812j      8.000000  False
2   nuc   8.839804+0.000000j      8.839804   True
3   inf   5.830952+0.000000j      9.830952  False
4     2   3.468106+4.901812j      7.950407  False
5     1  15.067019+0.000000j      8.830952  False
6    -1   0.831915-0.000000j      6.236068  False
7    -2   1.575198+0.000000j      0.889397  False
8  -inf   2.236068+0.000000j      5.236068  False

Nevertheless, I will add tests to this PR for the cases that give correct results, since we really should have those tests.

@kurtamohler kurtamohler added the module: bc-breaking Related to a BC-breaking change label Jul 9, 2020
@kurtamohler kurtamohler force-pushed the matrix-norm-consistency-24802 branch from 8a20124 to 57622f0 Compare July 9, 2020 17:42
@kurtamohler
Copy link
Collaborator Author

@ezyang or @mruberry , I have just two quick questions remaining for this PR:

  1. Complex frobenius norm and vector 2-norm both give incorrect results. This was true of the original implementation, not introduced by my changes. Should I fix those in this PR or a subsequent PR? I don't know yet how much effort is required to make this change.

  2. In light of making torch.norm look more like numpy.linalg.norm, would it be a good idea if I rename the p argument to ord?

@ezyang
Copy link
Contributor

ezyang commented Jul 9, 2020

Complex frobenius norm and vector 2-norm both give incorrect results. This was true of the original implementation, not introduced by my changes. Should I fix those in this PR or a subsequent PR? I don't know yet how much effort is required to make this change.

Don't worry about it for this PR. If you want, raise an error in these cases.

In light of making torch.norm look more like numpy.linalg.norm, would it be a good idea if I rename the p argument to ord?

Maybe, but we should probably keep p working to avoid gratuitously breaking BC.

@mruberry
Copy link
Collaborator

mruberry commented Jul 9, 2020

Implement matrix norm for orders +/- 1, 2, inf

This PR contains BC-breaking changes to torch.norm, torch.functional.norm, and its underlying aten functions. The deprecation plan is to add a warning to torch.functional.norm, explaining what will change. This warning will be added in a different PR (#41193), and should exist in PyTorch 1.6.0. Then, we can release the norm changes in PyTorch 1.7.0.

Issue #24802

Hey Kurt, sorry it took me a couple days to jump in here. My question is: if our plan of record is to support a torch.linalg namespace with torch.linalg.norm, should we modify the existing torch.norm at all (except for it to throw a deprecation warning that users should use torch.linalg.norm instead?)

If we do still want to modify torch.norm, I suppose we don't have any way to prevent silent breakages other than hoping the user sees a warning?

@kurtamohler
Copy link
Collaborator Author

if our plan of record is to support a torch.linalg namespace with torch.linalg.norm, should we modify the existing torch.norm at all (except for it to throw a deprecation warning that users should use torch.linalg.norm instead?)

Oh darn, I misunderstood the implications of torch.linalg.norm when you mentioned it here: #24802 (comment)

Thinking from the perspective of the user, I think your suggestion makes sense. It does seem better to leave the original torch.norm unchanged but deprecated so we can maintain backward compatibility. Is someone currently working on adding the torch.linalg module?

@mruberry
Copy link
Collaborator

mruberry commented Jul 9, 2020

if our plan of record is to support a torch.linalg namespace with torch.linalg.norm, should we modify the existing torch.norm at all (except for it to throw a deprecation warning that users should use torch.linalg.norm instead?)

Oh darn, I misunderstood the implications of torch.linalg.norm when you mentioned it here: #24802 (comment)

Thinking from the perspective of the user, I think your suggestion makes sense. It does seem better to leave the original torch.norm unchanged but deprecated so we can maintain backward compatibility. Is someone currently working on adding the torch.linalg module?

No worries, I think I misunderstood your response, too ;)

Let's think about it for a day before committing to any plan of action? I don't want to rush your thinking.

Yes, I'm planning on building torch.linalg for 1.7 and should have it available in a few weeks. I'll move that timeline up to unblock you, in case that's the direction you decide to go in.

@kurtamohler
Copy link
Collaborator Author

Let's think about it for a day before committing to any plan of action? I don't want to rush your thinking.

Sounds good. I'll tally up the changes I would have to make to support this plan.

@mruberry
Copy link
Collaborator

mruberry commented Jul 9, 2020

Let's think about it for a day before committing to any plan of action? I don't want to rush your thinking.

Sounds good. I'll tally up the changes I would have to make to support this plan.

To be clear, torch.linalg.norm would be a totally new function. It could internally reuse parts of the existing norm functions if desired, but its goal is to be consistent with NumPy.

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Jul 14, 2020

@mruberry, I do think that avoiding changes to the torch.functional.norm interface is the right way to go, and the new interface should go in torch.linalg once that module is available. However, there were a few other changes I made to the existing at::frobenius_norm, at::nuclear_norm, and at::norm functions to get their results to match numpy in these cases:

  • frobenius_norm of a 0-D scalar
  • nuclear_norm with keepdim=True

Also, vector norm with sparse tensors was only supported if doing a full reduction with keepdim=False, and dtype unspecified. The error message did not make that clear, so I added messages for that.

It seems like these functionality changes should still be made, even if we don't change the torch.functional.norm interface itself. Perhaps I could split those changes out to their own PR(s), if we agree that they're worth keeping.

@mruberry
Copy link
Collaborator

@mruberry, I do think that avoiding changes to the torch.functional.norm interface is the right way to go, and the new interface should go in torch.linalg once that module is available. However, there were a few other changes I made to the existing at::frobenius_norm, at::nuclear_norm, and at::norm functions to get their results to match numpy in these cases:

  • frobenius_norm of a 0-D scalar
  • nuclear_norm with keepdim=True

Also, vector norm with sparse tensors was only supported if doing a full reduction with keepdim=False, and dtype unspecified. The error message did not make that clear, so I added messages for that.

It seems like these functionality changes should still be made, even if we don't change the torch.functional.norm interface itself. Perhaps I could split those changes out to their own PR(s), if we agree that they're worth keeping.

Anything non-BC breaking sounds great, of course. Would the idea be to preserve the behavior of the user-facing torch.functional.norm but fix the internal functions? If so, let's do it!

@kurtamohler
Copy link
Collaborator Author

Closing this for now. Will open a new PR when torch.linalg exists.

facebook-github-bot pushed a commit that referenced this pull request Aug 1, 2020
Summary:
**BC-Breaking Note:**
BC breaking changes in the case where keepdim=True. Before this change, when calling `torch.norm` with keepdim=True and p='fro' or p=number, leaving all other optional arguments as their default values, the keepdim argument would be ignored. Also, any time `torch.norm` was called with p='nuc', the result would have one fewer dimension than the input, and the dimensions could be out of order depending on which dimensions were being reduced. After the change, for each of these cases, the result has the same number and order of dimensions as the input.

**PR Summary:**

* Fix keepdim behavior
* Throw descriptive errors for unsupported sparse norm args
* Increase unit test coverage for these cases and for complex inputs

These changes were taken from part of PR #40924. That PR is not going to be merged because it overrides `torch.norm`'s interface, which we want to avoid. But these improvements are still useful.

Issue #24802

Pull Request resolved: #41956

Reviewed By: albanD

Differential Revision: D22837455

Pulled By: mruberry

fbshipit-source-id: 509ecabfa63b93737996f48a58c7188b005b7217
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: bc-breaking Related to a BC-breaking change module: numpy Related to numpy support, and also numpy compatibility of our operators oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants