Skip to content

Commit bae1134

Browse files
committed
tmp
1 parent de5d764 commit bae1134

File tree

8 files changed

+33
-0
lines changed

8 files changed

+33
-0
lines changed

test/test_operations.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,6 +1867,16 @@ def test_patched_linear_1D_bias(self):
18671867
self.assertTrue(
18681868
torch.allclose(linear.bias.grad.cpu(), linear_cpu.bias.grad))
18691869

1870+
def test_tpu_custom_call(self):
1871+
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAExCQEDBQcBAwkDBQMLBQkNDxETBwkVFxkbA31bEwFXBw8LCw8PEwsLGwszCw8LMwsLCwsLCwsTHwsTHxsPCxNlDwsTDxMPIwtTEwUFYVkBExsfDw8fGwcLGwJeBB8dMw0BAQUdHT0/EQcBFw8VAQUfBSEDBRUXGRsFIyMPEREBAAAAAAAAAAUlEQUFBScDCyEFIyUnCykLKy0FKQUrDREFLQUvBTEFMwMDBzElAQkIAAAABTUDAwc3JQEJAAAAAAMFEQUHOxENAQU3Fw8TAWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAdRQ0FOQMDHUkRBxUDAx1NEQcJAwdRUxUXGRsFOyMFCSEBAAAAAQAAAAEAAAABAAAAAwMRBSN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACN0cHUudGlsZWQ8KDEyOCksWzFdPgAnBSECBAUXVwMCBAVZAQICAQIEF1cDAgQFQScFIQIEDwMBCQUHAwMDAQQSAgUBEAEHAwEFAxEBHwcDIT8HAwEDAQMBDwMDLwMBDwMDNQMBDwMJOQMNBwYBAwkDAQcGAQMJAwMHBgEDCQMFCQcJEwMBBQ0LCQcJEwMBBQ8LEwZDAwEFExULAgMDAREHA0cDCwUZCREHA0sDCwUZBxUGAwMLBRsdDQUDTwkXEQsfBQEBVQYDAQUBAOoFPSkLOTsLEyMhHSkVHxsVQw0LCwsTDQsLKQ8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAbW9kdWxlAHJldHVybgBlcmFzZV9tZW1yZWZfbGF5b3V0AGxvYWQAaW90YQBzdG9yZQBjb25zdGFudABjbXBpAGFkZGkAYW5kaQB2YWx1ZQAvaG9tZS9qd3Rhbi9wYWxsYXMvcGFsbGFzX2FkZC5weQBpbl9sYXlvdXQAc3VibGFuZV9tYXNrAHN1YmxhbmVfc3RyaWRlAHByZWRpY2F0ZQBkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgAvc3dhcFtpbmRleGVkX2RpbXM9KEZhbHNlLCldAC9nZXRbaW5kZXhlZF9kaW1zPShGYWxzZSwpXQAvYWRkAG9wZXJhbmRTZWdtZW50U2l6ZXMA\"}}"
1872+
1873+
x = torch.arange(8, dtype=torch.int).to("xla")
1874+
y = torch.arange(8, dtype=torch.int).to("xla")
1875+
output = torch.arange(8, dtype=torch.int).to("xla")
1876+
1877+
torch_xla._XLAC._xla_tpu_custom_call_(output, x, y, payload)
1878+
print(output)
1879+
18701880

