diff --git a/test/custom_debug_lowering.py b/test/custom_debug_lowering.py index cc19dd79b81f..4cc16fa5609d 100644 --- a/test/custom_debug_lowering.py +++ b/test/custom_debug_lowering.py @@ -9,6 +9,12 @@ class_count = defaultdict(int) instance_count = dict() +# This is a sample implementation for readying object +# hierachies from a source stack usng a TorchDispatch +# interceptor. We then set the node op_name in XLA +# via the output tensor and direct XLA to ignore stack +# frames added (due to TorchDispatch) during lowering + def GetInstancePlaceHolder(class_type, obj): global class_count @@ -172,11 +178,21 @@ def CleanNames(names): def GetAllObjectAndClassNames(frame): names = [] + frame_count = 0 + self_found = False while frame is not None: + if __file__ == frame.f_code.co_filename: + self_found = True + + if not self_found: + frame = frame.f_back + continue + name = GetClassNameAndObjFromFrame(frame) if len(name) > 0: names.append(name) frame = frame.f_back + frame_count += 1 names.reverse() @@ -187,7 +203,24 @@ def GetAllObjectAndClassNames(frame): if len(output) > 0: output += "/" - return output + return output, frame_count - 1 + + +class StackLayerSignature: + + def __init__(self, filename, func, line): + self.filename = filename + self.func = func + self.line = line + + def __str__(self): + return f"{self.filename}|{self.func}|{self.line}" + + def __repr__(self): + return str(self) + + def __eq__(self, ref): + return self.filename == ref.filename and self.func == ref.func and self.line == ref.line class CustomOpNameLowering(TorchDispatchMode): @@ -198,16 +231,38 @@ def __init__(self): def __enter__(self): self._old_ir_debug = torch_xla._XLAC._get_ir_debug() torch_xla._XLAC._set_ir_debug(True) + self.stack_sigs = [] return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): torch_xla._XLAC._set_ir_debug(self._old_ir_debug) + del self.stack_sigs super().__exit__(exc_type, exc_val, exc_tb) + def add_stack_sig(self, frame, depth): + stack = [] + for s in inspect.getouterframes(frame): + sls = StackLayerSignature(s.filename, s.function, s.lineno) + stack.append(sls) + + # Pop the top two stack laters + while len(stack) > depth: + stack.pop(0) + + assert len(stack) == depth + + self.stack_sigs.append(stack) + + return stack + def __torch_dispatch__(self, func, types, args=(), kwargs={}): res = func(*args, **kwargs) if 'xla' in str(res.device): frame = inspect.currentframe() - prefix = GetAllObjectAndClassNames(frame) - torch_xla._XLAC._set_xla_custom_op_name(res, prefix) + prefix, depth = GetAllObjectAndClassNames(frame) + self.depth = depth + self.add_stack_sig(frame, self.depth) + + assert torch_xla._XLAC._set_xla_custom_op_name_prefix( + res, prefix, self.depth), "Custom op set failed" return res diff --git a/test/test_hlo_metadata.py b/test/test_hlo_metadata.py index 5f5ac186395d..82eebd9f3ada 100644 --- a/test/test_hlo_metadata.py +++ b/test/test_hlo_metadata.py @@ -8,7 +8,52 @@ import torch_xla.debug.metrics as met import unittest import json -from custom_debug_lowering import CustomOpNameLowering +import inspect +import copy +from custom_debug_lowering import CustomOpNameLowering, StackLayerSignature + + +class HloStackExtractor: + + def __init__(self, hlo_json): + assert 'stackFrameIndex' in hlo_json + assert 'fileLocations' in hlo_json['stackFrameIndex'] + assert 'stackFrames' in hlo_json['stackFrameIndex'] + assert 'fileNames' in hlo_json['stackFrameIndex'] + assert 'functionNames' in hlo_json['stackFrameIndex'] + + self.file_locations = hlo_json['stackFrameIndex']['fileLocations'] + self.stack_frames = hlo_json['stackFrameIndex']['stackFrames'] + self.file_names = hlo_json['stackFrameIndex']['fileNames'] + self.function_names = hlo_json['stackFrameIndex']['functionNames'] + + def extract(self, stack_frame_id): + stack_sigs = [] + + stack_frame = self.stack_frames[stack_frame_id - 1] + + while True: + file_location_id = stack_frame['fileLocationId'] + file_location = self.file_locations[file_location_id - 1] + file_name_id = file_location['fileNameId'] + function_name_id = file_location['functionNameId'] + line = file_location['line'] + file_name = self.file_names[file_name_id - 1] + function_name = self.function_names[function_name_id - 1] + + sig = StackLayerSignature(file_name, function_name, line) + stack_sigs.append(sig) + + stack_frame_id = 0 + if 'parentFrameId' in stack_frame: + stack_frame_id = stack_frame['parentFrameId'] + + if stack_frame_id == 0: + break + else: + stack_frame = self.stack_frames[stack_frame_id - 1] + + return stack_sigs class TestHloMetaData(unittest.TestCase): @@ -32,21 +77,25 @@ def test_metadata(self): nl2 = torch.nn.Tanh() model = torch.nn.Sequential(layer1, nl1, layer2, nl2) - with CustomOpNameLowering(): + with CustomOpNameLowering() as c: model = model.to(device=xm.xla_device()) inp = torch.rand(4, 4, device=xm.xla_device()) + #inp = torch.rand(4, 4) + #inp = inp.to(device=xm.xla_device()) out = model(inp) + # Get outer frames + stack_sigs = c.stack_sigs + ctx = torch_xla._XLAC.lowering.LoweringContext() ctx.build([out]) hlo_text = ctx.hlo_json() # Strings to match in the lowering bingo = { - "torch/_ops.py": False, - #"torch/nn/modules/linear.py": False, - #"torch/nn/modules/activation.py": False, - #"torch/nn/functional.py": False, + "torch/nn/modules/linear.py": False, + "torch/nn/modules/activation.py": False, + "torch/nn/functional.py": False, "Sequential[model]/Linear[0]": False, "Sequential[model]/ReLU[1]": False, "Sequential[model]/Linear[2]": False, @@ -60,10 +109,17 @@ def test_metadata(self): non_zero_metadata = False local_json = json.loads(hlo_text) + + #with open("./hlo.json", "w") as f: + # f.write(json.dumps(local_json, indent=2)) + + hloEx = HloStackExtractor(local_json) + assert "computations" in local_json for c in local_json["computations"]: if "instructions" in c: i = c["instructions"] + for op in i: if 'metadata' in op: meta = op["metadata"] @@ -75,6 +131,27 @@ def test_metadata(self): if isinstance(vm, str) and k in vm: bingo[k] = True + # Decode stack frame id and check it matches one of the + # the passed in stacks + stack_frame_match = False + if 'stackFrameId' in meta: + hlo_stack_sig = hloEx.extract(meta['stackFrameId']) + + for t_sig in stack_sigs: + if len(hlo_stack_sig) == len(t_sig) and hlo_stack_sig == t_sig: + stack_frame_match = True + break + elif len(hlo_stack_sig) > len(t_sig): + hlo_stack_sig_copy = copy.copy(hlo_stack_sig) + discards = [] + while len(hlo_stack_sig_copy) > len(t_sig): + discards.append(hlo_stack_sig_copy.pop(0)) + # Print an error message on a partial match + if hlo_stack_sig_copy == t_sig: + print(f"** PARTIAL MATCH: Discarded {discards}") + + assert stack_frame_match, f"Stack\n{hlo_stack_sig} does not match any of\n{stack_sigs}" + assert non_zero_metadata, "No metadata was lowered - an issue with turning on IR DEBUG?" for k, v in bingo.items(): diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index b18014ab2dfb..b7076102db01 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -294,10 +294,12 @@ ptxla_cc_library( srcs = [ "ir.cpp", "lowering_context.cpp", + "stack_frame_index_builder.cpp", ], hdrs = [ "ir.h", "lowering_context.h", + "stack_frame_index_builder.h", ], deps = [ ":device", diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index fde4ee1c9d70..757585bd18d6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1918,15 +1918,14 @@ void InitXlaModuleBindings(py::module m) { [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { return XLANativeFunctions::set_(self, source); }); - m.def("_set_xla_custom_op_name", - [](const at::Tensor& input, const std::string& op_name) { + m.def("_set_xla_custom_op_name_prefix", + [](const at::Tensor& input, const std::string& op_name_prefix, + size_t max_call_stack_depth) -> bool { XLATensorPtr xtensor = bridge::GetXlaTensor(input); - xtensor->SetCustomOpName(op_name); - }); - m.def("_get_xla_custom_op_name", - [](const at::Tensor& input) -> const std::string& { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - return xtensor->GetCustomOpName(); + std::shared_ptr user_meta = + std::make_shared(op_name_prefix, + max_call_stack_depth); + return xtensor->SetNodeUserMetadata(user_meta); }); m.def("_get_all_reduce_token", [](const std::string& device_str) -> const torch::lazy::Value& { diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index b7cd2025bd34..4381cf164941 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -230,8 +230,16 @@ void XlaNode::UpdateShardingHash() { } } -void XlaNode::SetCustomOpName(const std::string& op_name) { - custom_op_name_ = op_name; +std::shared_ptr XlaNode::SetUserMetadataForSubGraph( + std::shared_ptr user_meta) { + for (auto np : operands_) { + XlaNode* xnp = dynamic_cast(np.get()); + if (xnp != nullptr && xnp->user_metadata() == nullptr) { + xnp->SetUserMetadataForSubGraph(user_meta); + } + } + // Only set if there is no metadata already set + return SetUserMetadata(user_meta); } } // namespace torch_xla diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 1e4a0439e235..ac6ae9ed428b 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -146,8 +147,8 @@ class XlaNode : public torch::lazy::Node { return unbounded_dynamic_dims_; } - void SetCustomOpName(const std::string& op_name); - const std::string& custom_op_name() const { return custom_op_name_; } + std::shared_ptr SetUserMetadataForSubGraph( + std::shared_ptr user_meta); protected: std::unordered_set unbounded_dynamic_dims_; @@ -170,8 +171,6 @@ class XlaNode : public torch::lazy::Node { // Experimental sharding annotations attached to the IR node. std::vector> output_shardings_; - - std::string custom_op_name_; }; inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) { @@ -195,6 +194,15 @@ T* NodeCast(const torch::lazy::Node* node, torch::lazy::OpKind op) { return const_cast(casted); } +struct CustomOpNameMetaData : public torch::lazy::UserMetaData { + CustomOpNameMetaData(const std::string& input_op_name_prefix, + int input_max_stack_depth) + : op_name_prefix(input_op_name_prefix), + max_stack_depth(input_max_stack_depth) {} + std::string op_name_prefix; + size_t max_stack_depth; +}; + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_IR_H_ diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 622e2ef7dd96..474a3f3c0ec7 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -2,8 +2,10 @@ #include +#include #include #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -14,9 +16,11 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/shape_helper.h" +#include "torch_xla/csrc/stack_frame_index_builder.h" #include "torch_xla/csrc/unwrap_data.h" namespace torch_xla { + namespace { class HloMetadataSetter { @@ -51,13 +55,20 @@ class HloMetadataSetter { std::string op_type = absl::StrReplaceAll(node->op().ToString(), {{":", "_"}}); metadata.set_op_type(op_type); + const torch::lazy::MetaData& nmeta = node->metadata(); - std::string op_name_prefix; - const XlaNode* xla_node_cast = dynamic_cast(node); + const CustomOpNameMetaData* custom_opname_meta = + dynamic_cast(node->user_metadata()); - if (xla_node_cast != nullptr && !xla_node_cast->custom_op_name().empty()) { - op_name_prefix = xla_node_cast->custom_op_name(); + std::string op_name_prefix; + size_t max_stack_depth = nmeta.frame_info.size(); + + if (custom_opname_meta != nullptr) { + op_name_prefix = custom_opname_meta->op_name_prefix; + max_stack_depth = custom_opname_meta->max_stack_depth; + } else { + TF_LOG(WARNING) << "No custom opname metadata! op_type=" << op_type; } if (!nmeta.scope.empty()) { @@ -66,12 +77,10 @@ class HloMetadataSetter { } metadata.set_op_name(absl::StrCat(op_name_prefix, op_type)); - if (!nmeta.frame_info.empty()) { - const torch::lazy::SourceLocation& frame = nmeta.frame_info.front(); + // Sets file, line and stack_frame_id in metadata + loctx->stack_frame_index_builder()->AddStackFrameLocations( + nmeta.frame_info, max_stack_depth, metadata); - metadata.set_source_file(frame.file); - metadata.set_source_line(frame.line); - } loctx->builder()->SetOpMetadata(std::move(metadata)); } @@ -82,14 +91,17 @@ class HloMetadataSetter { LoweringContext::LoweringContext(const std::string& name, torch::lazy::BackendDevice device) - : torch::lazy::LoweringContext(name, device), builder_(name) {} + : torch::lazy::LoweringContext(name, device), + builder_(name), + stack_frame_index_builder_(std::make_shared()) {} LoweringContext::LoweringContext( const std::string& name, torch::lazy::BackendDevice device, c10::ArrayRef post_order, torch::lazy::Util::EmissionMap emit_status) : torch::lazy::LoweringContext(name, device, {}, emit_status), - builder_(name) { + builder_(name), + stack_frame_index_builder_(std::make_shared()) { for (auto node : post_order) { LowerNode(node); } @@ -143,16 +155,32 @@ void LoweringContext::SetResult(size_t index, xla::XlaOp op) { } xla::StatusOr LoweringContext::BuildXla() { + xla::StatusOr xla; if (!root_tuple_.empty()) { xla::XlaOp root = xla::Tuple(builder(), root_tuple_); - return builder()->Build(root); + xla = builder()->Build(root); + } else { + xla = builder()->Build(); } - return builder()->Build(); + + if (xla.ok()) { + (*xla->mutable_proto()->mutable_stack_frame_index()) = + stack_frame_index_builder()->stack_frame_index(); + } + + return xla; } xla::StatusOr LoweringContext::BuildXla(xla::XlaOp root) { XLA_CHECK(root_tuple_.empty()); - return builder()->Build(root); + auto xla = builder()->Build(root); + + if (xla.ok()) { + (*xla->mutable_proto()->mutable_stack_frame_index()) = + stack_frame_index_builder()->stack_frame_index(); + } + + return xla; } void LoweringContext::AssignOutputOp(const torch::lazy::Output& output, @@ -265,6 +293,7 @@ void LoweringContext::AddParameter(const torch::lazy::Output& output, torch::lazy::ComputationPtr LoweringContext::Build() { xla::XlaComputation xla_computation = ConsumeValue(BuildXla()); + return std::make_shared( builder_.name(), std::move(xla_computation), device_); } diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index b46d91874b02..b8751673fb68 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,8 @@ namespace torch_xla { +class StackFrameIndexBuilder; + class LoweringContext : public torch::lazy::LoweringContext { public: explicit LoweringContext(const std::string& name, @@ -31,6 +34,10 @@ class LoweringContext : public torch::lazy::LoweringContext { xla::XlaBuilder* builder() { return &builder_; } + StackFrameIndexBuilder* stack_frame_index_builder() { + return stack_frame_index_builder_.get(); + } + const torch::lazy::BackendDevice& device() const { return device_; }; // If a parameter associated with data has already been declared, it will be @@ -95,6 +102,10 @@ class LoweringContext : public torch::lazy::LoweringContext { return emitted_outputs_; } + // Return stack frame id + int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source, + int64_t parent_id); + private: struct Parameter { xla::XlaOp param; @@ -110,6 +121,8 @@ class LoweringContext : public torch::lazy::LoweringContext { parameters_map_; std::vector root_tuple_; OutputMap emitted_outputs_; + + std::shared_ptr stack_frame_index_builder_; }; // namespace torch_xla } // namespace torch_xla diff --git a/torch_xla/csrc/stack_frame_index_builder.cpp b/torch_xla/csrc/stack_frame_index_builder.cpp new file mode 100644 index 000000000000..9624a5de2830 --- /dev/null +++ b/torch_xla/csrc/stack_frame_index_builder.cpp @@ -0,0 +1,95 @@ +#include "torch_xla/csrc/stack_frame_index_builder.h" + +namespace torch_xla { + +// Invalid stack frame id - used for stack frame population +static int kInvalidIndex = 0; + +int FindId(std::string_view key, std::map& index) { + auto entry_iterator = index.find(key); + if (entry_iterator == index.end()) { + return 0; + } else { + return entry_iterator->second; + } +} + +void StackFrameIndexBuilder::AddStackFrameLocations( + const std::vector& frame_info, + int max_stack_depth, xla::OpMetadata& metadata_to_populate) { + if (!frame_info.empty()) { + auto frame_it = frame_info.rbegin(); + int parent_frame_id = kInvalidIndex; + int depth = 0; + for (; frame_it != frame_info.rend() && depth < max_stack_depth; + ++frame_it) { + parent_frame_id = AddStackFrameLocation(*frame_it, parent_frame_id); + ++depth; + } + + // Point to first entry / deepest call / top frame in call stack + --frame_it; + + metadata_to_populate.set_source_file(frame_it->file); + metadata_to_populate.set_source_line(frame_it->line); + metadata_to_populate.set_stack_frame_id(parent_frame_id); + } +} + +int StackFrameIndexBuilder::AddStackFrameLocation( + const torch::lazy::SourceLocation& frame, int parent_frame_id) { + int line = frame.line; + int column = 0; // Not provided in torch lazy source location - set to zero + std::string filename = frame.file; + std::string function_name = frame.function; + + int filename_id = FindId(filename, file_name_to_id_); + if (filename_id == 0) { + indexes_.add_file_names(std::move(filename)); + filename_id = indexes_.file_names_size(); + file_name_to_id_[indexes_.file_names(filename_id - 1)] = filename_id; + } + + int function_name_id = FindId(function_name, function_name_to_id_); + if (function_name_id == 0) { + indexes_.add_function_names(std::move(function_name)); + function_name_id = indexes_.function_names_size(); + function_name_to_id_[indexes_.function_names(function_name_id - 1)] = + function_name_id; + } + + auto location_tuple = + std::make_tuple(filename_id, function_name_id, line, column); + auto file_location_iterator = file_location_to_id_.find(location_tuple); + int file_location_id = 0; + if (file_location_iterator == file_location_to_id_.end()) { + auto file_location = indexes_.add_file_locations(); + file_location->set_file_name_id(filename_id); + file_location->set_function_name_id(function_name_id); + file_location->set_line(line); + file_location->set_column(column); + + file_location_id = indexes_.file_locations_size(); + file_location_to_id_[location_tuple] = file_location_id; + } else { + file_location_id = file_location_iterator->second; + } + + auto frame_tuple = std::make_tuple(file_location_id, parent_frame_id); + auto stack_frame_iterator = frame_to_id_.find(frame_tuple); + int stack_frame_id = 0; + if (stack_frame_iterator == frame_to_id_.end()) { + auto frame = indexes_.add_stack_frames(); + frame->set_file_location_id(file_location_id); + frame->set_parent_frame_id(parent_frame_id); + + stack_frame_id = indexes_.stack_frames_size(); + frame_to_id_[frame_tuple] = stack_frame_id; + } else { + stack_frame_id = stack_frame_iterator->second; + } + + return stack_frame_id; +} + +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/stack_frame_index_builder.h b/torch_xla/csrc/stack_frame_index_builder.h new file mode 100644 index 000000000000..d7a8e6cb440b --- /dev/null +++ b/torch_xla/csrc/stack_frame_index_builder.h @@ -0,0 +1,42 @@ +#pragma once + +#include // SourceLocation + +#include +#include +#include + +#include "xla/service/hlo.pb.h" +#include "xla/types.h" + +namespace torch_xla { + +// TODO: Deduplicate with +// https://github.com/openxla/xla/blob/952d3cf39c3e3eeaa790cc1dd53423c8eb27d473/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc#L40 +// in openxla/xla +class StackFrameIndexBuilder { + public: + StackFrameIndexBuilder() {} + + void AddStackFrameLocations(const std::vector& f, + int max_stack_depth, + xla::OpMetadata& metadata_to_populate); + + const xla::StackFrameIndexProto& stack_frame_index() const { + return indexes_; + } + + private: + int AddStackFrameLocation(const torch::lazy::SourceLocation& source, + int parent_id); + + // Stack frame index tables - we accumulate and write these to the HloModule + xla::StackFrameIndexProto indexes_; + + std::map function_name_to_id_; + std::map file_name_to_id_; + std::map, int> file_location_to_id_; + std::map, int> frame_to_id_; +}; // StackFrameIndexBuilder + +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 4a97aad68b77..9b6dfe585e1d 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -896,20 +896,15 @@ void XLATensor::MarkDynamicDimension(uint32_t dim) { xla_node->MarkDynamicDimension(dim); } -void XLATensor::SetCustomOpName(const std::string& op_name) { - auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); - if (xla_node != nullptr) { - xla_node->SetCustomOpName(op_name); - } -} - -const std::string& XLATensor::GetCustomOpName() const { - auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); - if (xla_node != nullptr) { - return xla_node->custom_op_name(); - } else { - return ""; - } +bool XLATensor::SetNodeUserMetadata( + std::shared_ptr metadata) { + auto* node = dynamic_cast(CurrentIrValue().node.get()); + // auto* node = dynamic_cast(GetIrValue().node.get()); + if (node != nullptr) { + node->SetUserMetadataForSubGraph(metadata); + return true; + } + return false; } } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 83db2e95df61..0452c6f00b51 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -282,9 +283,11 @@ class XLATensor : public torch::lazy::LazyTensor { // Override to enable SPMD. void AssignIrValue(torch::lazy::Value ir_value) const final; - // Set custom op name on XlaNode - void SetCustomOpName(const std::string& op_name); - const std::string& GetCustomOpName() const; + // Set custom op name on base Node type (since not all nodes are XlaNode), + // additionally when using TorchDispatch - e.g. to set a custom op name we + // end up adding additional frames in stack frame debug - this limits + // stack depth + bool SetNodeUserMetadata(std::shared_ptr metadata); private: XLATensor(const at::Tensor& tensor, const torch::lazy::BackendDevice& device);