Skip to content

Conversation

cloudhan
Copy link
Contributor

@cloudhan cloudhan commented May 1, 2020

Fixed #24544

Reference #24507

@dr-ci
Copy link

dr-ci bot commented May 1, 2020

💊 Build failures summary and remediations

As of commit 431540e (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 10 times.

@cloudhan cloudhan marked this pull request as ready for review May 1, 2020 13:15
@cloudhan cloudhan force-pushed the port-cuda-clamp branch from b2f510b to 431540e Compare May 1, 2020 13:36
@cloudhan cloudhan changed the title port clamp from th to cuda Migrate clamp from the TH to Aten (CUDA) May 1, 2020
@cloudhan
Copy link
Contributor Author

cloudhan commented May 3, 2020

benchmarked with:

import timeit

for n, t in [(10000, 10000),
             (100000, 10000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.clamp(a, 0.25, 0.75) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.clamp(a, 0.25, 0.75); torch.cuda.synchronize()',
                            setup=f'import torch; a=torch.randn({n}, dtype={dtype}, device="cuda")',
                            number=t))

for name in ('clamp_min', 'clamp_max'):
    for n, t in [(10000, 10000),
                (32767, 10000)]:
        for dtype in ('torch.int16', 'torch.int32', 'torch.int64'):
            print(f'torch.{name}(a, 5000) a.numel() == {n} for {t} times {dtype}')
            print(timeit.timeit(f'torch.{name}(a, 5000); torch.cuda.synchronize()',
                                setup=f'import torch; a=torch.randint(0, {n}, ({n},), dtype={dtype}, device="cuda")',
                                number=t))

After port:

torch.clamp(a, 0.25, 0.75) a.numel() == 10000 for 10000 times torch.half
0.2015269324183464
torch.clamp(a, 0.25, 0.75) a.numel() == 10000 for 10000 times torch.float
0.2018002513796091
torch.clamp(a, 0.25, 0.75) a.numel() == 10000 for 10000 times torch.double
0.2028426956385374
torch.clamp(a, 0.25, 0.75) a.numel() == 100000 for 10000 times torch.half
0.20375624112784863
torch.clamp(a, 0.25, 0.75) a.numel() == 100000 for 10000 times torch.float
0.20094937086105347
torch.clamp(a, 0.25, 0.75) a.numel() == 100000 for 10000 times torch.double
0.20316258445382118
torch.clamp_min(a, 5000) a.numel() == 10000 for 10000 times torch.int16
0.19486145116388798
torch.clamp_min(a, 5000) a.numel() == 10000 for 10000 times torch.int32
0.19561058655381203
torch.clamp_min(a, 5000) a.numel() == 10000 for 10000 times torch.int64
0.19365259259939194
torch.clamp_min(a, 5000) a.numel() == 32767 for 10000 times torch.int16
0.19426318630576134
torch.clamp_min(a, 5000) a.numel() == 32767 for 10000 times torch.int32
0.19408364593982697
torch.clamp_min(a, 5000) a.numel() == 32767 for 10000 times torch.int64
0.19407162815332413
torch.clamp_max(a, 5000) a.numel() == 10000 for 10000 times torch.int16
0.1960984691977501
torch.clamp_max(a, 5000) a.numel() == 10000 for 10000 times torch.int32
0.19451416097581387
torch.clamp_max(a, 5000) a.numel() == 10000 for 10000 times torch.int64
0.1952181402593851
torch.clamp_max(a, 5000) a.numel() == 32767 for 10000 times torch.int16
0.19869455881416798
torch.clamp_max(a, 5000) a.numel() == 32767 for 10000 times torch.int32
0.19556890800595284
torch.clamp_max(a, 5000) a.numel() == 32767 for 10000 times torch.int64
0.1993284411728382

Original:

torch.clamp(a, 0.25, 0.75) a.numel() == 10000 for 10000 times torch.half
0.20965042151510715
torch.clamp(a, 0.25, 0.75) a.numel() == 10000 for 10000 times torch.float
0.21045211143791676
torch.clamp(a, 0.25, 0.75) a.numel() == 10000 for 10000 times torch.double
0.21295203268527985
torch.clamp(a, 0.25, 0.75) a.numel() == 100000 for 10000 times torch.half
0.2085997760295868
torch.clamp(a, 0.25, 0.75) a.numel() == 100000 for 10000 times torch.float
0.20977089367806911
torch.clamp(a, 0.25, 0.75) a.numel() == 100000 for 10000 times torch.double
0.21144748851656914
torch.clamp_min(a, 5000) a.numel() == 10000 for 10000 times torch.int16
0.19735403917729855
torch.clamp_min(a, 5000) a.numel() == 10000 for 10000 times torch.int32
0.19943231716752052
torch.clamp_min(a, 5000) a.numel() == 10000 for 10000 times torch.int64
0.19904308579862118
torch.clamp_min(a, 5000) a.numel() == 32767 for 10000 times torch.int16
0.19618243724107742
torch.clamp_min(a, 5000) a.numel() == 32767 for 10000 times torch.int32
0.19913179613649845
torch.clamp_min(a, 5000) a.numel() == 32767 for 10000 times torch.int64
0.19725298509001732
torch.clamp_max(a, 5000) a.numel() == 10000 for 10000 times torch.int16
0.20088727213442326
torch.clamp_max(a, 5000) a.numel() == 10000 for 10000 times torch.int32
0.20195786096155643
torch.clamp_max(a, 5000) a.numel() == 10000 for 10000 times torch.int64
0.1999555043876171
torch.clamp_max(a, 5000) a.numel() == 32767 for 10000 times torch.int16
0.19893217459321022
torch.clamp_max(a, 5000) a.numel() == 32767 for 10000 times torch.int32
0.2005235254764557
torch.clamp_max(a, 5000) a.numel() == 32767 for 10000 times torch.int64
0.19949866831302643

@cloudhan
Copy link
Contributor Author

cloudhan commented May 3, 2020

@VitalyFedyunin request for review.

@ngimel ngimel requested a review from VitalyFedyunin May 5, 2020 00:21
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 5, 2020
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in 12e6491.

bharatr21 pushed a commit to bharatr21/pytorch that referenced this pull request May 5, 2020
Summary:
Fixed pytorch#24544

Reference pytorch#24507
Pull Request resolved: pytorch#37646

Differential Revision: D21395824

Pulled By: VitalyFedyunin

fbshipit-source-id: 111889023d60e3361b5a646bcfb6fb7d5ec969d1
@cloudhan cloudhan deleted the port-cuda-clamp branch September 26, 2021 05:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Migrate clamp and clamp_ from the TH to Aten (CUDA)

6 participants