18711881
class MNISTComparator(nn.Module):
18721882

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2170,6 +2170,11 @@ void InitXlaModuleBindings(py::module m) {
21702170
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
21712171
return XLANativeFunctions::set_(self, source);
21722172
});
2173+
m.def("_xla_tpu_custom_call_",
2174+
[](at::Tensor& output, const at::Tensor& x, const at::Tensor& y, const std::string& payload) {
2175+
auto x_output = bridge::GetXlaTensor(output);
2176+
return tensor_methods::tpu_custom_call_(x_output, bridge::GetXlaTensor(x), bridge::GetXlaTensor(y), payload);
2177+
});
21732178
m.def("_set_xla_custom_op_name_prefix",
21742179
[](const at::Tensor& input, const std::string& op_name_prefix,
21752180
size_t max_call_stack_depth) -> bool {

torch_xla/csrc/ops/xla_ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@ const OpKindWrapper xla_tensor_data("xla::tensor_data");
3333
const OpKindWrapper xla_unselect("xla::unselect");
3434
const OpKindWrapper xla_update_slice("xla::update_slice");
3535
const OpKindWrapper xla_custom_sharding("xla::custom_sharding");
36+
const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call");
3637

3738
} // namespace torch_xla

torch_xla/csrc/ops/xla_ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ extern const OpKindWrapper xla_tensor_data;
5858
extern const OpKindWrapper xla_unselect;
5959
extern const OpKindWrapper xla_update_slice;
6060
extern const OpKindWrapper xla_custom_sharding;
61+
extern const OpKindWrapper xla_tpu_custom_call;
6162

6263
} // namespace torch_xla
6364

torch_xla/csrc/tensor_methods.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "torch_xla/csrc/ops/cumprod.h"
4242
#include "torch_xla/csrc/ops/cumsum.h"
4343
#include "torch_xla/csrc/ops/custom_sharding.h"
44+
#include "torch_xla/csrc/ops/tpu_custom_call.h"
4445
#include "torch_xla/csrc/ops/dequant_tensor.h"
4546
#include "torch_xla/csrc/ops/device_data.h"
4647
#include "torch_xla/csrc/ops/diagonal.h"
@@ -523,6 +524,11 @@ void custom_sharding_(
523524
input->SetShardingSpec(*sharding_spec);
524525
}
525526

527+
void tpu_custom_call_(XLATensorPtr& output, const XLATensorPtr& x, const XLATensorPtr& y, const std::string& payload) {
528+
output->SetInPlaceIrValue(
529+
torch::lazy::MakeNode<TpuCustomCall>(x->GetIrValue(), y->GetIrValue(), payload));
530+
}
531+
526532
XLATensorPtr get_dimensions_size(const XLATensorPtr& input,
527533
std::vector<int64_t> dimensions) {
528534
return input->CreateFrom(torch::lazy::MakeNode<GetDimensionsSize>(

torch_xla/csrc/tensor_methods.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
8282
void custom_sharding_(const XLATensorPtr& input,
8383
const std::shared_ptr<XLATensor::ShardingSpec>& spec);
8484

85+
void tpu_custom_call_(XLATensorPtr& output, const XLATensorPtr& x, const XLATensorPtr& y, const std::string& payload);
86+
8587
XLATensorPtr get_dimensions_size(const XLATensorPtr& input,
8688
std::vector<int64_t> dimensions);
8789

torch_xla/csrc/xla_lower_util.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,4 +1224,10 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) {
12241224
{input}, ShapeHelper::ShapeOfXlaOp(input));
12251225
}
12261226

1227+
xla::XlaOp BuildTpuCustomCall(const xla::XlaOp& x, const xla::XlaOp& y, const std::string& payload) {
1228+
return xla::CustomCallWithLayout(x.builder(), /*call_target_name=*/"tpu_custom_call",
1229+
{x, y}, ShapeHelper::ShapeOfXlaOp(x) /*TODO: update later*/,
1230+
{ShapeHelper::ShapeOfXlaOp(x), ShapeHelper::ShapeOfXlaOp(y)}, payload);
1231+
}
1232+
12271233
} // namespace torch_xla

torch_xla/csrc/xla_lower_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p,
150150

151151
xla::XlaOp BuildCustomSharding(const xla::XlaOp& input);
152152

153+
xla::XlaOp BuildTpuCustomCall(const xla::XlaOp& x, const xla::XlaOp& y, const std::string& payload);
154+
153155
} // namespace torch_xla
154156

155157
#endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_

0 commit comments

Comments
 (0)