Skip to content

Commit

Permalink
Update on "[pt][quant] Optimized qadd_scalar"
Browse files Browse the repository at this point in the history
Optimized path for qadd scalar. qadd_scalar time goes down from 55.840ms for a model to 4.637ms.

### Before
```
  -------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                       Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
-------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
quantize_per_tensor        0.12%            155.807us        0.12%            155.807us        155.807us        1
quantized::conv2d          25.50%           31.981ms         25.50%           31.981ms         273.343us        117
quantized::add_scalar      44.53%           55.840ms         44.53%           55.840ms         809.281us        69
quantized::relu6           1.25%            1.570ms          1.25%            1.570ms          22.749us         69
quantized::mul_scalar      10.73%           13.449ms         10.73%           13.449ms         194.914us        69
quantized::mul             16.67%           20.904ms         16.67%           20.904ms         227.220us        92
adaptive_avg_pool2d        0.03%            41.713us         0.69%            862.922us        35.955us         24
_adaptive_avg_pool2d       0.65%            821.209us        0.65%            821.209us        34.217us         24
sigmoid                    0.15%            182.344us        0.15%            182.344us        7.928us          23
quantized::add             0.34%            431.939us        0.34%            431.939us        26.996us         16
dropout                    0.00%            1.936us          0.00%            1.936us          1.936us          1
view                       0.01%            10.281us         0.01%            10.281us         10.281us         1
dequantize                 0.00%            4.562us          0.00%            4.562us          4.562us          1
-------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Self CPU time total: 125.394ms
```
### After
```
 -------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                       Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
-------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
quantize_per_tensor        0.18%            130.534us        0.18%            130.534us        130.534us        1
quantized::conv2d          42.29%           31.267ms         42.29%           31.267ms         267.243us        117
quantized::add_scalar      6.27%            4.637ms          6.27%            4.637ms          67.205us         69
quantized::relu6           1.77%            1.312ms          1.77%            1.312ms          19.008us         69
quantized::mul_scalar      18.92%           13.991ms         18.92%           13.991ms         202.768us        69
quantized::mul             28.49%           21.059ms         28.49%           21.059ms         228.904us        92
adaptive_avg_pool2d        0.06%            45.242us         1.27%            942.522us        39.272us         24
_adaptive_avg_pool2d       1.21%            897.280us        1.21%            897.280us        37.387us         24
sigmoid                    0.22%            160.282us        0.22%            160.282us        6.969us          23
quantized::add             0.56%            416.276us        0.56%            416.276us        26.017us         16
dropout                    0.00%            1.245us          0.00%            1.245us          1.245us          1
view                       0.01%            7.122us          0.01%            7.122us          7.122us          1
dequantize                 0.01%            5.952us          0.01%            5.952us          5.952us          1
-------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Self CPU time total: 73.930ms
```

