diff --git a/test/stablehlo/test_stablehlo_custom_call.py b/test/stablehlo/test_stablehlo_custom_call.py index a315bbc230db..affadc8ef72d 100644 --- a/test/stablehlo/test_stablehlo_custom_call.py +++ b/test/stablehlo/test_stablehlo_custom_call.py @@ -1,3 +1,4 @@ +import expecttest import sys import re import unittest @@ -16,7 +17,7 @@ m = Library("my_custom_library", "DEF") -class StableHLOCustomCallExportTest(unittest.TestCase): +class StableHLOCustomCallExportTest(expecttest.TestCase): def test_single_output(self): diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index 42e0bb8cb760..579a2f456605 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -1,3 +1,4 @@ +from typing import Callable import expecttest import os import torch @@ -357,6 +358,56 @@ def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]): expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3.""" ) + def _get_custom_call_properties(self, mode): + match mode: + case "tpu": + return (torch_xla._XLAC._xla_tpu_custom_call, "", []) + case "stablehlo": + return (torch_xla._XLAC._xla_custom_call, "custom_op_target", + [False, "", 0, {}]) + + self.fail(f"expected `mode` ({mode}) to be either of ['tpu', 'stablehlo'].") + + def _gen_custom_call_no_input(self, mode): + lib_custom_call, payload, args = self._get_custom_call_properties( + mode) # type: ignore[attr-defined] + return lambda: lib_custom_call([], payload, [[1]], [torch.int8], *args) + + def _gen_custom_call_output_properties_size_mismatch(self, mode): + lib_custom_call, payload, args = self._get_custom_call_properties( + mode) # type: ignore[attr-defined] + input = torch.rand(10, device=torch_xla.device()) + return lambda: lib_custom_call( + (input,), payload, [[1], [1]], [torch.int8], *args) + + def test_stablehlo_custom_call(self): + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=self._gen_custom_call_no_input("stablehlo"), + expect="""custom_call(custom_op_target): expected at least 1 input tensor.""" + ) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=self._gen_custom_call_output_properties_size_mismatch( + "stablehlo"), + expect="""custom_call(custom_op_target): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1).""" + ) + + def test_tpu_custom_call(self): + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=self._gen_custom_call_no_input("tpu"), + expect="""tpu_custom_call(): expected at least 1 input tensor.""") + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=self._gen_custom_call_output_properties_size_mismatch("tpu"), + expect="""tpu_custom_call(): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1).""" + ) + if __name__ == "__main__": unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 0f3c1a19e1f6..5e8d791aff9d 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -347,21 +347,6 @@ std::vector> CreateReduceGroups(const py::list& groups) { return replica_groups; } -std::vector TpuCustomCall( - const std::vector& inputs, const std::string& payload, - const std::vector>& output_shapes, - const std::vector& output_dtypes) { - std::vector dtypes; - dtypes.reserve(output_dtypes.size()); - for (auto& dtype : output_dtypes) { - dtypes.push_back(reinterpret_cast(dtype.ptr())->scalar_type); - } - XLA_ASSIGN_OR_THROW(std::vector xla_inputs, - bridge::GetXlaTensors(inputs)); - return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call( - xla_inputs, payload, output_shapes, dtypes)); -} - std::vector> ExtractXlaDotGeneralDimVectors( const py::tuple& dimension_numbers) { // Expect Python arg `dimension_numbers` to be @@ -3116,30 +3101,33 @@ void InitXlaModuleBindings(py::module m) { "_xla_custom_call", [](const std::vector& inputs, const std::string& target, const std::vector>& output_shapes, - const std::vector& output_dtypes, bool has_side_effect, + const std::vector& output_dtypes, bool has_side_effect, const std::string& backend_config, const int api_version, const std::unordered_map& frontend_attributes) -> std::vector { - std::vector dtypes; - dtypes.reserve(output_dtypes.size()); - for (auto& dtype : output_dtypes) { - dtypes.push_back( - reinterpret_cast(dtype.ptr())->scalar_type); - } - XLA_ASSIGN_OR_THROW(std::vector xla_inputs, bridge::GetXlaTensors(inputs)); - auto xtensors = tensor_methods::custom_call( - xla_inputs, target, - output_shapes, dtypes, has_side_effect, backend_config, - api_version, frontend_attributes); - return bridge::AtenFromXlaTensors(std::move(xtensors)); + XLA_ASSIGN_OR_THROW(std::vector xla_inputs, + bridge::GetXlaTensors(inputs)); + XLA_ASSIGN_OR_THROW(std::vector xla_outputs, + tensor_methods::custom_call( + xla_inputs, target, output_shapes, output_dtypes, + has_side_effect, backend_config, api_version, + frontend_attributes)); + + return bridge::AtenFromXlaTensors(std::move(xla_outputs)); }) .def("_xla_tpu_custom_call", [](const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, - const std::vector& output_dtypes) + const std::vector& output_dtypes) -> std::vector { - return TpuCustomCall(inputs, payload, output_shapes, output_dtypes); + + XLA_ASSIGN_OR_THROW(std::vector xla_inputs, + bridge::GetXlaTensors(inputs)); + XLA_ASSIGN_OR_THROW(std::vector xla_outputs, + tensor_methods::tpu_custom_call(xla_inputs, payload, output_shapes, output_dtypes)); + + return bridge::AtenFromXlaTensors(std::move(xla_outputs)); }) .def("_xla_register_custom_call_target", [](const std::string& fn_name, const py::capsule& function_ptr, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index d81b812bd59a..be12823fb7ec 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -299,27 +299,6 @@ absl::StatusOr FillAndCheckPoolNdInputs( return PoolNdInputsOwner{kernel_size, stride, padding}; } -// Resizes and / or checks whether a list is of the given size. The list is only -// resized if its size is 1. If it's empty, it's replaced with the provided -// default first. -std::vector CheckIntList(absl::Span list, size_t length, - const std::string& name, - std::vector def = {}) { - std::vector result; - if (list.empty()) { - result = std::move(def); - } else { - result = torch::lazy::ToVector(list); - } - if (result.size() == 1 && length > 1) { - result.resize(length, result[0]); - return result; - } - XLA_CHECK_EQ(result.size(), length) - << "Invalid length for the '" << name << "' attribute"; - return result; -} - // Returns a 1-D shape for batch norm weight or bias based on the input shape. xla::Shape BatchNormFeaturesShape(const XLATensorPtr& input) { xla::PrimitiveType input_element_type = @@ -666,6 +645,92 @@ absl::Status CheckUniformRangeIsValid(double from, double to) { return absl::OkStatus(); } +// This check is used for both `custom_call()` and `tpu_custom_call()`. +// +// The `target` parameter is `std::nullopt` whenever it's being called from +// a `tpu_custom_call()` context. +absl::Status CheckCustomCallNonEmptyInputs( + const std::vector& inputs, + const std::optional& target) { + if (inputs.empty()) { + std::string op = target.has_value() + ? absl::StrCat("custom_call(", *target, ")") + : "tpu_custom_call()"; + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat(op, ": expected at least 1 input tensor."))); + } + return absl::OkStatus(); +} + +// This check is used for both `custom_call()` and `tpu_custom_call()`. +// +// The `target` parameter is `std::nullopt` whenever it's being called from +// a `tpu_custom_call()` context. +absl::Status CheckCustomCallOutputPropertiesSize( + const std::vector>& output_shapes, + const std::vector& output_dtypes, + const std::optional& target) { + if (output_shapes.size() != output_dtypes.size()) { + std::string op = target.has_value() + ? absl::StrCat("custom_call(", *target, ")") + : "tpu_custom_call()"; + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, ": expected the given output shapes (size=", output_shapes.size(), + ") to be of the same size as the given output dtypes (size=", + output_dtypes.size(), ")."))); + } + return absl::OkStatus(); +} + +// This check is used for both `custom_call()` and `tpu_custom_call()`. +// +// The `target` parameter is `std::nullopt` whenever it's being called from +// a `tpu_custom_call()` context. +template +absl::StatusOr> CustomCallImpl( + const std::vector& inputs, + const std::optional& target, + const std::vector>& output_shapes, + const std::vector& output_dtypes, F&& make_node) { + XLA_RETURN_IF_ERROR(CheckCustomCallNonEmptyInputs(inputs, target)); + XLA_RETURN_IF_ERROR(CheckCustomCallOutputPropertiesSize( + output_shapes, output_dtypes, target)); + + const auto& first = inputs.front(); + auto device = first->GetDevice(); + auto output_range = c10::irange(output_shapes.size()); + + // `values`: vector with Lazy IR of `inputs`. + std::vector values(inputs.size()); + std::transform( + inputs.begin(), inputs.end(), values.begin(), + [](const XLATensorPtr& tensor) { return tensor->GetIrValue(); }); + + // `output_xla_shapes`: `xla::Shape` instances created from `output_shapes` + // and `output_dtypes`. + std::vector output_xla_shapes(output_shapes.size()); + std::transform(output_range.begin(), output_range.end(), + output_xla_shapes.begin(), [&](std::size_t i) { + return xla::ShapeUtil::MakeShape( + MakeXlaPrimitiveType(output_dtypes[i], &device), + output_shapes[i]); + }); + + auto node = make_node(values, output_xla_shapes); + + // `outputs`: `XLATensorPtr` instances created from the `i`-th output of + // the `node` Lazy IR `Node`. + std::vector outputs(output_shapes.size()); + std::transform(output_range.begin(), output_range.end(), outputs.begin(), + [&](std::size_t i) { + return first->CreateFrom(torch::lazy::Value(node, i), + output_dtypes[i], + /*delay_eager_execution=*/true); + }); + + return outputs; +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -886,40 +951,26 @@ std::pair collective_permute( torch::lazy::Value(node, 1)}; } -std::vector custom_call( - const std::vector& inputs, const std::string& target, +absl::StatusOr> custom_call( + const std::vector& inputs, + const std::string& target, const std::vector>& output_shapes, const std::vector& output_dtypes, bool has_side_effect, const std::string& backend_config, const int api_version, const std::unordered_map& frontend_attributes) { - XLA_CHECK(inputs.size() > 0) << "inputs are empty"; - - std::vector values; - values.reserve(inputs.size()); - for (const auto& input : inputs) { - values.push_back(input->GetIrValue()); - } - - XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size()); - std::vector output_xla_shapes; - output_xla_shapes.reserve(output_shapes.size()); - for (size_t i = 0; i < output_shapes.size(); ++i) { - output_xla_shapes.push_back(xla::ShapeUtil::MakeShape( - MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())), - output_shapes[i])); - } - - auto node = torch_xla::MakeNode( - values, target, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), - has_side_effect, backend_config, api_version, frontend_attributes); + XLA_ASSIGN_OR_RETURN( + std::vector outputs, + CustomCallImpl(inputs, target, output_shapes, output_dtypes, + /* make_node= */ + [&](const std::vector& values, + const std::vector& output_xla_shapes) { + return torch_xla::MakeNode( + values, target, + xla::ShapeUtil::MakeTupleShape(output_xla_shapes), + has_side_effect, backend_config, api_version, + frontend_attributes); + })); - std::vector outputs; - outputs.reserve(output_shapes.size()); - for (size_t i = 0; i < output_shapes.size(); ++i) { - outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i), - output_dtypes[i], - /*delay_eager_execution=*/true)); - } XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); if (graph_executor->UseEagerMode()) { // Execute the HLO that will run the `customcall` and in one graph @@ -954,37 +1005,23 @@ void custom_sharding_( input->SetShardingSpec(*sharding_spec); } -std::vector tpu_custom_call( - const std::vector& inputs, const std::string& payload, +absl::StatusOr> tpu_custom_call( + const std::vector& inputs, + const std::string& payload, const std::vector>& output_shapes, const std::vector& output_dtypes) { - XLA_CHECK(inputs.size() > 0) << "inputs are empty"; - - std::vector values; - values.reserve(inputs.size()); - for (const auto& input : inputs) { - values.push_back(input->GetIrValue()); - } - - XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size()); - std::vector output_xla_shapes; - output_xla_shapes.reserve(output_shapes.size()); - for (size_t i = 0; i < output_shapes.size(); ++i) { - output_xla_shapes.push_back(xla::ShapeUtil::MakeShape( - MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())), - output_shapes[i])); - } - - auto node = torch_xla::MakeNode( - values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload); + XLA_ASSIGN_OR_RETURN( + std::vector outputs, + CustomCallImpl( + inputs, /* target= */ std::nullopt, output_shapes, output_dtypes, + /* make_node= */ + [&](const std::vector& values, + const std::vector& output_xla_shapes) { + return torch_xla::MakeNode( + values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), + payload); + })); - std::vector outputs; - outputs.reserve(output_shapes.size()); - for (size_t i = 0; i < output_shapes.size(); ++i) { - outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i), - output_dtypes[i], - /*delay_eager_execution=*/true)); - } XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); if (graph_executor->UseEagerMode()) { // Execute the HLO that will run the `custom` and in one hlo diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index b08ad948fa32..37827e341fc8 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -92,8 +92,9 @@ std::pair collective_permute( const XLATensorPtr& input, const torch::lazy::Value& token, std::vector> source_target_pairs); -std::vector custom_call( - const std::vector& inputs, const std::string& target, +absl::StatusOr> custom_call( + const std::vector& inputs, + const std::string& target, const std::vector>& output_shapes, const std::vector& output_dtypes, bool has_side_effect, const std::string& backend_config, const int api_version, @@ -104,8 +105,9 @@ void custom_sharding_( const std::shared_ptr& spec, const CustomSharding::Type& type = CustomSharding::Type::kSharding); -std::vector tpu_custom_call( - const std::vector& inputs, const std::string& payload, +absl::StatusOr> tpu_custom_call( + const std::vector& inputs, + const std::string& payload, const std::vector>& output_shapes, const std::vector& output_dtypes);