From 52ed576d3c3d18d1a677124f9ae0efe261f84cdd Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sat, 23 May 2020 13:20:48 -0700 Subject: [PATCH] User XLA computations support. --- test/test_operations.py | 59 ++- torch_xla/core/xla_builder.py | 331 ++++++++++++++ torch_xla/core/xla_op_registry.py | 81 ++++ torch_xla/csrc/computation.cpp | 14 + torch_xla/csrc/computation.h | 33 ++ torch_xla/csrc/init_python_bindings.cpp | 91 ++++ torch_xla/csrc/ops/user_computation.cpp | 54 +++ torch_xla/csrc/ops/user_computation.h | 28 ++ torch_xla/csrc/tensor.h | 5 + torch_xla/csrc/tensor_methods.cpp | 14 + torch_xla/csrc/xla_op_builder.cpp | 550 ++++++++++++++++++++++++ torch_xla/csrc/xla_op_builder.h | 33 ++ 12 files changed, 1286 insertions(+), 7 deletions(-) create mode 100644 torch_xla/core/xla_builder.py create mode 100644 torch_xla/core/xla_op_registry.py create mode 100644 torch_xla/csrc/computation.cpp create mode 100644 torch_xla/csrc/computation.h create mode 100644 torch_xla/csrc/ops/user_computation.cpp create mode 100644 torch_xla/csrc/ops/user_computation.h create mode 100644 torch_xla/csrc/xla_op_builder.cpp create mode 100644 torch_xla/csrc/xla_op_builder.h diff --git a/test/test_operations.py b/test/test_operations.py index 29f46c2e4d3..5e67c01ff70 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -29,6 +29,8 @@ import torch.nn.functional as F import torch.optim as optim import torch_xla +import torch_xla.core.xla_builder as xb +import torch_xla.core.xla_op_registry as xor import torch_xla.distributed.data_parallel as dp import torch_xla.debug.metrics as met import torch_xla.debug.model_comparator as mc @@ -459,6 +461,15 @@ def maybePrintGraph(self, tensors): else: raise RuntimeError('Invalid TEST_PRINT_GRAPH value: {}'.format(env)) + def compareResults(self, results, xla_results, rel_err=1e-2, abs_err=1e-5): + self.maybePrintGraph(xla_results) + for at, xt in zip(results, xla_results): + self.assertEqualRel( + self.makeComparable(xt), + self.makeComparable(at), + rel_err=rel_err, + abs_err=abs_err) + def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5): if device is None: device = xm.xla_device() @@ -468,13 +479,7 @@ def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5): ] results = xu.as_list(fn(*tensors)) xla_results = xu.as_list(fn(*xla_tensors)) - self.maybePrintGraph(xla_results) - for at, xt in zip(results, xla_results): - self.assertEqualRel( - self.makeComparable(xt), - self.makeComparable(at), - rel_err=rel_err, - abs_err=abs_err) + self.compareResults(results, xla_results, rel_err=rel_err, abs_err=abs_err) class TestToXlaTensorArena(XlaTestCase): @@ -1680,6 +1685,46 @@ def test(self): self.assertEqual(len(report), 0) +class TestOpBuilder(XlaTestCase): + + def runOpBuilderTest(self, + name, + tensors, + opfn, + aten_fn=None, + device=None, + rel_err=1e-2, + abs_err=1e-5): + op = xor.register(name, opfn) + if device is None: + device = xm.xla_device() + if aten_fn is None: + aten_fn = opfn + tensors = xu.as_list(tensors) + xla_tensors = [ + x.to(device).detach().requires_grad_(x.requires_grad) for x in tensors + ] + results = xu.as_list(aten_fn(*tensors)) + xla_results = xu.as_list(op(*xla_tensors)) + self.compareResults(results, xla_results, rel_err=rel_err, abs_err=abs_err) + + def test_add(self): + + def op_fn(a, b, **kwargs): + return a + b + + self.runOpBuilderTest( + 'test_add', [torch.randn(2, 2), torch.randn(2, 2)], op_fn) + + def test_mul(self): + + def op_fn(a, b, **kwargs): + return a * b + + self.runOpBuilderTest( + 'test_mul', [torch.randn(2, 2), torch.randn(2, 2)], op_fn) + + class TestGeneric(XlaTestCase): def test_zeros_like_patch(self): diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py new file mode 100644 index 00000000000..32d56975825 --- /dev/null +++ b/torch_xla/core/xla_builder.py @@ -0,0 +1,331 @@ +from __future__ import division +from __future__ import print_function + +import torch_xla + + +class Op(object): + + def __init__(self, op): + self.op = op + + def shape(self): + return torch_xla._XLAC._xla_op_shape(self.op) + + def builder(self): + return torch_xla._XLAC._xla_op_builder(self.op) + + def build(self, name): + return torch_xla._XLAC._xla_op_build(name, self.op) + + def __add__(self, rhs): + return mkop('Add', (self.op, rhs.op)) + + def __sub__(self, rhs): + return mkop('Sub', (self.op, rhs.op)) + + def __mul__(self, rhs): + return mkop('Mul', (self.op, rhs.op)) + + def __matmul__(self, rhs): + return mkop('Dot', (self.op, rhs.op)) + + def __truediv__(self, rhs): + return mkop('Div', (self.op, rhs.op)) + + def __pow__(self, rhs): + return mkop('Pow', (self.op, rhs.op)) + + def __mod__(self, rhs): + return mkop('Rem', (self.op, rhs.op)) + + def __neg__(self): + return mkop('Neg', (self.op,)) + + def __not__(self): + return mkop('Not', (self.op,)) + + def __and__(self, rhs): + return mkop('And', (self.op, rhs.op)) + + def __or__(self, rhs): + return mkop('Or', (self.op, rhs.op)) + + def __xor__(self, rhs): + return mkop('Xor', (self.op, rhs.op)) + + def __eq__(self, rhs): + return mkop('Eq', (self.op, rhs.op)) + + def __ne__(self, rhs): + return mkop('Ne', (self.op, rhs.op)) + + def __le__(self, rhs): + return mkop('Le', (self.op, rhs.op)) + + def __lt__(self, rhs): + return mkop('Lt', (self.op, rhs.op)) + + def __ge__(self, rhs): + return mkop('Ge', (self.op, rhs.op)) + + def __gt__(self, rhs): + return mkop('Gt', (self.op, rhs.op)) + + def __lshift__(self, rhs): + return mkop('ShiftLeft', (self.op, rhs.op)) + + def __rshift__(self, rhs): + return mkop('ShiftRight', (self.op, rhs.op)) + + def reshape(self, sizes, dimensions=None, inferred_dimension=None): + return mkop( + 'Reshape', (self.op,), + sizes=sizes, + dimensions=dimensions, + inferred_dimension=inferred_dimension) + + def dynamic_reshape(self, sizes): + return mkop('DynamicReshape', (self.op,), sizes=sizes) + + def broadcast(self, sizes): + return mkop('Broadcast', (self.op,), sizes=sizes) + + def broadcast_in_dim(self, sizes, dimensions): + return mkop( + 'BroadcastInDim', (self.op,), sizes=sizes, dimensions=dimensions) + + def slice(self, start_indices, limit_indices, strides=None): + if strides is None: + strides = [1] * len(start_indices) + return mkop( + 'Slice', (self.op,), + start_indices=start_indices, + limit_indices=limit_indices, + strides=strides) + + def slice_in_dim(self, start_index, limit_index, dimno, stride=1): + return mkop( + 'SliceInDim', (self.op,), + start_index=start_index, + limit_index=limit_index, + dimno=dimno, + stride=stride) + + def dynamic_slice(self, start_indices, slice_sizes): + start_indices = [x.op for x in start_indices] + return mkop( + 'DynamicSlice', (self.op,), + start_indices=start_indices, + slice_sizes=slice_sizes) + + def dynamic_update_slice(self, update, start_indices): + start_indices = [x.op for x in start_indices] + return mkop( + 'DynamicUpdateSlice', (self.op, update.op), start_indices=start_indices) + + def gather(self, + start_indices, + offset_dims, + collapsed_slice_dims, + start_index_map, + index_vector_dim, + indices_are_sorted=None): + return mkop( + 'Gather', (self.op, start_indices.op), + offset_dims=offset_dims, + collapsed_slice_dims=collapsed_slice_dims, + start_index_map=start_index_map, + index_vector_dim=index_vector_dim, + indices_are_sorted=indices_are_sorted) + + def scatter(self, + scatter_indices, + updates, + update_window_dims, + inserted_window_dims, + index_vector_dim, + indices_are_sorted=None, + unique_indices=None): + return mkop( + 'Scatter', (self.op, scatter_indices.op, updates.op), + update_window_dims=update_window_dims, + inserted_window_dims=inserted_window_dims, + index_vector_dim=index_vector_dim, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices) + + def conv(self, + kernel, + window_strides, + feature_group_count=1, + batch_group_count=1, + padding='valid', + precision_config=None): + return mkop( + 'Conv', (self.op, kernel.op), + window_strides=window_strides, + feature_group_count=feature_group_count, + batch_group_count=batch_group_count, + padding=padding, + precision_config=precision_config) + + def cast(self, to_type): + return mkop('Convert', (self.op,), to_type=to_type) + + def bitcast(self, to_type): + return mkop('BitcastConvert', (self.op,), to_type=to_type) + + def pad(self, value, config): + return mkop('Pad', (self.op, value.op), config=config) + + def reduce(self, init_value, computation, dimensions): + return mkop( + 'Reduce', (self.op, init_value.op), + computation=computation, + dimensions=dimensions) + + def select(self, true_value, false_value): + return mkop('Select', (self.op, true_value.op, false_value.op)) + + def transpose(self, permutation): + return mkop('Transpose', (self.op,), permutation=permutation) + + def acos(self): + return mkop('Acos', (self.op,)) + + def asin(self): + return mkop('Asin', (self.op,)) + + def atan(self): + return mkop('Atan', (self.op,)) + + def ceil(self): + return mkop('Ceil', (self.op,)) + + def cos(self): + return mkop('Cos', (self.op,)) + + def cosh(self): + return mkop('Cosh', (self.op,)) + + def erf(self): + return mkop('Erf', (self.op,)) + + def erfc(self): + return mkop('Erfc', (self.op,)) + + def erfinf(self): + return mkop('ErfInv', (self.op,)) + + def exp(self): + return mkop('Exp', (self.op,)) + + def expm1(self): + return mkop('Expm1', (self.op,)) + + def floor(self): + return mkop('Floor', (self.op,)) + + def log(self): + return mkop('Log', (self.op,)) + + def log1p(self): + return mkop('Log1p', (self.op,)) + + def sqrt(self): + return mkop('Sqrt', (self.op,)) + + def rsqrt(self): + return mkop('Rsqrt', (self.op,)) + + def sin(self): + return mkop('Sin', (self.op,)) + + def sinh(self): + return mkop('Sinh', (self.op,)) + + def tan(self): + return mkop('Tan', (self.op,)) + + def tanh(self): + return mkop('Tanh', (self.op,)) + + def atan2(self, other): + return mkop('Atan2', (self.op, other.op)) + + def max(self, other): + return mkop('Max', (self.op, other.op)) + + def min(self, other): + return mkop('Min', (self.op, other.op)) + + @classmethod + def tuple(cls, ops, builder=None): + return mkop('Tuple', [x.op for x in ops], builder=builder) + + @classmethod + def call(cls, computation, ops, builder=None): + return mkop( + 'Call', [x.op for x in ops], computation=computation, builder=builder) + + @classmethod + def constant(cls, builder, value): + return mkleaf('Constant', builder, value=value) + + @classmethod + def iota(cls, builder, shape, iota_dimension): + return mkleaf('Iota', builder, shape=shape, iota_dimension=iota_dimension) + + @classmethod + def sort(cls, ops, comparator, dimension=None, is_stable=None): + return mkop( + 'Sort', [x.op for x in ops], + comparator=comparator, + dimension=dimension, + is_stable=is_stable) + + +def create_builder(name): + return torch_xla._XLAC._xla_op_create_builder(name) + + +def mkshape(stype, dims): + return (str(stype), tuple(dims)) + + +def mkop(name, ops, **kwargs): + builder = kwargs.get('builder', None) + if builder is None: + assert ops + builder = torch_xla._XLAC._xla_op_builder(ops[0]) + return Op(torch_xla._XLAC._xla_op_create(builder, name, ops, kwargs)) + + +def mkleaf(name, builder, **kwargs): + return Op(torch_xla._XLAC._xla_op_create(builder, name, (), kwargs)) + + +def mkparam(builder, param_no, shape): + return Op(torch_xla._XLAC._xla_op_param(builder, param_no, shape)) + + +def tensor_shape(tensor, device=''): + if isinstance(tensor, (list, tuple)): + return [torch_xla._XLAC._xla_op_tensor_shape(t, device) for t in tensor] + return torch_xla._XLAC._xla_op_tensor_shape(tensor, device) + + +def create_computation(name, fn, shapes, **kwargs): + builder = create_builder(name) + params = [] + for shape in shapes: + p = mkparam(builder, len(params), shape) + params.append(p) + + root = fn(*params, **kwargs) + return root.build(name) + + +def get_computation_hlo(computation): + return torch_xla._XLAC._xla_computation_text(computation) diff --git a/torch_xla/core/xla_op_registry.py b/torch_xla/core/xla_op_registry.py new file mode 100644 index 00000000000..7686574e696 --- /dev/null +++ b/torch_xla/core/xla_op_registry.py @@ -0,0 +1,81 @@ +from __future__ import division +from __future__ import print_function + +import pickle +import sys +import threading +import torch_xla +import torch_xla.core.xla_builder as xb +import torch_xla.utils.utils as xu + + +class Op(object): + """Creates a PyTorch operation with an XLA lowering function. + + Args: + name (str): The name of the operation. + opfn (callable): The function implementing the XLA lowering. + """ + + def __init__(self, name, opfn): + self._name = name + self._opfn = opfn + self._opname = 'xla::_op_' + name + self._lock = threading.Lock() + self._computations = dict() + + def __call__(self, *args, **kwargs): + """Perform the PyTorch operation based on XLA tensors. + + Args: + args: The PyTorch XLA tensors which are inputs of the operation. + kwargs: Keyword arguments passed to the lowering function. These are + Python scalars and cannot be XLA tensors. + Returns: + The PyTorch tensors wrapping the values returned by XLA lowering function. + """ + shapes = xb.tensor_shape(args) + key = pickle.dumps([shapes, kwargs]) + with self._lock: + computation = self._computations.get(key, None) + if computation is None: + computation = xb.create_computation(self._name, self._opfn, shapes, + **kwargs) + self._computations[key] = computation + if xu.getenv_as('XLA_OP_PRINT_COMPUTATIONS', bool, False): + print(xb.get_computation_hlo(computation), file=sys.stderr) + result = torch_xla._XLAC._xla_user_computation(self._opname, args, + computation) + return result[0] if len(result) == 1 else result + + +def register(name, opfn): + """Registers a PyTorch operation with an XLA lowering function. + + Example:: + + import torch + import torch_xla + import torch_xla.core.xla_op_registry as xor + import torch_xla.core.xla_model as xm + + def slice_and_add(a, b, **kwargs): + sa = a.slice_in_dim(start_index=0, limit_index=1, dimno=0) + sb = b.slice_in_dim(start_index=1, limit_index=2, dimno=0) + return sa + sb + + xadd = xor.register('slice_and_add', slice_and_add) + device = xm.xla_device() + x = torch.randn(2, 2).to(device) + y = torch.randn(2, 2).to(device) + z = xadd(x, y) + print(z.cpu()) + + Args: + name (str): The name of the operation. + opfn (callable): The function implementing the XLA lowering. + Returns: + The `Op` object which can be called to perform the XLA driven PyTorch + operation. + """ + return Op(name, opfn) diff --git a/torch_xla/csrc/computation.cpp b/torch_xla/csrc/computation.cpp new file mode 100644 index 00000000000..81600c613fb --- /dev/null +++ b/torch_xla/csrc/computation.cpp @@ -0,0 +1,14 @@ +#include "torch_xla/csrc/computation.h" + +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/util.h" + +namespace torch_xla { + +Computation::Computation(std::string name, xla::XlaComputation computation) + : name_(std::move(name)), computation_(std::move(computation)) { + program_shape_ = ConsumeValue(computation_.GetProgramShape()); + hash_ = xla::util::MHash(name_, computation_.proto().SerializeAsString()); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/computation.h b/torch_xla/csrc/computation.h new file mode 100644 index 00000000000..b93634c4c2d --- /dev/null +++ b/torch_xla/csrc/computation.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_client/types.h" + +namespace torch_xla { + +class Computation { + public: + Computation(std::string name, xla::XlaComputation computation); + + const std::string& name() const { return name_; } + + const xla::XlaComputation& computation() const { return computation_; } + + const xla::ProgramShape& program_shape() const { return program_shape_; } + + const xla::hash_t& hash() const { return hash_; } + + private: + std::string name_; + xla::XlaComputation computation_; + xla::ProgramShape program_shape_; + xla::hash_t hash_; +}; + +using ComputationPtr = std::shared_ptr; + +} // namespace torch_xla diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index d907d80f034..1da16fefb7b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -7,6 +7,7 @@ #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "tensorflow/compiler/xla/xla_client/mesh_service.h" #include "tensorflow/compiler/xla/xla_client/metrics.h" @@ -15,6 +16,7 @@ #include "tensorflow/compiler/xla/xla_client/record_reader.h" #include "tensorflow/compiler/xla/xla_client/thread_pool.h" #include "tensorflow/compiler/xla/xla_client/util.h" +#include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/platform/env.h" @@ -23,6 +25,7 @@ #include "torch/csrc/jit/python/pybind.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/aten_xla_type.h" +#include "torch_xla/csrc/computation.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir_dump_util.h" @@ -33,6 +36,7 @@ #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/version.h" +#include "torch_xla/csrc/xla_op_builder.h" namespace torch_xla { namespace { @@ -537,6 +541,36 @@ py::object XlaNms(const at::Tensor& boxes, const at::Tensor& scores, return result_tuple; } +std::vector XlaUserComputation( + const std::string& opname, const std::vector& inputs, + ComputationPtr computation) { + std::vector xinputs = GetXlaTensors(inputs, /*want_all=*/true); + std::vector xresults = + XLATensor::user_computation(opname, xinputs, std::move(computation)); + std::vector results; + for (auto& xresult : xresults) { + at::Tensor tensor = bridge::AtenFromXlaTensor(std::move(xresult)); + results.push_back( + torch::autograd::make_variable(tensor, /*requires_grad=*/false)); + } + return results; +} + +ComputationPtr CreateComputation(const std::string& name, xla::XlaOp root) { + xla::XlaComputation computation = ConsumeValue(root.builder()->Build(root)); + return std::make_shared(name, std::move(computation)); +} + +xla::Shape GetTensorShape(const at::Tensor& tensor, + const std::string& device_str) { + auto xtensor = bridge::TryGetXlaTensor(tensor); + if (xtensor) { + return xtensor->shape(); + } + Device device = GetDeviceOrCurrent(device_str); + return CreateComputationShapeFromTensor(tensor, &device); +} + void InitXlaModuleBindings(py::module m) { m.def("_initialize_aten_bindings", []() { AtenXlaType::InitializeAtenBindings(); }); @@ -551,6 +585,16 @@ void InitXlaModuleBindings(py::module m) { xla::int64 output_size) { return XlaNms(boxes, scores, score_threshold, iou_threshold, output_size); }); + m.def("_xla_user_computation", + [](const std::string& opname, const std::vector& inputs, + const ComputationPtr& computation) { + std::vector results; + { + NoGilSection nogil; + results = XlaUserComputation(opname, inputs, computation); + } + return results; + }); m.def("_get_xla_tensors_dot", [](const std::vector& tensors) -> std::string { auto coverter = [](absl::Span nodes) { @@ -838,6 +882,53 @@ void InitXlaModuleBindings(py::module m) { NoGilSection nogil; RemoveTfFile(path); }); + + py::class_(m, "XlaBuilder"); + py::class_(m, "XlaOp"); + py::class_(m, "XlaComputation"); + m.def("_xla_op_create_builder", [](const std::string& name) { + return std::make_shared(name); + }); + m.def("_xla_op_tensor_shape", + [](const at::Tensor& tensor, const std::string& device) { + xla::Shape tensor_shape = GetTensorShape(tensor, device); + return op_builder::ShapeToPyShape(tensor_shape); + }); + m.def("_xla_op_param", [](op_builder::BuilderPtr builder, xla::int64 param_no, + py::object py_shape) { + xla::Shape shape = op_builder::PyShapeToShape(py_shape); + xla::XlaOp param = xla::Parameter(builder.get(), param_no, shape, + absl::StrCat("p", param_no)); + return std::make_shared(std::move(builder), + std::move(param)); + }); + m.def("_xla_op_build", [](const std::string& name, op_builder::OpPtr root) { + ComputationPtr computation; + { + NoGilSection nogil; + computation = CreateComputation(name, root->op); + } + return computation; + }); + m.def("_xla_computation_text", [](const ComputationPtr& computation) { + std::string hlo_text; + { + NoGilSection nogil; + hlo_text = ConsumeValue( + xla::util::GetComputationHloText(computation->computation())); + } + return hlo_text; + }); + m.def("_xla_op_shape", [](op_builder::OpPtr op) { + const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(op->op); + return op_builder::ShapeToPyShape(shape); + }); + m.def("_xla_op_builder", [](op_builder::OpPtr op) { return op->builder; }); + m.def("_xla_op_create", + [](op_builder::BuilderPtr builder, const std::string& opname, + const std::vector& operands, py::dict args) { + return op_builder::CreateOp(builder, opname, operands, args); + }); } } // namespace diff --git a/torch_xla/csrc/ops/user_computation.cpp b/torch_xla/csrc/ops/user_computation.cpp new file mode 100644 index 00000000000..f70298d0452 --- /dev/null +++ b/torch_xla/csrc/ops/user_computation.cpp @@ -0,0 +1,54 @@ +#include "torch_xla/csrc/ops/user_computation.h" + +#include "torch_xla/csrc/lowering_context.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +size_t GetNumOutputs(const xla::Shape& shape) { + return shape.IsTuple() ? shape.tuple_shapes_size() : 1; +} + +} // namespace + +UserComputation::UserComputation(OpKind op, OpList operands, + ComputationPtr computation) + : Node(std::move(op), operands, computation->program_shape().result(), + GetNumOutputs(computation->program_shape().result()), + computation->hash()), + computation_(std::move(computation)) {} + +NodePtr UserComputation::Clone(OpList operands) const { + return MakeNode(op(), operands, computation_); +} + +XlaOpVector UserComputation::Lower(LoweringContext* loctx) const { + std::vector inputs; + for (auto& op : operands()) { + inputs.push_back(loctx->GetOutputOp(op)); + } + xla::XlaOp output = + xla::Call(loctx->builder(), computation_->computation(), inputs); + XlaOpVector results; + const xla::Shape& result_shape = computation_->program_shape().result(); + if (result_shape.IsTuple()) { + for (xla::int64 i = 0; i < result_shape.tuple_shapes_size(); ++i) { + results.push_back(xla::GetTupleElement(output, i)); + } + } else { + results.push_back(output); + } + return ReturnOps(results, loctx); +} + +std::string UserComputation::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", computation=" << computation_->name(); + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/user_computation.h b/torch_xla/csrc/ops/user_computation.h new file mode 100644 index 00000000000..84fd55d9c80 --- /dev/null +++ b/torch_xla/csrc/ops/user_computation.h @@ -0,0 +1,28 @@ +#pragma once + +#include "torch_xla/csrc/computation.h" +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class UserComputation : public Node { + public: + UserComputation(OpKind op, OpList operands, ComputationPtr computation); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const ComputationPtr& computation() const { return computation_; } + + private: + ComputationPtr computation_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 615b71521b7..acd19b28944 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -14,6 +14,7 @@ #include "tensorflow/compiler/xla/xla_client/multi_wait.h" #include "tensorflow/compiler/xla/xla_client/util.h" #include "torch/csrc/autograd/variable.h" +#include "torch_xla/csrc/computation.h" #include "torch_xla/csrc/cross_replica_reduces.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/ir.h" @@ -203,6 +204,10 @@ class XLATensor { static XLATensor get_dimensions_size(const XLATensor& input, std::vector dimensions); + static std::vector user_computation( + const std::string& opname, absl::Span inputs, + ComputationPtr computation); + ////////////////////////////////////////////////////////////////////////////// // ATEN operators follows here, listed in alphabetical order. ////////////////////////////////////////////////////////////////////////////// diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 71c9377ff46..1b6dfc0337d 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -119,6 +119,7 @@ #include "torch_xla/csrc/ops/upsample_bilinear2d_backward.h" #include "torch_xla/csrc/ops/upsample_nearest2d.h" #include "torch_xla/csrc/ops/upsample_nearest2d_backward.h" +#include "torch_xla/csrc/ops/user_computation.h" #include "torch_xla/csrc/ops/view.h" #include "torch_xla/csrc/shape_builder.h" #include "torch_xla/csrc/tensor.h" @@ -350,6 +351,19 @@ XLATensor XLATensor::get_dimensions_size(const XLATensor& input, at::ScalarType::Int); } +std::vector XLATensor::user_computation( + const std::string& opname, absl::Span inputs, + ComputationPtr computation) { + XLA_CHECK(!inputs.empty()); + std::vector input_values; + for (auto& input : inputs) { + input_values.push_back(input.GetIrValue()); + } + ir::NodePtr node = ir::MakeNode( + ir::OpKind::Get(opname), input_values, std::move(computation)); + return inputs.front().MakeOutputTensors(node); +} + ////////////////////////////////////////////////////////////////////////////// // ATEN operators follows here, listed in alphabetical order. ////////////////////////////////////////////////////////////////////////////// diff --git a/torch_xla/csrc/xla_op_builder.cpp b/torch_xla/csrc/xla_op_builder.cpp new file mode 100644 index 00000000000..fd4deffefb3 --- /dev/null +++ b/torch_xla/csrc/xla_op_builder.cpp @@ -0,0 +1,550 @@ +#include "torch_xla/csrc/xla_op_builder.h" + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/lib/logdet.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "torch_xla/csrc/computation.h" +#include "torch_xla/csrc/convert_ops.h" +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/tensor_util.h" + +namespace torch_xla { +namespace op_builder { +namespace { + +typedef xla::XlaOp (*XlaOpFunction)(const BuilderPtr&, + const std::vector&, py::dict); +using XlaOpFunctionMap = std::map; + +#define XLA_UNARY_OP(name) \ + xla::XlaOp name(const BuilderPtr&, const std::vector& operands, \ + py::dict /* args */) { \ + return xla::name(operands.at(0)->op); \ + } + +#define XLA_BINARY_OP(name) \ + xla::XlaOp name(const BuilderPtr&, const std::vector& operands, \ + py::dict /* args */) { \ + return xla::name(operands.at(0)->op, operands.at(1)->op); \ + } + +XLA_UNARY_OP(Abs); +XLA_UNARY_OP(Acos); +XLA_UNARY_OP(Asin); +XLA_UNARY_OP(Atan); +XLA_UNARY_OP(Ceil); +XLA_UNARY_OP(Cos); +XLA_UNARY_OP(Cosh); +XLA_UNARY_OP(Erf); +XLA_UNARY_OP(Erfc); +XLA_UNARY_OP(ErfInv); +XLA_UNARY_OP(Exp); +XLA_UNARY_OP(Expm1); +XLA_UNARY_OP(Floor); +XLA_UNARY_OP(Log); +XLA_UNARY_OP(Log1p); +XLA_UNARY_OP(Neg); +XLA_UNARY_OP(Not); +XLA_UNARY_OP(Sqrt); +XLA_UNARY_OP(Rsqrt); +XLA_UNARY_OP(Sin); +XLA_UNARY_OP(Sinh); +XLA_UNARY_OP(Tan); +XLA_UNARY_OP(Tanh); + +XLA_BINARY_OP(Add); +XLA_BINARY_OP(And); +XLA_BINARY_OP(Atan2); +XLA_BINARY_OP(Div); +XLA_BINARY_OP(Eq); +XLA_BINARY_OP(Ge); +XLA_BINARY_OP(Gt); +XLA_BINARY_OP(Le); +XLA_BINARY_OP(Lt); +XLA_BINARY_OP(Max); +XLA_BINARY_OP(Min); +XLA_BINARY_OP(Mul); +XLA_BINARY_OP(Ne); +XLA_BINARY_OP(Or); +XLA_BINARY_OP(Pow); +XLA_BINARY_OP(Rem); +XLA_BINARY_OP(Sub); +XLA_BINARY_OP(Xor); + +template +std::vector GetTupleVector(py::tuple tuple) { + std::vector values; + values.reserve(tuple.size()); + for (auto& v : tuple) { + values.push_back(v.cast()); + } + return values; +} + +template +absl::optional ArgOptional(py::dict args, const char* name) { + if (!args.contains(name)) { + return absl::nullopt; + } + auto value = args[name]; + if (value.is_none()) { + return absl::nullopt; + } + return value.cast(); +} + +template +T ArgOrDefault(py::dict args, const char* name, T defval) { + absl::optional value = ArgOptional(args, name); + return value.value_or(defval); +} + +std::vector ExtractXlaOps(const std::vector& operands) { + std::vector ops; + for (auto& operand : operands) { + ops.push_back(operand->op); + } + return ops; +} + +std::vector GetOpVector(py::tuple tuple) { + std::vector ops; + for (auto& op : tuple) { + ops.push_back(op.cast()->op); + } + return ops; +} + +xla::XlaOp Reshape(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + std::vector sizes = GetTupleVector(args["sizes"]); + absl::optional arg_dimensions = + ArgOptional(args, "dimensions"); + if (arg_dimensions) { + std::vector dimensions = + GetTupleVector(*arg_dimensions); + return xla::Reshape(operands.at(0)->op, dimensions, sizes); + } + xla::int64 inferred_dimension = + ArgOrDefault(args, "inferred_dimension", -1); + if (inferred_dimension >= 0) { + return xla::ReshapeWithInferredDimension(operands.at(0)->op, sizes, + inferred_dimension); + } + return xla::Reshape(operands.at(0)->op, sizes); +} + +xla::XlaOp DynamicReshape(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + std::vector sizes = GetTupleVector(args["sizes"]); + return XlaHelpers::DynamicReshape(operands.at(0)->op, sizes); +} + +xla::XlaOp Broadcast(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + std::vector sizes = GetTupleVector(args["sizes"]); + return xla::Broadcast(operands.at(0)->op, sizes); +} + +xla::XlaOp BroadcastInDim(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + std::vector sizes = GetTupleVector(args["sizes"]); + std::vector dimensions = + GetTupleVector(args["dimensions"]); + return xla::BroadcastInDim(operands.at(0)->op, sizes, dimensions); +} + +xla::XlaOp Tuple(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + std::vector ops = ExtractXlaOps(operands); + return xla::Tuple(builder.get(), ops); +} + +xla::PrecisionConfig DotPrecisonConfig(py::dict args) { + xla::PrecisionConfig::Precision precision = XlaHelpers::mat_mul_precision(); + absl::optional arg_precision_config = + ArgOptional(args, "precision_config"); + if (arg_precision_config) { + if (*arg_precision_config == "default") { + precision = xla::PrecisionConfig::DEFAULT; + } else if (*arg_precision_config == "high") { + precision = xla::PrecisionConfig::HIGH; + } else if (*arg_precision_config == "highest") { + precision = xla::PrecisionConfig::HIGHEST; + } + } + return XlaHelpers::BuildPrecisionConfig(precision); +} + +xla::XlaOp Dot(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + xla::PrecisionConfig precision_config = DotPrecisonConfig(args); + return xla::Dot(operands.at(0)->op, operands.at(1)->op, &precision_config); +} + +xla::XlaOp Constant(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + at::Tensor tensor = args["value"].cast(); + xla::Literal literal = + GetTensorLiteral(tensor, /*shape=*/nullptr, /*device=*/nullptr); + return xla::ConstantLiteral(builder.get(), literal); +} + +xla::PaddingConfig ParsePaddingConfig(py::tuple cfg) { + xla::PaddingConfig pad_config; + for (auto& dimp : cfg) { + py::tuple dims = dimp.cast(); + XLA_CHECK_EQ(dims.size(), 3); + auto dim = pad_config.add_dimensions(); + dim->set_edge_padding_low(dims[0].cast()); + dim->set_edge_padding_high(dims[1].cast()); + dim->set_interior_padding(dims[2].cast()); + } + return pad_config; +} + +xla::XlaOp Pad(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + xla::PaddingConfig pad_config = ParsePaddingConfig(args["config"]); + return xla::Pad(operands.at(0)->op, operands.at(1)->op, pad_config); +} + +xla::XlaOp Transpose(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + std::vector permutation = + GetTupleVector(args["permutation"]); + return xla::Transpose(operands.at(0)->op, permutation); +} + +xla::Padding ParseConvPadding(const std::string& padding_str) { + if (padding_str == "same") { + return xla::Padding::kSame; + } + if (padding_str == "valid") { + return xla::Padding::kValid; + } + XLA_ERROR() << "Invalid padding: " << padding_str; +} + +xla::XlaOp Conv(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + std::vector window_strides = + GetTupleVector(args["window_strides"]); + xla::int64 feature_group_count = + ArgOrDefault(args, "feature_group_count", 1); + xla::int64 batch_group_count = + ArgOrDefault(args, "batch_group_count", 1); + xla::Padding padding = ParseConvPadding(args["padding"].cast()); + xla::PrecisionConfig precision_config = DotPrecisonConfig(args); + return xla::Conv(operands.at(0)->op, operands.at(1)->op, window_strides, + padding, feature_group_count, batch_group_count, + &precision_config); +} + +xla::XlaOp Slice(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + std::vector start_indices = + GetTupleVector(args["start_indices"]); + std::vector limit_indices = + GetTupleVector(args["limit_indices"]); + std::vector strides = GetTupleVector(args["strides"]); + return xla::Slice(operands.at(0)->op, start_indices, limit_indices, strides); +} + +xla::XlaOp SliceInDim(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + xla::int64 start_index = args["start_index"].cast(); + xla::int64 limit_index = args["limit_index"].cast(); + xla::int64 dimno = args["dimno"].cast(); + xla::int64 stride = ArgOrDefault(args, "stride", 1); + return xla::SliceInDim(operands.at(0)->op, start_index, limit_index, stride, + dimno); +} + +xla::XlaOp DynamicSlice(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + std::vector slice_sizes = + GetTupleVector(args["slice_sizes"]); + std::vector start_indices = + GetOpVector(args["start_indices"].cast()); + return xla::DynamicSlice(operands.at(0)->op, start_indices, slice_sizes); +} + +xla::XlaOp DynamicUpdateSlice(const BuilderPtr& builder, + const std::vector& operands, + py::dict args) { + std::vector start_indices = + GetOpVector(args["start_indices"].cast()); + return xla::DynamicUpdateSlice(operands.at(0)->op, operands.at(1)->op, + start_indices); +} + +xla::XlaOp Reduce(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + std::vector dimensions = + GetTupleVector(args["dimensions"]); + ComputationPtr computation = args["computation"].cast(); + return xla::Reduce(operands.at(0)->op, operands.at(1)->op, + computation->computation(), dimensions); +} + +xla::XlaOp Call(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + ComputationPtr computation = args["computation"].cast(); + std::vector ops = ExtractXlaOps(operands); + return xla::Call(builder.get(), computation->computation(), ops); +} + +xla::XlaOp Select(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + return xla::Select(operands.at(0)->op, operands.at(1)->op, + operands.at(2)->op); +} + +xla::XlaOp ShiftLeft(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + return xla::ShiftLeft(operands.at(0)->op, operands.at(1)->op); +} + +xla::XlaOp ShifRight(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + return operands.at(0)->op >> operands.at(1)->op; +} + +xla::GatherDimensionNumbers ParseGatherDimensionNumbers(py::dict args) { + xla::GatherDimensionNumbers dimension_numbers; + absl::optional arg_offset_dims = + ArgOptional(args, "offset_dims"); + if (arg_offset_dims) { + for (auto& dim : *arg_offset_dims) { + dimension_numbers.add_offset_dims(dim.cast()); + } + } + absl::optional arg_collapsed_slice_dims = + ArgOptional(args, "collapsed_slice_dims"); + if (arg_collapsed_slice_dims) { + for (auto& dim : *arg_collapsed_slice_dims) { + dimension_numbers.add_collapsed_slice_dims(dim.cast()); + } + } + absl::optional arg_start_index_map = + ArgOptional(args, "start_index_map"); + if (arg_start_index_map) { + for (auto& dim : *arg_start_index_map) { + dimension_numbers.add_start_index_map(dim.cast()); + } + } + absl::optional arg_index_vector_dim = + ArgOptional(args, "index_vector_dim"); + if (arg_index_vector_dim) { + dimension_numbers.set_index_vector_dim(*arg_index_vector_dim); + } + return dimension_numbers; +} + +xla::XlaOp Gather(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + std::vector slice_sizes = + GetTupleVector(args["slice_sizes"]); + bool indices_are_sorted = + ArgOrDefault(args, "indices_are_sorted", false); + xla::GatherDimensionNumbers dimension_numbers = + ParseGatherDimensionNumbers(args); + return xla::Gather(operands.at(0)->op, operands.at(1)->op, dimension_numbers, + slice_sizes, indices_are_sorted); +} + +xla::ScatterDimensionNumbers ParseScatterDimensionNumbers(py::dict args) { + xla::ScatterDimensionNumbers dimension_numbers; + absl::optional arg_update_window_dims = + ArgOptional(args, "update_window_dims"); + if (arg_update_window_dims) { + for (auto& dim : *arg_update_window_dims) { + dimension_numbers.add_update_window_dims(dim.cast()); + } + } + absl::optional arg_inserted_window_dims = + ArgOptional(args, "inserted_window_dims"); + if (arg_inserted_window_dims) { + for (auto& dim : *arg_inserted_window_dims) { + dimension_numbers.add_inserted_window_dims(dim.cast()); + } + } + absl::optional arg_index_vector_dim = + ArgOptional(args, "index_vector_dim"); + if (arg_index_vector_dim) { + dimension_numbers.set_index_vector_dim(*arg_index_vector_dim); + } + return dimension_numbers; +} + +xla::XlaOp Scatter(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + bool indices_are_sorted = + ArgOrDefault(args, "indices_are_sorted", false); + bool unique_indices = ArgOrDefault(args, "unique_indices", false); + ComputationPtr computation = args["computation"].cast(); + xla::ScatterDimensionNumbers dimension_numbers = + ParseScatterDimensionNumbers(args); + return xla::Scatter(operands.at(0)->op, operands.at(1)->op, + operands.at(2)->op, computation->computation(), + dimension_numbers, indices_are_sorted, unique_indices); +} + +xla::XlaOp Sort(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + bool is_stable = ArgOrDefault(args, "is_stable", false); + xla::int64 dimension = ArgOrDefault(args, "dimension", -1); + ComputationPtr comparator = args["comparator"].cast(); + std::vector ops = ExtractXlaOps(operands); + return xla::Sort(ops, comparator->computation(), dimension, is_stable); +} + +xla::XlaOp Iota(const BuilderPtr& builder, const std::vector& operands, + py::dict args) { + xla::Shape shape = PyShapeToShape(args["shape"]); + xla::int64 iota_dimension = args["iota_dimension"].cast(); + return xla::Iota(builder.get(), shape, iota_dimension); +} + +xla::XlaOp Convert(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + std::string type = args["to_type"].cast(); + xla::PrimitiveType xla_type = + ConsumeValue(xla::primitive_util::StringToPrimitiveType(type)); + return MaybeConvertTo(operands.at(0)->op, xla_type); +} + +xla::XlaOp BitcastConvert(const BuilderPtr& builder, + const std::vector& operands, py::dict args) { + std::string type = args["to_type"].cast(); + xla::PrimitiveType xla_type = + ConsumeValue(xla::primitive_util::StringToPrimitiveType(type)); + return xla::BitcastConvertType(operands.at(0)->op, xla_type); +} + +const XlaOpFunctionMap* CreateXlaOpFunctionMap() { + XlaOpFunctionMap* fn_map = new XlaOpFunctionMap(); + +#define XLA_OPADD(name) fn_map->emplace(#name, name) + + XLA_OPADD(Abs); + XLA_OPADD(Add); + XLA_OPADD(And); + XLA_OPADD(Acos); + XLA_OPADD(Asin); + XLA_OPADD(Atan2); + XLA_OPADD(Atan); + XLA_OPADD(BitcastConvert); + XLA_OPADD(Broadcast); + XLA_OPADD(BroadcastInDim); + XLA_OPADD(Call); + XLA_OPADD(Ceil); + XLA_OPADD(Constant); + XLA_OPADD(Conv); + XLA_OPADD(Convert); + XLA_OPADD(Cos); + XLA_OPADD(Cosh); + XLA_OPADD(Div); + XLA_OPADD(Dot); + XLA_OPADD(DynamicReshape); + XLA_OPADD(DynamicSlice); + XLA_OPADD(DynamicUpdateSlice); + XLA_OPADD(Eq); + XLA_OPADD(Erf); + XLA_OPADD(Erfc); + XLA_OPADD(ErfInv); + XLA_OPADD(Exp); + XLA_OPADD(Expm1); + XLA_OPADD(Floor); + XLA_OPADD(Gather); + XLA_OPADD(Ge); + XLA_OPADD(Gt); + XLA_OPADD(Iota); + XLA_OPADD(Le); + XLA_OPADD(Log); + XLA_OPADD(Log1p); + XLA_OPADD(Lt); + XLA_OPADD(Max); + XLA_OPADD(Min); + XLA_OPADD(Mul); + XLA_OPADD(Ne); + XLA_OPADD(Neg); + XLA_OPADD(Not); + XLA_OPADD(Or); + XLA_OPADD(Pad); + XLA_OPADD(Pow); + XLA_OPADD(Reduce); + XLA_OPADD(Rem); + XLA_OPADD(Reshape); + XLA_OPADD(Rsqrt); + XLA_OPADD(Scatter); + XLA_OPADD(Select); + XLA_OPADD(ShiftLeft); + XLA_OPADD(ShifRight); + XLA_OPADD(Sin); + XLA_OPADD(Sinh); + XLA_OPADD(Slice); + XLA_OPADD(SliceInDim); + XLA_OPADD(Sort); + XLA_OPADD(Sqrt); + XLA_OPADD(Sub); + XLA_OPADD(Tan); + XLA_OPADD(Tanh); + XLA_OPADD(Transpose); + XLA_OPADD(Tuple); + XLA_OPADD(Xor); + +#undef XLA_OPADD + + return fn_map; +} + +const XlaOpFunctionMap* GetXlaOpFunctionMap() { + static const XlaOpFunctionMap* fn_map = CreateXlaOpFunctionMap(); + return fn_map; +} + +} // namespace + +py::object ShapeToPyShape(const xla::Shape& shape) { + py::tuple py_shape(2); + py_shape[0] = py::cast( + xla::primitive_util::LowercasePrimitiveTypeName(shape.element_type())); + auto sizes = py::tuple(shape.rank()); + for (xla::int64 i = 0; i < shape.rank(); ++i) { + sizes[i] = py::cast(shape.dimensions(i)); + } + py_shape[1] = sizes; + return py_shape; +} + +xla::Shape PyShapeToShape(py::object shape) { + py::tuple py_shape = shape.cast(); + std::string type = py_shape[0].cast(); + std::vector dimensions = + GetTupleVector(py_shape[1].cast()); + xla::PrimitiveType xla_type = + ConsumeValue(xla::primitive_util::StringToPrimitiveType(type)); + return xla::ShapeUtil::MakeShape(xla_type, dimensions); +} + +OpPtr CreateOp(BuilderPtr builder, const std::string& opname, + const std::vector& operands, py::dict args) { + const XlaOpFunctionMap* fn_map = GetXlaOpFunctionMap(); + auto it = fn_map->find(opname); + if (it == fn_map->end()) { + XLA_ERROR() << "Unknown XLA op name: " << opname; + } + xla::XlaOp result = (*it->second)(builder, operands, args); + return std::make_shared(std::move(builder), std::move(result)); +} + +} // namespace op_builder +} // namespace torch_xla diff --git a/torch_xla/csrc/xla_op_builder.h b/torch_xla/csrc/xla_op_builder.h new file mode 100644 index 00000000000..34daf0cfe59 --- /dev/null +++ b/torch_xla/csrc/xla_op_builder.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/python/pybind.h" + +namespace torch_xla { +namespace op_builder { + +using BuilderPtr = std::shared_ptr; + +struct Op { + Op(BuilderPtr builder, xla::XlaOp op) + : builder(std::move(builder)), op(std::move(op)) {} + + BuilderPtr builder; + xla::XlaOp op; +}; + +using OpPtr = std::shared_ptr; + +py::object ShapeToPyShape(const xla::Shape& shape); + +xla::Shape PyShapeToShape(py::object shape); + +OpPtr CreateOp(BuilderPtr builder, const std::string& opname, + const std::vector& operands, py::dict args); + +} // namespace op_builder +} // namespace torch_xla