Skip to content

Commit

Permalink
Update on "Port all and any full reductions to structured kernels."
Browse files Browse the repository at this point in the history
Tracking issue: #55070 

This PR creates out overloads for both `all` and `any` kernels (full reduction overload),
and ports them to structured kernels.

Differential Revision: [D30867354](https://our.internmc.facebook.com/intern/diff/D30867354)

[ghstack-poisoned]
  • Loading branch information
ysiraichi committed Sep 13, 2021
2 parents f583683 + 38a6d35 commit 7a0a3e9
Show file tree
Hide file tree
Showing 32 changed files with 1,557 additions and 1,236 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/DLConvertor.cpp
Expand Up @@ -39,7 +39,7 @@ DLDataType getDLDataType(const Tensor& t) {
dtype.code = DLDataTypeCode::kDLFloat;
break;
case ScalarType::Bool:
dtype.code = DLDataTypeCode::kDLUInt;
TORCH_CHECK(false, "Bool type is not supported by dlpack");
break;
case ScalarType::ComplexHalf:
dtype.code = DLDataTypeCode::kDLComplex;
Expand Down
72 changes: 29 additions & 43 deletions aten/src/ATen/native/ReduceOps.cpp
Expand Up @@ -103,7 +103,13 @@ void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor
}
}

void allany_meta(
// Note [all, any : uint8 compatibility]:
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// For NumPy comptability, `all` and `any` return
// Tensor of dtype `bool`. However for compatibility reason,
// for `uint8`, they return Tensor of same dtype `uint8`.
// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
static void allany_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
Expand Down Expand Up @@ -1303,24 +1309,6 @@ Tensor norm(const Tensor& self, const Scalar& p) {
return at::norm(self, p, IntArrayRef{}, false);
}

// Note [all, any : uint8 compatibility]:
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// For NumPy comptability, `all` and `any` return
// Tensor of dtype `bool`. However for compatibility reason,
// for `uint8`, they return Tensor of same dtype `uint8`.
// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
inline const Tensor & _all(const Tensor & self, const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(1);
} else if (iter.numel() == 1) {
result.fill_(self.item());
} else {
and_stub(iter.device_type(), iter);
}

return result;
}

inline TensorIterator get_allany_iter(
const Tensor& self,
const Tensor& result,
Expand All @@ -1337,41 +1325,39 @@ inline TensorIterator get_allany_iter(
self, result, dims, keepdim, result.scalar_type());
}

template <int identity, typename Stub>
inline void allany_impl(
const Tensor& self,
const Tensor& result,
IntArrayRef dims,
bool keepdim,
Stub& stub) {
if (self.numel() == 0) {
result.fill_(identity);
} else if (self.numel() == 1) {
result.fill_(self.item().toBool());
} else {
auto iter = get_allany_iter(self, result, dims, keepdim);
stub(iter.device_type(), iter);
}
}

TORCH_IMPL_FUNC(all_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
_all(self, result, iter);
allany_impl<1>(self, result, dim, keepdim, and_stub);
}

TORCH_IMPL_FUNC(all_all_out)(const Tensor& self, const Tensor& result) {
auto iter = get_allany_iter(self, result, {}, false);
_all(self, result, iter);
}

inline const Tensor & _any(const Tensor & self, const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(0);
} else if (iter.numel() == 1) {
result.fill_(self.item());
} else {
or_stub(iter.device_type(), iter);
}

return result;
allany_impl<1>(self, result, {}, false, and_stub);
}

TORCH_IMPL_FUNC(any_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
_any(self, result, iter);
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
allany_impl<0>(self, result, dim, keepdim, or_stub);
}

TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
auto iter = get_allany_iter(self, result, {}, false);
_any(self, result, iter);
allany_impl<0>(self, result, {}, false, or_stub);
}

Tensor &amin_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
Expand Down
29 changes: 17 additions & 12 deletions aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
Expand Up @@ -163,24 +163,29 @@ static void std_var_kernel_impl(TensorIterator& iter, int64_t correction, bool t
}

static void prod_kernel_impl(TensorIterator& iter) {
// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
// Workaround for the error: '*' in boolean context, suggest '&&' instead
// [-Werror=int-in-bool-context]
if (iter.dtype() == ScalarType::Bool) {
using scalar_t = bool;
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a && b; },
[=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a && b; },
// NOLINTNEXTLINE(bugprone-argument-comment)
/*identity=*/1);
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] {
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a * b; },
[=](Vectorized <scalar_t> a, Vectorized <scalar_t> b) { return a * b; },
[=](scalar_t a, scalar_t b)
__ubsan_ignore_undefined__ -> scalar_t { return a && b; },
[=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
__ubsan_ignore_undefined__ { return a && b; },
// NOLINTNEXTLINE(bugprone-argument-comment)
/*identity=*/1);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] {
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b)
__ubsan_ignore_undefined__ -> scalar_t { return a * b; },
[=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
__ubsan_ignore_undefined__ { return a * b; },
// NOLINTNEXTLINE(bugprone-argument-comment)
/*identity=*/1);
});
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/quantization/eager/test_bias_correction_eager.py
Expand Up @@ -5,7 +5,7 @@

