Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "status.h"
#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_fallback.h"
Expand Down Expand Up @@ -345,22 +344,18 @@ std::vector<std::vector<int64_t>> CreateReduceGroups(const py::list& groups) {
return replica_groups;
}

std::vector<at::Tensor> XlaCustomCall(
std::vector<at::Tensor> TpuCustomCall(
const std::vector<at::Tensor>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes, bool is_tpu) {
const std::vector<py::object>& output_dtypes) {
std::vector<at::ScalarType> dtypes;
dtypes.reserve(output_dtypes.size());
for (auto& dtype : output_dtypes) {
dtypes.push_back(reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
}
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
bridge::GetXlaTensors(inputs));
if (is_tpu) {
return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call(
xla_inputs, payload, output_shapes, dtypes));
}
return bridge::AtenFromXlaTensors(tensor_methods::gpu_custom_call(
return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call(
xla_inputs, payload, output_shapes, dtypes));
}

Expand Down Expand Up @@ -3063,8 +3058,7 @@ void InitXlaModuleBindings(py::module m) {
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes)
-> std::vector<at::Tensor> {
return XlaCustomCall(inputs, payload, output_shapes, output_dtypes,
/*is_tpu=*/true);
return TpuCustomCall(inputs, payload, output_shapes, output_dtypes);
})
.def("_has_cuda_support",
[]() {
Expand All @@ -3074,14 +3068,6 @@ void InitXlaModuleBindings(py::module m) {
return false;
#endif
})
.def("_xla_gpu_custom_call",
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes)
-> std::vector<at::Tensor> {
return XlaCustomCall(inputs, payload, output_shapes, output_dtypes,
/*is_tpu=*/false);
})
.def("_xla_register_custom_call_target",
[](const std::string& fn_name, const py::capsule& function_ptr,
const std::string& platform) {
Expand Down
37 changes: 0 additions & 37 deletions torch_xla/csrc/ops/gpu_custom_call.cpp

This file was deleted.

25 changes: 0 additions & 25 deletions torch_xla/csrc/ops/gpu_custom_call.h

This file was deleted.

1 change: 0 additions & 1 deletion torch_xla/csrc/ops/xla_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,5 @@ const OpKindWrapper xla_unselect("xla::unselect");
const OpKindWrapper xla_update_slice("xla::update_slice");
const OpKindWrapper xla_custom_sharding("xla::custom_sharding");
const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call");
const OpKindWrapper xla_gpu_custom_call("xla::gpu_custom_call");

} // namespace torch_xla
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ extern const OpKindWrapper xla_unselect;
extern const OpKindWrapper xla_update_slice;
extern const OpKindWrapper xla_custom_sharding;
extern const OpKindWrapper xla_tpu_custom_call;
extern const OpKindWrapper xla_gpu_custom_call;

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_
#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_
40 changes: 0 additions & 40 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
#include "torch_xla/csrc/ops/generic.h"
#include "torch_xla/csrc/ops/generic_slice.h"
#include "torch_xla/csrc/ops/get_dimensions_size.h"
#include "torch_xla/csrc/ops/gpu_custom_call.h"
#include "torch_xla/csrc/ops/hardtanh_backward.h"
#include "torch_xla/csrc/ops/index_ops.h"
#include "torch_xla/csrc/ops/index_select.h"
Expand Down Expand Up @@ -767,45 +766,6 @@ void custom_sharding_(
input->SetShardingSpec(*sharding_spec);
}

std::vector<XLATensorPtr> gpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes) {
XLA_CHECK(inputs.size() > 0) << "inputs are empty";

std::vector<torch::lazy::Value> values;
values.reserve(inputs.size());
for (const auto& input : inputs) {
values.push_back(input->GetIrValue());
}

XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size());
std::vector<xla::Shape> output_xla_shapes;
output_xla_shapes.reserve(output_shapes.size());
for (size_t i = 0; i < output_shapes.size(); ++i) {
output_xla_shapes.push_back(xla::ShapeUtil::MakeShape(
MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())),
output_shapes[i]));
}

auto node = torch_xla::MakeNode<GpuCustomCall>(
values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload);

std::vector<XLATensorPtr> outputs;
outputs.reserve(output_shapes.size());
for (size_t i = 0; i < output_shapes.size(); ++i) {
outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i),
output_dtypes[i],
/*delay_eager_execution=*/true));
}
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
if (graph_executor->UseEagerMode()) {
// Execute the HLO that will run the `custom` and in one hlo
graph_executor->ApplyEagerSync(outputs);
}
return outputs;
}

std::vector<XLATensorPtr> tpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
Expand Down
5 changes: 0 additions & 5 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,6 @@ void custom_sharding_(
const std::shared_ptr<XLATensor::ShardingSpec>& spec,
const CustomSharding::Type& type = CustomSharding::Type::kSharding);

std::vector<XLATensorPtr> gpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes);

std::vector<XLATensorPtr> tpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
Expand Down
25 changes: 0 additions & 25 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,31 +1279,6 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input, const std::string& type,
output_shape);
}

std::vector<xla::XlaOp> BuildGpuCustomCall(
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
const std::string& payload) {
std::vector<xla::Shape> input_shapes;
input_shapes.reserve(inputs.size());
for (const auto& input : inputs) {
input_shapes.push_back(ShapeHelper::ShapeOfXlaOp(input));
}

XLA_CHECK(inputs.size() > 0) << "inputs are empty";
xla::XlaOp outputs = xla::CustomCallWithLayout(
inputs[0].builder(),
/*call_target_name=*/"triton_kernel_call", inputs, output_shape,
input_shapes, payload, false, {}, nullptr,
xla::CustomCallSchedule::SCHEDULE_NONE,
xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING);
std::vector<xla::XlaOp> result;
int num_outputs = output_shape.tuple_shapes_size();
result.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
result.push_back(xla::GetTupleElement(outputs, i));
}
return result;
}

std::vector<xla::XlaOp> BuildTpuCustomCall(
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
const std::string& payload) {
Expand Down
4 changes: 0 additions & 4 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,6 @@ std::vector<xla::XlaOp> BuildTpuCustomCall(
xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores,
xla::XlaOp iou_threshold);

std::vector<xla::XlaOp> BuildGpuCustomCall(
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
const std::string& payload);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_