-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate neg's CUDA implementation to ATen. #23617
Closed
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Doesn't seem to cause any performance regression. Performance difference in the benchmarks is negligible. Benchmark script: ```python import timeit for n, t in [(10, 100000), (1000, 10000)]: print('a.neg() (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.float', 'torch.double') + (('torch.half',) if device == 'cuda' else ()): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit('a.neg()', setup=f'import torch; a = torch.zeros({n}, device="{device}", dtype={dtype})', number=t)) ``` Before: a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.6863291729987395 device: cpu, dtype: torch.uint8, 100000 times 2.652089002000139 device: cpu, dtype: torch.int16, 100000 times 2.656129287001022 device: cpu, dtype: torch.int32, 100000 times 2.690395133000493 device: cpu, dtype: torch.int64, 100000 times 2.7907447709985718 device: cpu, dtype: torch.float, 100000 times 2.7426670610002475 device: cpu, dtype: torch.double, 100000 times 2.839107595000314 device: cuda, dtype: torch.int8, 100000 times 4.7421167499996955 device: cuda, dtype: torch.uint8, 100000 times 4.732283645998905 device: cuda, dtype: torch.int16, 100000 times 4.794625207998251 device: cuda, dtype: torch.int32, 100000 times 4.763449829999445 device: cuda, dtype: torch.int64, 100000 times 4.832622486999753 device: cuda, dtype: torch.float, 100000 times 4.859277726000073 device: cuda, dtype: torch.double, 100000 times 4.920464308999726 device: cuda, dtype: torch.half, 100000 times 4.84502355999939 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.39148091199967894 device: cpu, dtype: torch.uint8, 10000 times 0.3914154279991635 device: cpu, dtype: torch.int16, 10000 times 0.3592291830009344 device: cpu, dtype: torch.int32, 10000 times 0.4305302360007772 device: cpu, dtype: torch.int64, 10000 times 0.5731948379998357 device: cpu, dtype: torch.float, 10000 times 0.40393425100046443 device: cpu, dtype: torch.double, 10000 times 0.5108613129996229 device: cuda, dtype: torch.int8, 10000 times 0.47522059500079195 device: cuda, dtype: torch.uint8, 10000 times 0.4748163679996651 device: cuda, dtype: torch.int16, 10000 times 0.48025749100088433 device: cuda, dtype: torch.int32, 10000 times 0.47739119099969685 device: cuda, dtype: torch.int64, 10000 times 0.4862670579987025 device: cuda, dtype: torch.float, 10000 times 0.4882351270007348 device: cuda, dtype: torch.double, 10000 times 0.49148828299985325 device: cuda, dtype: torch.half, 10000 times 0.48497576499903516 After: a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.690098836001198 device: cpu, dtype: torch.uint8, 100000 times 2.6437115690005157 device: cpu, dtype: torch.int16, 100000 times 2.658629939000093 device: cpu, dtype: torch.int32, 100000 times 2.7232703419995232 device: cpu, dtype: torch.int64, 100000 times 2.781715623999844 device: cpu, dtype: torch.float, 100000 times 2.7466302209995774 device: cpu, dtype: torch.double, 100000 times 2.8326373519994377 device: cuda, dtype: torch.int8, 100000 times 4.7760227950002445 device: cuda, dtype: torch.uint8, 100000 times 4.788483000998895 device: cuda, dtype: torch.int16, 100000 times 4.826825819000078 device: cuda, dtype: torch.int32, 100000 times 4.838881628998934 device: cuda, dtype: torch.int64, 100000 times 4.8777323520007485 device: cuda, dtype: torch.float, 100000 times 4.902277535000394 device: cuda, dtype: torch.double, 100000 times 4.971408750001501 device: cuda, dtype: torch.half, 100000 times 4.909047918001306 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.3891995970006974 device: cpu, dtype: torch.uint8, 10000 times 0.38925370699871564 device: cpu, dtype: torch.int16, 10000 times 0.3580240920000506 device: cpu, dtype: torch.int32, 10000 times 0.4262140860009822 device: cpu, dtype: torch.int64, 10000 times 0.5766096090010251 device: cpu, dtype: torch.float, 10000 times 0.40679733400065743 device: cpu, dtype: torch.double, 10000 times 0.5093677469994873 device: cuda, dtype: torch.int8, 10000 times 0.4803728469996713 device: cuda, dtype: torch.uint8, 10000 times 0.47961040299924207 device: cuda, dtype: torch.int16, 10000 times 0.4822084730003553 device: cuda, dtype: torch.int32, 10000 times 0.4869117329999426 device: cuda, dtype: torch.int64, 10000 times 0.48833372000081 device: cuda, dtype: torch.float, 10000 times 0.4896292170014931 device: cuda, dtype: torch.double, 10000 times 0.4972619459986163 device: cuda, dtype: torch.half, 10000 times 0.49118665199966927
pytorchbot
added
module: cuda
Related to torch.cuda, and CUDA support in general
module: internals
Related to internal abstractions in c10 and ATen
module: operators
labels
Jul 31, 2019
xuhdev
added a commit
that referenced
this pull request
Jul 31, 2019
Doesn't seem to cause any performance regression. Performance difference in the benchmarks is negligible. Benchmark script: ```python import timeit for n, t in [(10, 100000), (1000, 10000)]: print('a.neg() (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.float', 'torch.double') + (('torch.half',) if device == 'cuda' else ()): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit('a.neg()', setup=f'import torch; a = torch.zeros({n}, device="{device}", dtype={dtype})', number=t)) ``` Before: a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.6863291729987395 device: cpu, dtype: torch.uint8, 100000 times 2.652089002000139 device: cpu, dtype: torch.int16, 100000 times 2.656129287001022 device: cpu, dtype: torch.int32, 100000 times 2.690395133000493 device: cpu, dtype: torch.int64, 100000 times 2.7907447709985718 device: cpu, dtype: torch.float, 100000 times 2.7426670610002475 device: cpu, dtype: torch.double, 100000 times 2.839107595000314 device: cuda, dtype: torch.int8, 100000 times 4.7421167499996955 device: cuda, dtype: torch.uint8, 100000 times 4.732283645998905 device: cuda, dtype: torch.int16, 100000 times 4.794625207998251 device: cuda, dtype: torch.int32, 100000 times 4.763449829999445 device: cuda, dtype: torch.int64, 100000 times 4.832622486999753 device: cuda, dtype: torch.float, 100000 times 4.859277726000073 device: cuda, dtype: torch.double, 100000 times 4.920464308999726 device: cuda, dtype: torch.half, 100000 times 4.84502355999939 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.39148091199967894 device: cpu, dtype: torch.uint8, 10000 times 0.3914154279991635 device: cpu, dtype: torch.int16, 10000 times 0.3592291830009344 device: cpu, dtype: torch.int32, 10000 times 0.4305302360007772 device: cpu, dtype: torch.int64, 10000 times 0.5731948379998357 device: cpu, dtype: torch.float, 10000 times 0.40393425100046443 device: cpu, dtype: torch.double, 10000 times 0.5108613129996229 device: cuda, dtype: torch.int8, 10000 times 0.47522059500079195 device: cuda, dtype: torch.uint8, 10000 times 0.4748163679996651 device: cuda, dtype: torch.int16, 10000 times 0.48025749100088433 device: cuda, dtype: torch.int32, 10000 times 0.47739119099969685 device: cuda, dtype: torch.int64, 10000 times 0.4862670579987025 device: cuda, dtype: torch.float, 10000 times 0.4882351270007348 device: cuda, dtype: torch.double, 10000 times 0.49148828299985325 device: cuda, dtype: torch.half, 10000 times 0.48497576499903516 After: a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.690098836001198 device: cpu, dtype: torch.uint8, 100000 times 2.6437115690005157 device: cpu, dtype: torch.int16, 100000 times 2.658629939000093 device: cpu, dtype: torch.int32, 100000 times 2.7232703419995232 device: cpu, dtype: torch.int64, 100000 times 2.781715623999844 device: cpu, dtype: torch.float, 100000 times 2.7466302209995774 device: cpu, dtype: torch.double, 100000 times 2.8326373519994377 device: cuda, dtype: torch.int8, 100000 times 4.7760227950002445 device: cuda, dtype: torch.uint8, 100000 times 4.788483000998895 device: cuda, dtype: torch.int16, 100000 times 4.826825819000078 device: cuda, dtype: torch.int32, 100000 times 4.838881628998934 device: cuda, dtype: torch.int64, 100000 times 4.8777323520007485 device: cuda, dtype: torch.float, 100000 times 4.902277535000394 device: cuda, dtype: torch.double, 100000 times 4.971408750001501 device: cuda, dtype: torch.half, 100000 times 4.909047918001306 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.3891995970006974 device: cpu, dtype: torch.uint8, 10000 times 0.38925370699871564 device: cpu, dtype: torch.int16, 10000 times 0.3580240920000506 device: cpu, dtype: torch.int32, 10000 times 0.4262140860009822 device: cpu, dtype: torch.int64, 10000 times 0.5766096090010251 device: cpu, dtype: torch.float, 10000 times 0.40679733400065743 device: cpu, dtype: torch.double, 10000 times 0.5093677469994873 device: cuda, dtype: torch.int8, 10000 times 0.4803728469996713 device: cuda, dtype: torch.uint8, 10000 times 0.47961040299924207 device: cuda, dtype: torch.int16, 10000 times 0.4822084730003553 device: cuda, dtype: torch.int32, 10000 times 0.4869117329999426 device: cuda, dtype: torch.int64, 10000 times 0.48833372000081 device: cuda, dtype: torch.float, 10000 times 0.4896292170014931 device: cuda, dtype: torch.double, 10000 times 0.4972619459986163 device: cuda, dtype: torch.half, 10000 times 0.49118665199966927 ghstack-source-id: b8939eda48e63068df71951cd4065f5652154669 Pull Request resolved: #23617
A bigger question is whether we wanna do this for other unary ops in general. |
Doesn't seem to cause any performance regression. Performance difference in the benchmarks is negligible. Benchmark script: ```python import timeit for n, t in [(10, 100000), (1000, 10000)]: print('a.neg() (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.float', 'torch.double') + (('torch.half',) if device == 'cuda' else ()): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit('a.neg()', setup=f'import torch; a = torch.zeros({n}, device="{device}", dtype={dtype})', number=t)) ``` Before: a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.6863291729987395 device: cpu, dtype: torch.uint8, 100000 times 2.652089002000139 device: cpu, dtype: torch.int16, 100000 times 2.656129287001022 device: cpu, dtype: torch.int32, 100000 times 2.690395133000493 device: cpu, dtype: torch.int64, 100000 times 2.7907447709985718 device: cpu, dtype: torch.float, 100000 times 2.7426670610002475 device: cpu, dtype: torch.double, 100000 times 2.839107595000314 device: cuda, dtype: torch.int8, 100000 times 4.7421167499996955 device: cuda, dtype: torch.uint8, 100000 times 4.732283645998905 device: cuda, dtype: torch.int16, 100000 times 4.794625207998251 device: cuda, dtype: torch.int32, 100000 times 4.763449829999445 device: cuda, dtype: torch.int64, 100000 times 4.832622486999753 device: cuda, dtype: torch.float, 100000 times 4.859277726000073 device: cuda, dtype: torch.double, 100000 times 4.920464308999726 device: cuda, dtype: torch.half, 100000 times 4.84502355999939 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.39148091199967894 device: cpu, dtype: torch.uint8, 10000 times 0.3914154279991635 device: cpu, dtype: torch.int16, 10000 times 0.3592291830009344 device: cpu, dtype: torch.int32, 10000 times 0.4305302360007772 device: cpu, dtype: torch.int64, 10000 times 0.5731948379998357 device: cpu, dtype: torch.float, 10000 times 0.40393425100046443 device: cpu, dtype: torch.double, 10000 times 0.5108613129996229 device: cuda, dtype: torch.int8, 10000 times 0.47522059500079195 device: cuda, dtype: torch.uint8, 10000 times 0.4748163679996651 device: cuda, dtype: torch.int16, 10000 times 0.48025749100088433 device: cuda, dtype: torch.int32, 10000 times 0.47739119099969685 device: cuda, dtype: torch.int64, 10000 times 0.4862670579987025 device: cuda, dtype: torch.float, 10000 times 0.4882351270007348 device: cuda, dtype: torch.double, 10000 times 0.49148828299985325 device: cuda, dtype: torch.half, 10000 times 0.48497576499903516 After: a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.690098836001198 device: cpu, dtype: torch.uint8, 100000 times 2.6437115690005157 device: cpu, dtype: torch.int16, 100000 times 2.658629939000093 device: cpu, dtype: torch.int32, 100000 times 2.7232703419995232 device: cpu, dtype: torch.int64, 100000 times 2.781715623999844 device: cpu, dtype: torch.float, 100000 times 2.7466302209995774 device: cpu, dtype: torch.double, 100000 times 2.8326373519994377 device: cuda, dtype: torch.int8, 100000 times 4.7760227950002445 device: cuda, dtype: torch.uint8, 100000 times 4.788483000998895 device: cuda, dtype: torch.int16, 100000 times 4.826825819000078 device: cuda, dtype: torch.int32, 100000 times 4.838881628998934 device: cuda, dtype: torch.int64, 100000 times 4.8777323520007485 device: cuda, dtype: torch.float, 100000 times 4.902277535000394 device: cuda, dtype: torch.double, 100000 times 4.971408750001501 device: cuda, dtype: torch.half, 100000 times 4.909047918001306 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.3891995970006974 device: cpu, dtype: torch.uint8, 10000 times 0.38925370699871564 device: cpu, dtype: torch.int16, 10000 times 0.3580240920000506 device: cpu, dtype: torch.int32, 10000 times 0.4262140860009822 device: cpu, dtype: torch.int64, 10000 times 0.5766096090010251 device: cpu, dtype: torch.float, 10000 times 0.40679733400065743 device: cpu, dtype: torch.double, 10000 times 0.5093677469994873 device: cuda, dtype: torch.int8, 10000 times 0.4803728469996713 device: cuda, dtype: torch.uint8, 10000 times 0.47961040299924207 device: cuda, dtype: torch.int16, 10000 times 0.4822084730003553 device: cuda, dtype: torch.int32, 10000 times 0.4869117329999426 device: cuda, dtype: torch.int64, 10000 times 0.48833372000081 device: cuda, dtype: torch.float, 10000 times 0.4896292170014931 device: cuda, dtype: torch.double, 10000 times 0.4972619459986163 device: cuda, dtype: torch.half, 10000 times 0.49118665199966927
xuhdev
requested review from
gchanan,
colesbury,
zdevito,
jamesr66a and
VitalyFedyunin
July 31, 2019 21:59
colesbury
approved these changes
Jul 31, 2019
Doesn't seem to cause any performance regression. Performance difference in the benchmarks is negligible. Benchmark script: ```python import timeit for n, t in [(10, 100000), (1000, 10000)]: print('a.neg() (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.float', 'torch.double') + (('torch.half',) if device == 'cuda' else ()): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit('a.neg()', setup=f'import torch; a = torch.ones({n}, device="{device}", dtype={dtype})', number=t)) ``` Before: ``` a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.642898898997373 device: cpu, dtype: torch.uint8, 100000 times 2.6144280709995655 device: cpu, dtype: torch.int16, 100000 times 2.6500500490001286 device: cpu, dtype: torch.int32, 100000 times 2.707123526997748 device: cpu, dtype: torch.int64, 100000 times 2.7521983230035403 device: cpu, dtype: torch.float, 100000 times 2.766398545998527 device: cpu, dtype: torch.double, 100000 times 2.8156379600004584 device: cuda, dtype: torch.int8, 100000 times 4.66132030800145 device: cuda, dtype: torch.uint8, 100000 times 4.643385174997093 device: cuda, dtype: torch.int16, 100000 times 4.683008575000713 device: cuda, dtype: torch.int32, 100000 times 4.739808276000986 device: cuda, dtype: torch.int64, 100000 times 4.752713688001677 device: cuda, dtype: torch.float, 100000 times 4.828282921000209 device: cuda, dtype: torch.double, 100000 times 4.807285109000077 device: cuda, dtype: torch.half, 100000 times 4.774916842001403 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.3856488650017127 device: cpu, dtype: torch.uint8, 10000 times 0.38432079300037003 device: cpu, dtype: torch.int16, 10000 times 0.35588691600059974 device: cpu, dtype: torch.int32, 10000 times 0.42739917199651245 device: cpu, dtype: torch.int64, 10000 times 0.5718168579987832 device: cpu, dtype: torch.float, 10000 times 0.40673828999933903 device: cpu, dtype: torch.double, 10000 times 0.50664389299709 device: cuda, dtype: torch.int8, 10000 times 0.47066202399946633 device: cuda, dtype: torch.uint8, 10000 times 0.46319492099792114 device: cuda, dtype: torch.int16, 10000 times 0.46734901899981196 device: cuda, dtype: torch.int32, 10000 times 0.47492316399802803 device: cuda, dtype: torch.int64, 10000 times 0.47535407499890425 device: cuda, dtype: torch.float, 10000 times 0.48102769600154716 device: cuda, dtype: torch.double, 10000 times 0.47957396499987226 device: cuda, dtype: torch.half, 10000 times 0.4789546009997139 ``` After: ``` a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.5973585530009586 device: cpu, dtype: torch.uint8, 100000 times 2.5876065160009603 device: cpu, dtype: torch.int16, 100000 times 2.6277998949990433 device: cpu, dtype: torch.int32, 100000 times 2.6755344299999706 device: cpu, dtype: torch.int64, 100000 times 2.7128256750002038 device: cpu, dtype: torch.float, 100000 times 2.7338339269990684 device: cpu, dtype: torch.double, 100000 times 2.8125362460014003 device: cuda, dtype: torch.int8, 100000 times 4.5937904180027544 device: cuda, dtype: torch.uint8, 100000 times 4.56467357099973 device: cuda, dtype: torch.int16, 100000 times 4.612064369001018 device: cuda, dtype: torch.int32, 100000 times 4.644249272998422 device: cuda, dtype: torch.int64, 100000 times 4.69729050299793 device: cuda, dtype: torch.float, 100000 times 4.7240724049988785 device: cuda, dtype: torch.double, 100000 times 4.734548320000613 device: cuda, dtype: torch.half, 100000 times 4.6907191919999605 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.38785348699821043 device: cpu, dtype: torch.uint8, 10000 times 0.3826022310022381 device: cpu, dtype: torch.int16, 10000 times 0.35425406899958034 device: cpu, dtype: torch.int32, 10000 times 0.4246119660019758 device: cpu, dtype: torch.int64, 10000 times 0.5690692410025804 device: cpu, dtype: torch.float, 10000 times 0.40096976999848266 device: cpu, dtype: torch.double, 10000 times 0.5034001249987341 device: cuda, dtype: torch.int8, 10000 times 0.4584954599995399 device: cuda, dtype: torch.uint8, 10000 times 0.4570061539998278 device: cuda, dtype: torch.int16, 10000 times 0.4609103500006313 device: cuda, dtype: torch.int32, 10000 times 0.4647650409970083 device: cuda, dtype: torch.int64, 10000 times 0.46961031799946795 device: cuda, dtype: torch.float, 10000 times 0.472893533999013 device: cuda, dtype: torch.double, 10000 times 0.4730295919980563 device: cuda, dtype: torch.half, 10000 times 0.46740938700168044 ```
@pytorchbot merge this please |
pytorchbot
added
the
merge-this-please
Was marked for merge with @pytorchbot merge this please
label
Aug 1, 2019
zdevito
pushed a commit
to zdevito/ATen
that referenced
this pull request
Aug 2, 2019
Summary: Pull Request resolved: pytorch/pytorch#23617 Doesn't seem to cause any performance regression. Performance difference in the benchmarks is negligible. Benchmark script: ```python import timeit for n, t in [(10, 100000), (1000, 10000)]: print('a.neg() (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.float', 'torch.double') + (('torch.half',) if device == 'cuda' else ()): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a.neg()\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.ones({n}, device="{device}", dtype={dtype})', number=t)) ``` Before: ``` a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.5537249100016197 device: cpu, dtype: torch.uint8, 100000 times 2.512518662999355 device: cpu, dtype: torch.int16, 100000 times 2.548207502000878 device: cpu, dtype: torch.int32, 100000 times 2.5974994509997487 device: cpu, dtype: torch.int64, 100000 times 2.6533011499996064 device: cpu, dtype: torch.float, 100000 times 2.6474813019995054 device: cpu, dtype: torch.double, 100000 times 2.6949866009999823 device: cuda, dtype: torch.int8, 100000 times 5.820120684998983 device: cuda, dtype: torch.uint8, 100000 times 5.732108927997615 device: cuda, dtype: torch.int16, 100000 times 5.791249125999457 device: cuda, dtype: torch.int32, 100000 times 5.816761754998879 device: cuda, dtype: torch.int64, 100000 times 5.935873205999087 device: cuda, dtype: torch.float, 100000 times 6.276509613999224 device: cuda, dtype: torch.double, 100000 times 6.122782447000645 device: cuda, dtype: torch.half, 100000 times 6.161522764999972 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.3766637519984215 device: cpu, dtype: torch.uint8, 10000 times 0.37288786600038293 device: cpu, dtype: torch.int16, 10000 times 0.3485262310023245 device: cpu, dtype: torch.int32, 10000 times 0.41810554200128536 device: cpu, dtype: torch.int64, 10000 times 0.5609612200023548 device: cpu, dtype: torch.float, 10000 times 0.39054008099992643 device: cpu, dtype: torch.double, 10000 times 0.4946578170020075 device: cuda, dtype: torch.int8, 10000 times 0.5843639539998549 device: cuda, dtype: torch.uint8, 10000 times 0.5780841570012853 device: cuda, dtype: torch.int16, 10000 times 0.5819949180004187 device: cuda, dtype: torch.int32, 10000 times 0.5827294059999986 device: cuda, dtype: torch.int64, 10000 times 0.5861426519986708 device: cuda, dtype: torch.float, 10000 times 0.5929420489992481 device: cuda, dtype: torch.double, 10000 times 0.594638443999429 device: cuda, dtype: torch.half, 10000 times 0.5903799709994928 ``` After: ``` a.neg() (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 2.4983287129980454 device: cpu, dtype: torch.uint8, 100000 times 2.479393904999597 device: cpu, dtype: torch.int16, 100000 times 2.5382055320005747 device: cpu, dtype: torch.int32, 100000 times 2.5587980189993687 device: cpu, dtype: torch.int64, 100000 times 2.637738788002025 device: cpu, dtype: torch.float, 100000 times 2.602799075997609 device: cpu, dtype: torch.double, 100000 times 2.6648931070012623 device: cuda, dtype: torch.int8, 100000 times 5.793338211999071 device: cuda, dtype: torch.uint8, 100000 times 5.782462584000314 device: cuda, dtype: torch.int16, 100000 times 5.824340334998851 device: cuda, dtype: torch.int32, 100000 times 5.851659068001027 device: cuda, dtype: torch.int64, 100000 times 5.8898071570001775 device: cuda, dtype: torch.float, 100000 times 5.913144636000652 device: cuda, dtype: torch.double, 100000 times 5.963339805999567 device: cuda, dtype: torch.half, 100000 times 5.87889370099947 a.neg() (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.37244726499920944 device: cpu, dtype: torch.uint8, 10000 times 0.36641623199830065 device: cpu, dtype: torch.int16, 10000 times 0.3449854829996184 device: cpu, dtype: torch.int32, 10000 times 0.4127863069988962 device: cpu, dtype: torch.int64, 10000 times 0.5551902160004829 device: cpu, dtype: torch.float, 10000 times 0.38593814199703047 device: cpu, dtype: torch.double, 10000 times 0.48877579500185675 device: cuda, dtype: torch.int8, 10000 times 0.5862828740027908 device: cuda, dtype: torch.uint8, 10000 times 0.5836667540024791 device: cuda, dtype: torch.int16, 10000 times 0.5918155769977602 device: cuda, dtype: torch.int32, 10000 times 0.5961457039993547 device: cuda, dtype: torch.int64, 10000 times 0.5963898690024507 device: cuda, dtype: torch.float, 10000 times 0.5985483309996198 device: cuda, dtype: torch.double, 10000 times 0.6027148480025062 device: cuda, dtype: torch.half, 10000 times 0.5961164370019105 ``` Test Plan: Imported from OSS Differential Revision: D16617574 Pulled By: ezyang fbshipit-source-id: c90aa410f6385ce94fe6b84ebeceffa5effd0267
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
merge-this-please
Was marked for merge with @pytorchbot merge this please
Merged
module: cuda
Related to torch.cuda, and CUDA support in general
module: internals
Related to internal abstractions in c10 and ATen
open source
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack:
~
andbitwise_not()
when user tries to apply neg (-
) on a bool tensor. #23621 Recommend~
andbitwise_not()
when user tries to apply neg (-
) on a bool tensor.Doesn't seem to cause any performance regression. Performance difference
in the benchmarks is negligible.
Benchmark script:
Before:
After:
Differential Revision: D16617574