from torch.quantization import default_qconfig
from torch.quantization import QuantWrapper
import torch.quantization._numeric_suite as ns
import torch.ao.ns._numeric_suite as ns

from torch.quantization._correct_bias import (
_supported_modules,
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/eager/test_numeric_suite_eager.py
Expand Up @@ -10,7 +10,7 @@
quantize,
quantize_dynamic,
)
from torch.quantization._numeric_suite import (
from torch.ao.ns._numeric_suite import (
OutputLogger,
Shadow,
ShadowLogger,
Expand Down
14 changes: 7 additions & 7 deletions test/quantization/fx/test_numeric_suite_fx.py
Expand Up @@ -34,29 +34,29 @@
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
import torch.quantization.fx.quantization_patterns as qp
from torch.quantization.ns.pattern_utils import (
from torch.ao.ns.fx.pattern_utils import (
get_type_a_related_to_b,
)
from torch.quantization.ns.graph_matcher import (
from torch.ao.ns.fx.graph_matcher import (
get_matching_subgraph_pairs,
GraphMatchingException,
)
from torch.quantization.ns.utils import (
from torch.ao.ns.fx.utils import (
compute_sqnr,
compute_normalized_l2_error,
compute_cosine_similarity,
)
from torch.quantization.ns.mappings import (
from torch.ao.ns.fx.mappings import (
get_node_type_to_io_type_map,
get_unmatchable_types_map,
get_base_name_to_sets_of_related_ops,
get_base_name_for_op,
add_op_to_sets_of_related_ops,
)
from torch.quantization.ns.weight_utils import (
from torch.ao.ns.fx.weight_utils import (
get_op_to_type_to_weight_extraction_fn,
)
from torch.quantization._numeric_suite_fx import (
from torch.ao.ns._numeric_suite_fx import (
extract_weights,
_extract_weights_impl,
add_loggers,
Expand Down Expand Up @@ -1634,7 +1634,7 @@ def forward(self, x):
op_to_type_to_weight_extraction_fn = \
get_op_to_type_to_weight_extraction_fn()
op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \
torch.quantization.ns.weight_utils.get_linear_fun_weight
torch.ao.ns.fx.weight_utils.get_linear_fun_weight

# test compare weights
results = extract_weights(
Expand Down
54 changes: 1 addition & 53 deletions test/test_modules.py
Expand Up @@ -6,7 +6,7 @@
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck)
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from)
from unittest.mock import patch


Expand Down Expand Up @@ -206,58 +206,6 @@ def test_check_inplace(self, device, dtype, module_info):
self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)


def _test_gradients_helper(self, device, dtype, module_info, check):
# Check gradients
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True)

for module_input in module_inputs:
if module_input.forward_input is None:
continue

# === Instantiate the module. ===
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)

params = tuple(m.parameters())

# === Perform gradient check on the input_args ===
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs

other_kwargs = {}
kwarg_tensors = []
for name, obj in input_kwargs.items():
if isinstance(obj, torch.Tensor):
kwarg_tensors.append((name, obj))
else:
other_kwargs[name] = obj

grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)

def fn_to_gradcheck(*input_and_params):
new_input_args = input_and_params[:len(input_args)]
kwarg_args = input_and_params[-len(kwarg_tensors):]
new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}

with freeze_rng_state():
return m(*new_input_args, **new_kwargs, **other_kwargs)

self.assertTrue(check(fn_to_gradcheck, grad_input))


@modules(module_db, allowed_dtypes=[torch.double])
def test_grad(self, device, dtype, module_info):
self._test_gradients_helper(device, dtype, module_info, gradcheck)

@modules(module_db, allowed_dtypes=[torch.double])
def test_gradgrad(self, device, dtype, module_info):
if not module_info.supports_gradgrad:
self.skipTest("Skipped! Module does not support gradgrad")
self._test_gradients_helper(device, dtype, module_info, gradgradcheck)


instantiate_device_type_tests(TestModule, globals())

if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions test/test_python_dispatch.py
Expand Up @@ -95,6 +95,7 @@ def capture_logs() -> Iterator[List[str]]:
handler = LoggingTensorHandler(log_list)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False
try:
yield log_list
finally:
Expand Down

0 comments on commit 7a0a3e9

Please sign in to comment.