Differential Revision: [D20500848](https://our.internmc.facebook.com/intern/diff/D20500848/)

[ghstack-poisoned]
  • Loading branch information
dskhudia committed Mar 20, 2020
2 parents 8721b2c + a20b9f1 commit 82de5b5
Show file tree
Hide file tree
Showing 33 changed files with 2,849 additions and 39 deletions.
8 changes: 7 additions & 1 deletion android/gradle/android_maven_install.gradle
Expand Up @@ -19,7 +19,13 @@ install {
developerConnection scmDeveloperConnection
}

licenses projectLicenses
licenses {
license {
name = POM_LICENSE_NAME
url = POM_LICENSE_URL
distribution = POM_LICENSE_DIST
}
}

developers {
developer {
Expand Down
33 changes: 33 additions & 0 deletions aten/src/ATen/native/ComplexHelper.h
@@ -0,0 +1,33 @@
#pragma once

#include <ATen/ATen.h>

namespace at { namespace native {

inline std::vector<int64_t> computeStrideForComplex(IntArrayRef oldstride) {
auto res = oldstride.vec();
for(size_t i = 0; i < res.size(); i++) {
res[i] = res[i] * 2;
}
res.emplace_back(1);
return res;
}

// expects as input a complex tensor and returns back a float tensor
// containing the complex values in the last two dimensions
inline Tensor view_complex_as_float(const Tensor& self) {
TORCH_INTERNAL_ASSERT(self.is_complex());
auto new_sizes = self.sizes().vec();
// last dimension will always have two elements containing the real and imag vals
new_sizes.emplace_back(2);
auto new_strides = computeStrideForComplex(self.strides());
if(self.scalar_type() == at::kComplexFloat) {
float* data = reinterpret_cast<float*>(self.data_ptr<std::complex<float>>());
return at::from_blob(data, new_sizes, new_strides, dtype(at::kFloat));
} else {
double* data = reinterpret_cast<double*>(self.data_ptr<std::complex<double>>());
return at::from_blob(data, new_sizes, new_strides, dtype(at::kDouble));
}
}

}} // namespace at::native
27 changes: 1 addition & 26 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -14,6 +14,7 @@
#include <ATen/native/Distributions.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/ComplexHelper.h>

#include <ATen/native/cpu/Loops.h>
#include <ATen/native/cpu/zmath.h>
Expand Down Expand Up @@ -440,32 +441,6 @@ void normal_fill(Tensor& self, const scalar_t mean, const scalar_t std, Generato
}
}

std::vector<int64_t> computeStrideForComplex(IntArrayRef oldstride) {
auto res = oldstride.vec();
for(size_t i = 0; i < res.size(); i++) {
res[i] = res[i] * 2;
}
res.emplace_back(1);
return res;
}

// expects as input a complex tensor and returns back a float tensor
// containing the complex values in the last two dimensions
Tensor view_complex_as_float(const Tensor& self) {
TORCH_INTERNAL_ASSERT(self.is_complex());
auto new_sizes = self.sizes().vec();
// last dimension will always have two elements containing the real and imag vals
new_sizes.emplace_back(2);
auto new_strides = computeStrideForComplex(self.strides());
if(self.scalar_type() == at::kComplexFloat) {
float* data = reinterpret_cast<float*>(self.data_ptr<std::complex<float>>());
return at::from_blob(data, new_sizes, new_strides, dtype(at::kFloat));
} else {
double* data = reinterpret_cast<double*>(self.data_ptr<std::complex<double>>());
return at::from_blob(data, new_sizes, new_strides, dtype(at::kDouble));
}
}

void normal_kernel(Tensor& self, double mean, double std, Generator* gen) {
if(self.is_complex()) {
// note: float_tensor lives only as long as the self tensor lives
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/CUDAScalar.cu
Expand Up @@ -9,7 +9,7 @@ namespace native {

Scalar _local_scalar_dense_cuda(const Tensor& self) {
Scalar r;
AT_DISPATCH_ALL_TYPES_AND3(
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
scalar_t value;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/native/cuda/DistributionNormal.cu
Expand Up @@ -6,6 +6,7 @@
#include <ATen/CUDAGenerator.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/DistributionTemplates.h>
#include <ATen/native/ComplexHelper.h>

#include <curand.h>
#include <curand_kernel.h>
Expand Down Expand Up @@ -55,6 +56,14 @@ void normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generat

Tensor& normal_cuda_(Tensor& self, double mean, double std, Generator* gen) {
TORCH_CHECK(std > 0.0, "normal_ expects std > 0.0, but found std=", std);
if(self.is_complex()) {
// note: float_tensor lives only as long as the self tensor lives
auto float_tensor = at::native::view_complex_as_float(self);
// variance for normal distribution of the real and imaginary values
// is half of the input variance
normal_cuda_(float_tensor, mean, std/(std::sqrt(2)), gen);
return self;
}
auto iter = TensorIterator::nullary_op(self);
normal_kernel_cuda(iter, mean, std, gen);
return self;
Expand Down
28 changes: 28 additions & 0 deletions aten/src/ATen/test/complex_test.cpp
@@ -0,0 +1,28 @@
#include <gtest/gtest.h>
#include "c10/util/Complex.h"

template<typename T, typename int_t>
static void TestBinaryOpsForIntType(T real, T img, int_t num) {
std::complex<T> c(real, img);
ASSERT_EQ(c + num, std::complex<T>(real + num, img + num));
ASSERT_EQ(num + c, std::complex<T>(num + real, num + img));
ASSERT_EQ(c - num, std::complex<T>(real - num, num - num));
ASSERT_EQ(num - c, std::complex<T>(num - real, num - img));
ASSERT_EQ(c * num, std::complex<T>(real * num, img * num));
ASSERT_EQ(num * c, std::complex<T>(num * real, num * img));
ASSERT_EQ(c / num, std::complex<T>(real / num, img / num));
ASSERT_EQ(num / c, std::complex<T>(num / real, num / img));
}

template<typename T>
static void TestBinaryOpsForAllIntTypes(T real, T img, int8_t i) {
TestBinaryOpsForIntType<T, int8_t>(real, img, i, op);
TestBinaryOpsForIntType<T, int16_t>(real, img, i, op);
TestBinaryOpsForIntType<T, int32_t>(real, img, i, op);
TestBinaryOpsForIntType<T, int64_t>(real, img, i, op);
}

TEST(ComplexTest, Integer) {
TestBinaryOpsForAllIntTypes<float>(1.0, 0.1, 1);
TestBinaryOpsForAllIntTypes<double>(-1.3, -0.2, -2);
}
20 changes: 20 additions & 0 deletions benchmarks/fastrnns/bench.py
Expand Up @@ -199,8 +199,24 @@ def bench_group(model_list, bench_name, bench_group, bench_args):
parser.add_argument('--cnns', nargs='*',
help='What to run. resnet18, resnet18_jit, resnet50, etc')
parser.add_argument('--group', nargs='*', default=default_groups, help='Which group to run. cnns, rnns, etc.')
parser.add_argument('--fuser', default='te', type=str,
help='The fuser backend to use. One of: te, old, or none')
parser.add_argument('--cuda_pointwise_loop_level', default=None, type=int)
parser.add_argument('--cuda_pointwise_block_count', default=None, type=int)
parser.add_argument('--cuda_pointwise_block_size', default=None, type=int)

args = parser.parse_args()
assert args.fuser in ['te', 'old', 'none']
torch._C._jit_set_texpr_fuser_enabled(args.fuser == 'te')
torch._C._jit_override_can_fuse_on_gpu(args.fuser == 'old')
torch._C._jit_set_bailout_depth(20)
if args.cuda_pointwise_loop_level:
torch._C._jit_set_te_cuda_pointwise_loop_levels(args.cuda_pointwise_loop_level)
if args.cuda_pointwise_block_count:
torch._C._jit_set_te_cuda_pointwise_block_count(args.cuda_pointwise_block_count)
if args.cuda_pointwise_block_size:
torch._C._jit_set_te_cuda_pointwise_block_size(args.cuda_pointwise_block_size)

rnns = args.rnns or ['cudnn', 'aten', 'jit', 'jit_premul', 'jit_premul_bias', 'jit_simple',
'jit_multilayer', 'py']
cnns = args.cnns or ['resnet18', 'resnet18_jit', 'resnet50', 'resnet50_jit']
Expand All @@ -219,6 +235,10 @@ def bench_group(model_list, bench_name, bench_group, bench_args):
del bench_args['rnns']
del bench_args['cnns']
del bench_args['variable_lstms']
del bench_args['fuser']
del bench_args['cuda_pointwise_loop_level']
del bench_args['cuda_pointwise_block_count']
del bench_args['cuda_pointwise_block_size']

results = {}
if should_bench_varlen_lstms:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/operator_benchmark/pt/qactivation_test.py
Expand Up @@ -45,7 +45,7 @@
('relu', nnq.ReLU),
('relu6', nnq.ReLU6),
('functional.hardtanh', nnq.functional.hardtanh),
('functional.elu', nnq.functional.elu)
('functional.elu', nnq.functional.elu),
('functional.hardsigmoid', nnq.functional.hardsigmoid),
),
attr_names=('op_name', 'op_func'),
Expand Down
10 changes: 10 additions & 0 deletions benchmarks/tensorexpr/HowToRun.md
@@ -0,0 +1,10 @@
From the root of pytorch repo, run:
```
python -m benchmarks.tensorexpr --help
```
to show documentation.

An example of an actual command line:
```
python -m benchmarks.tensorexpr broadcast --device gpu --mode fwd --jit_mode trace
```

0 comments on commit 82de5b5

Please sign in to comment.