Skip to content

Conversation

xuhdev
Copy link
Collaborator

@xuhdev xuhdev commented Oct 23, 2019

Stack from ghstack:

Benchmark (Debian Buster, CUDA 9.2, Quadro P400, turbo off, Release, gcc 7.4):

import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.sinh(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.sinh(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))

Before:

torch.sinh(a) a.numel() == 10000 for 20000 times torch.half
0.3807680979998622
torch.sinh(a) a.numel() == 10000 for 20000 times torch.float
0.37430476099962107
torch.sinh(a) a.numel() == 10000 for 20000 times torch.double
1.0580407639999976
torch.sinh(a) a.numel() == 100000 for 20000 times torch.half
0.7996397469996737
torch.sinh(a) a.numel() == 100000 for 20000 times torch.float
1.010930432999885
torch.sinh(a) a.numel() == 100000 for 20000 times torch.double
7.310400856999877

After:

torch.sinh(a) a.numel() == 10000 for 20000 times torch.half
0.3720399889998589
torch.sinh(a) a.numel() == 10000 for 20000 times torch.float
0.3694016069994177
torch.sinh(a) a.numel() == 10000 for 20000 times torch.double
1.0551542660004998
torch.sinh(a) a.numel() == 100000 for 20000 times torch.half
0.7431191599998783
torch.sinh(a) a.numel() == 100000 for 20000 times torch.float
0.9953043630002867
torch.sinh(a) a.numel() == 100000 for 20000 times torch.double
7.3146168890007175

Close #24628

Differential Revision: D18124732

Benchmark (Debian Buster, CUDA 9.2, Quadro P400, turbo off, Release):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.sinh(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.sinh(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.sinh(a) a.numel() == 10000 for 20000 times torch.half
0.3807680979998622
torch.sinh(a) a.numel() == 10000 for 20000 times torch.float
0.37430476099962107
torch.sinh(a) a.numel() == 10000 for 20000 times torch.double
1.0580407639999976
torch.sinh(a) a.numel() == 100000 for 20000 times torch.half
0.7996397469996737
torch.sinh(a) a.numel() == 100000 for 20000 times torch.float
1.010930432999885
torch.sinh(a) a.numel() == 100000 for 20000 times torch.double
7.310400856999877
```

After:

```
torch.sinh(a) a.numel() == 10000 for 20000 times torch.half
0.3720399889998589
torch.sinh(a) a.numel() == 10000 for 20000 times torch.float
0.3694016069994177
torch.sinh(a) a.numel() == 10000 for 20000 times torch.double
1.0551542660004998
torch.sinh(a) a.numel() == 100000 for 20000 times torch.half
0.7431191599998783
torch.sinh(a) a.numel() == 100000 for 20000 times torch.float
0.9953043630002867
torch.sinh(a) a.numel() == 100000 for 20000 times torch.double
7.3146168890007175
```

Close #24628

[ghstack-poisoned]
xuhdev added a commit that referenced this pull request Oct 23, 2019
Benchmark (Debian Buster, CUDA 9.2, Quadro P400, turbo off, Release):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.sinh(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.sinh(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.sinh(a) a.numel() == 10000 for 20000 times torch.half
0.3807680979998622
torch.sinh(a) a.numel() == 10000 for 20000 times torch.float
0.37430476099962107
torch.sinh(a) a.numel() == 10000 for 20000 times torch.double
1.0580407639999976
torch.sinh(a) a.numel() == 100000 for 20000 times torch.half
0.7996397469996737
torch.sinh(a) a.numel() == 100000 for 20000 times torch.float
1.010930432999885
torch.sinh(a) a.numel() == 100000 for 20000 times torch.double
7.310400856999877
```

After:

```
torch.sinh(a) a.numel() == 10000 for 20000 times torch.half
0.3720399889998589
torch.sinh(a) a.numel() == 10000 for 20000 times torch.float
0.3694016069994177
torch.sinh(a) a.numel() == 10000 for 20000 times torch.double
1.0551542660004998
torch.sinh(a) a.numel() == 100000 for 20000 times torch.half
0.7431191599998783
torch.sinh(a) a.numel() == 100000 for 20000 times torch.float
0.9953043630002867
torch.sinh(a) a.numel() == 100000 for 20000 times torch.double
7.3146168890007175
```

Close #24628

ghstack-source-id: 501674a
Pull Request resolved: #28527
@xuhdev xuhdev requested a review from VitalyFedyunin October 23, 2019 18:22
zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 30, 2019
Summary:
Pull Request resolved: pytorch/pytorch#28527

Benchmark (Debian Buster, CUDA 9.2, Quadro P400, turbo off, Release, gcc 7.4):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.sinh(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.sinh(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.sinh(a) a.numel() == 10000 for 20000 times torch.half
0.3807680979998622
torch.sinh(a) a.numel() == 10000 for 20000 times torch.float
0.37430476099962107
torch.sinh(a) a.numel() == 10000 for 20000 times torch.double
1.0580407639999976
torch.sinh(a) a.numel() == 100000 for 20000 times torch.half
0.7996397469996737
torch.sinh(a) a.numel() == 100000 for 20000 times torch.float
1.010930432999885
torch.sinh(a) a.numel() == 100000 for 20000 times torch.double
7.310400856999877
```

After:

```
torch.sinh(a) a.numel() == 10000 for 20000 times torch.half
0.3720399889998589
torch.sinh(a) a.numel() == 10000 for 20000 times torch.float
0.3694016069994177
torch.sinh(a) a.numel() == 10000 for 20000 times torch.double
1.0551542660004998
torch.sinh(a) a.numel() == 100000 for 20000 times torch.half
0.7431191599998783
torch.sinh(a) a.numel() == 100000 for 20000 times torch.float
0.9953043630002867
torch.sinh(a) a.numel() == 100000 for 20000 times torch.double
7.3146168890007175
```

Close #24628

Test Plan: Imported from OSS

Differential Revision: D18124732

Pulled By: VitalyFedyunin

fbshipit-source-id: 054b0c0884ac12de2dd1a92c5de916aaf047f9e9
@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in e0009fd.

@facebook-github-bot facebook-github-bot deleted the gh/xuhdev/45/head branch November 3, 2019 15:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants