diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ffff87ee0a1..32ddb81cda1 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -38,7 +38,6 @@ #include "pybind11/pytypes.h" #include "pybind11/stl.h" #include "pybind11/stl_bind.h" -#include "status.h" #include "torch_xla/csrc/XLANativeFunctions.h" #include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_fallback.h" @@ -345,10 +344,10 @@ std::vector> CreateReduceGroups(const py::list& groups) { return replica_groups; } -std::vector XlaCustomCall( +std::vector TpuCustomCall( const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, - const std::vector& output_dtypes, bool is_tpu) { + const std::vector& output_dtypes) { std::vector dtypes; dtypes.reserve(output_dtypes.size()); for (auto& dtype : output_dtypes) { @@ -356,11 +355,7 @@ std::vector XlaCustomCall( } XLA_ASSIGN_OR_THROW(std::vector xla_inputs, bridge::GetXlaTensors(inputs)); - if (is_tpu) { - return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call( - xla_inputs, payload, output_shapes, dtypes)); - } - return bridge::AtenFromXlaTensors(tensor_methods::gpu_custom_call( + return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call( xla_inputs, payload, output_shapes, dtypes)); } @@ -3063,8 +3058,7 @@ void InitXlaModuleBindings(py::module m) { const std::vector>& output_shapes, const std::vector& output_dtypes) -> std::vector { - return XlaCustomCall(inputs, payload, output_shapes, output_dtypes, - /*is_tpu=*/true); + return TpuCustomCall(inputs, payload, output_shapes, output_dtypes); }) .def("_has_cuda_support", []() { @@ -3074,14 +3068,6 @@ void InitXlaModuleBindings(py::module m) { return false; #endif }) - .def("_xla_gpu_custom_call", - [](const std::vector& inputs, const std::string& payload, - const std::vector>& output_shapes, - const std::vector& output_dtypes) - -> std::vector { - return XlaCustomCall(inputs, payload, output_shapes, output_dtypes, - /*is_tpu=*/false); - }) .def("_xla_register_custom_call_target", [](const std::string& fn_name, const py::capsule& function_ptr, const std::string& platform) { diff --git a/torch_xla/csrc/ops/gpu_custom_call.cpp b/torch_xla/csrc/ops/gpu_custom_call.cpp deleted file mode 100644 index 26581f94899..00000000000 --- a/torch_xla/csrc/ops/gpu_custom_call.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "torch_xla/csrc/ops/gpu_custom_call.h" - -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/xla_ops.h" -#include "torch_xla/csrc/xla_lower_util.h" - -namespace torch_xla { - -GpuCustomCall::GpuCustomCall(torch::lazy::OpList inputs, - xla::Shape output_shape, - const std::string& payload) - : XlaNode(xla_gpu_custom_call, inputs, output_shape, - /*num_outputs=*/output_shape.tuple_shapes_size(), - torch::lazy::MHash(payload)), - payload_(payload) {} - -torch::lazy::NodePtr GpuCustomCall::Clone(torch::lazy::OpList operands) const { - return torch_xla::MakeNode(operands, xla_shape(), payload_); -} - -XlaOpVector GpuCustomCall::Lower(LoweringContext* loctx) const { - std::vector inputs; - inputs.reserve(operands().size()); - for (auto& operand : operands()) { - inputs.push_back(loctx->GetOutputOp(operand)); - } - auto output = BuildGpuCustomCall(inputs, xla_shape(), payload_); - return ReturnOps(output, loctx); -} - -std::string GpuCustomCall::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", " << payload_; - return ss.str(); -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/gpu_custom_call.h b/torch_xla/csrc/ops/gpu_custom_call.h deleted file mode 100644 index fa08d62be67..00000000000 --- a/torch_xla/csrc/ops/gpu_custom_call.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ -#define XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ - -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { -class GpuCustomCall : public XlaNode { - public: - // Make a GPU custom call with payload, e.g., Triton. - GpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, - const std::string& payload); - - torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - private: - std::string payload_; -}; - -} // namespace torch_xla - -#endif // XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index 9187ee64fa9..a253d9cad8b 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -39,6 +39,5 @@ const OpKindWrapper xla_unselect("xla::unselect"); const OpKindWrapper xla_update_slice("xla::update_slice"); const OpKindWrapper xla_custom_sharding("xla::custom_sharding"); const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call"); -const OpKindWrapper xla_gpu_custom_call("xla::gpu_custom_call"); } // namespace torch_xla diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index 86ab2c57d4d..042de15e5cc 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -64,8 +64,7 @@ extern const OpKindWrapper xla_unselect; extern const OpKindWrapper xla_update_slice; extern const OpKindWrapper xla_custom_sharding; extern const OpKindWrapper xla_tpu_custom_call; -extern const OpKindWrapper xla_gpu_custom_call; } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bfd67b59de2..e7814ce517d 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -65,7 +65,6 @@ #include "torch_xla/csrc/ops/generic.h" #include "torch_xla/csrc/ops/generic_slice.h" #include "torch_xla/csrc/ops/get_dimensions_size.h" -#include "torch_xla/csrc/ops/gpu_custom_call.h" #include "torch_xla/csrc/ops/hardtanh_backward.h" #include "torch_xla/csrc/ops/index_ops.h" #include "torch_xla/csrc/ops/index_select.h" @@ -767,45 +766,6 @@ void custom_sharding_( input->SetShardingSpec(*sharding_spec); } -std::vector gpu_custom_call( - const std::vector& inputs, const std::string& payload, - const std::vector>& output_shapes, - const std::vector& output_dtypes) { - XLA_CHECK(inputs.size() > 0) << "inputs are empty"; - - std::vector values; - values.reserve(inputs.size()); - for (const auto& input : inputs) { - values.push_back(input->GetIrValue()); - } - - XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size()); - std::vector output_xla_shapes; - output_xla_shapes.reserve(output_shapes.size()); - for (size_t i = 0; i < output_shapes.size(); ++i) { - output_xla_shapes.push_back(xla::ShapeUtil::MakeShape( - MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())), - output_shapes[i])); - } - - auto node = torch_xla::MakeNode( - values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload); - - std::vector outputs; - outputs.reserve(output_shapes.size()); - for (size_t i = 0; i < output_shapes.size(); ++i) { - outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i), - output_dtypes[i], - /*delay_eager_execution=*/true)); - } - XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); - if (graph_executor->UseEagerMode()) { - // Execute the HLO that will run the `custom` and in one hlo - graph_executor->ApplyEagerSync(outputs); - } - return outputs; -} - std::vector tpu_custom_call( const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index c28d7f2165e..597640bf4c4 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -103,11 +103,6 @@ void custom_sharding_( const std::shared_ptr& spec, const CustomSharding::Type& type = CustomSharding::Type::kSharding); -std::vector gpu_custom_call( - const std::vector& inputs, const std::string& payload, - const std::vector>& output_shapes, - const std::vector& output_dtypes); - std::vector tpu_custom_call( const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 1f8327acc36..e34c0d90adc 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1279,31 +1279,6 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input, const std::string& type, output_shape); } -std::vector BuildGpuCustomCall( - const std::vector& inputs, const xla::Shape& output_shape, - const std::string& payload) { - std::vector input_shapes; - input_shapes.reserve(inputs.size()); - for (const auto& input : inputs) { - input_shapes.push_back(ShapeHelper::ShapeOfXlaOp(input)); - } - - XLA_CHECK(inputs.size() > 0) << "inputs are empty"; - xla::XlaOp outputs = xla::CustomCallWithLayout( - inputs[0].builder(), - /*call_target_name=*/"triton_kernel_call", inputs, output_shape, - input_shapes, payload, false, {}, nullptr, - xla::CustomCallSchedule::SCHEDULE_NONE, - xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING); - std::vector result; - int num_outputs = output_shape.tuple_shapes_size(); - result.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - result.push_back(xla::GetTupleElement(outputs, i)); - } - return result; -} - std::vector BuildTpuCustomCall( const std::vector& inputs, const xla::Shape& output_shape, const std::string& payload) { diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 60ebad6dcd6..f2dc8a1915e 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -162,10 +162,6 @@ std::vector BuildTpuCustomCall( xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores, xla::XlaOp iou_threshold); -std::vector BuildGpuCustomCall( - const std::vector& inputs, const xla::Shape& output_shape, - const std::string& payload); - } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_