Skip to content

Conversation

@gchanan
Copy link
Contributor

@gchanan gchanan commented Nov 3, 2020

Stack from ghstack:

Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

Differential Revision: D24729627

BC-breaking message:

Previously, the cpu variant of trace did not properly type promote, but the cuda variant did. Both variants now type promote correctly:

1.7:

>>> import torch
>>> x = torch.ones(5, 5, dtype=torch.uint8)
>>> x_cuda = x.cuda()
>>> x.trace().dtype
torch.uint8
>>> x_cuda.trace().dtype
torch.int64

1.8:

>>> import torch
>>> x = torch.ones(5, 5, dtype=torch.uint8)
>>> x_cuda = x.cuda()
>>> x.trace().dtype
torch.int64
>>> x_cuda.trace().dtype
torch.int64

Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

[ghstack-poisoned]
gchanan added a commit that referenced this pull request Nov 3, 2020
Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

ghstack-source-id: 2435e13
Pull Request resolved: #47305
@gchanan gchanan requested a review from zou3519 November 3, 2020 19:51
@gchanan
Copy link
Contributor Author

gchanan commented Nov 3, 2020

Here's some benchmarking results for this vs the diag().sum() option:

This:

 1:
>>> t.timeit.timeit(setup='import torch; x=torch.ones(1, 1, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
2.2587152142077684
>>> t.timeit.timeit(setup='import torch; x=torch.ones(1, 1, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
2.240397651679814
>>> t.timeit.timeit(setup='import torch; x=torch.ones(1, 1, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
2.2281987024471164
>>> t.timeit.timeit(setup='import torch; x=torch.ones(1, 1, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
2.2414425648748875

300:
>>> t.timeit.timeit(setup='import torch; x=torch.ones(300, 300, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
2.571804345585406
>>> t.timeit.timeit(setup='import torch; x=torch.ones(300, 300, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
2.574404393322766
>>> t.timeit.timeit(setup='import torch; x=torch.ones(300, 300, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
2.676772426813841
>>> t.timeit.timeit(setup='import torch; x=torch.ones(300, 300, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
2.566413162276149
>>> t.timeit.timeit(setup='import torch; x=torch.ones(300, 300, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
2.5903642047196627

100000
>>> t.timeit.timeit(setup='import torch; x=torch.ones(100000, 100000, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000)
2.908080894500017
>>> t.timeit.timeit(setup='import torch; x=torch.ones(100000, 100000, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000)
3.2296120524406433

diag().sum():

1
>>> t.timeit.timeit(setup='import torch; x=torch.ones(1, 1, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
8.96580315567553
>>> t.timeit.timeit(setup='import torch; x=torch.ones(1, 1, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
9.246467645280063
>>> t.timeit.timeit(setup='import torch; x=torch.ones(1, 1, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
9.167060194537044

300
>>> t.timeit.timeit(setup='import torch; x=torch.ones(300, 300, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
10.824241297319531
>>> t.timeit.timeit(setup='import torch; x=torch.ones(300, 300, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000000)
10.658952249214053

100000
>>> t.timeit.timeit(setup='import torch; x=torch.ones(100000, 100000, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000)
2.981750411912799
>>> t.timeit.timeit(setup='import torch; x=torch.ones(100000, 100000, dtype=torch.uint8)', stmt='torch.trace(x)', number=1000)
3.28736359719187

@gchanan
Copy link
Contributor Author

gchanan commented Nov 3, 2020

so the overhead is much higher, although for big sizes it doesn't matter.

return result;
}

Tensor trace_cpu(const Tensor& self) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The best place to put this might be TriangularOps.cpp (trace_cuda is in TriangularOps.cu). Alternatively we can move trace_cuda out of TriangularOps.cu for consistency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is you need the specialized promotion logic from here. In theory you could split things up but this doesn't seem worth it given the promotion logic is really specialized for reductions.

});

return result;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're going to have to update test_trace:

pytorch/test/test_torch.py

Lines 5996 to 6029 in 774b638

def _test_trace(self, device, dtype, legacy):
def test(shape):
tensor = make_tensor(shape, device, dtype, low=-9, high=9)
diag = tensor.diag()
if legacy:
# NB: trace on cpu doesn't do type promotion... #47127
expected_dtype = dtype
else:
expected_dtype = tensor.sum().dtype
expected_dtype = torch_to_numpy_dtype_dict[expected_dtype]
result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype)
expected = torch.tensor(result, device=device)
self.assertEqual(tensor.trace(), expected)
shapes = (
[10, 1],
[1, 10],
[100, 100],
[20, 100],
[100, 20],
)
for shape in shapes:
test(shape)
@onlyCPU
@dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))
def test_trace_legacy(self, device, dtype):
self._test_trace(device, dtype, legacy=True)
@onlyCUDA
@dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
def test_trace(self, device, dtype):
self._test_trace(device, dtype, legacy=False)

@zou3519
Copy link
Contributor

zou3519 commented Nov 3, 2020

Should we mark this as BC-breaking? In older versions of PyTorch (<1.5), we did not do any type promotions for trace

@dr-ci
Copy link

dr-ci bot commented Nov 3, 2020

💊 CI failures summary and remediations

As of commit f72d684 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build binary_linux_libtorch_3_7m_cpu_devtoolset7_shared-with-deps_build (1/1)

Step: "Checkout pytorch/builder repo" (full log | diagnosis details | 🔁 rerun)

fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878
+ sleep 2 
+ git submodule update --init --recursive 
fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878 
Unable to checkout 'cd5a9b73c3028d2496666201588111a8c8d84878' in submodule path 'third_party/nccl/nccl' 
+ sleep 4 
+ git submodule update --init --recursive 
fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878 
Unable to checkout 'cd5a9b73c3028d2496666201588111a8c8d84878' in submodule path 'third_party/nccl/nccl' 
+ sleep 8 
+ git submodule update --init --recursive 
fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878 
Unable to checkout 'cd5a9b73c3028d2496666201588111a8c8d84878' in submodule path 'third_party/nccl/nccl' 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 24 times.

@gchanan
Copy link
Contributor Author

gchanan commented Nov 3, 2020

ya I'll mark it BC-breaking.

@gchanan gchanan added the module: bc-breaking Related to a BC-breaking change label Nov 3, 2020
@codecov
Copy link

codecov bot commented Nov 4, 2020

Codecov Report

Merging #47305 (848af86) into gh/gchanan/338/base (5c8896f) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@                   Coverage Diff                   @@
##           gh/gchanan/338/base   #47305      +/-   ##
=======================================================
- Coverage                60.81%   60.81%   -0.01%     
=======================================================
  Files                     2749     2749              
  Lines                   254098   254099       +1     
=======================================================
- Hits                    154535   154529       -6     
- Misses                   99563    99570       +7     

Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

[ghstack-poisoned]
def test_trace(self, device, dtype):
def test(shape):
tensor = make_tensor(shape, device, dtype, low=-9, high=9)
diag = tensor.diag()
Copy link
Contributor

@zou3519 zou3519 Nov 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diag = tensor.diag() line is bogus (it is unused), can you remove it please? (Sorry, I was the one who added it but I am only noticing it now)

self.assertEqual(expected, result)

def _test_trace(self, device, dtype, legacy):
@dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've dropped testing for half on CUDA, is there a good way around this?

The dtypesIfCUDA and dtypesifCPU decorators look like they handle this (albeit very verbosely):

@dtypesIfCPU(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))
@dtypesIfCUDA(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))

Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

[ghstack-poisoned]
Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

[ghstack-poisoned]
gchanan added a commit that referenced this pull request Nov 4, 2020
Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

ghstack-source-id: d1a2b43
Pull Request resolved: #47305
Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

Differential Revision: [D24729627](https://our.internmc.facebook.com/intern/diff/D24729627)
gchanan added a commit that referenced this pull request Nov 5, 2020
Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

ghstack-source-id: 9dcfdf8
Pull Request resolved: #47305
Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

Differential Revision: [D24729627](https://our.internmc.facebook.com/intern/diff/D24729627)
gchanan added a commit that referenced this pull request Nov 9, 2020
Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

ghstack-source-id: 0e94887
Pull Request resolved: #47305
Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

Differential Revision: [D24729627](https://our.internmc.facebook.com/intern/diff/D24729627)
gchanan added a commit that referenced this pull request Nov 9, 2020
Fixes #47127.

Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR.

ghstack-source-id: a6c59f2
Pull Request resolved: #47305
@facebook-github-bot
Copy link
Contributor

@gchanan merged this pull request in 65a72ca.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: bc-breaking Related to a BC-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants