Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions test/custom_debug_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
class_count = defaultdict(int)
instance_count = dict()

# This is a sample implementation for readying object
# hierachies from a source stack usng a TorchDispatch
# interceptor. We then set the node op_name in XLA
# via the output tensor and direct XLA to ignore stack
# frames added (due to TorchDispatch) during lowering


def GetInstancePlaceHolder(class_type, obj):
global class_count
Expand Down Expand Up @@ -172,11 +178,21 @@ def CleanNames(names):

def GetAllObjectAndClassNames(frame):
names = []
frame_count = 0
self_found = False
while frame is not None:
if __file__ == frame.f_code.co_filename:
self_found = True

if not self_found:
frame = frame.f_back
continue

name = GetClassNameAndObjFromFrame(frame)
if len(name) > 0:
names.append(name)
frame = frame.f_back
frame_count += 1

names.reverse()

Expand All @@ -187,7 +203,24 @@ def GetAllObjectAndClassNames(frame):
if len(output) > 0:
output += "/"

return output
return output, frame_count - 1


class StackLayerSignature:

def __init__(self, filename, func, line):
self.filename = filename
self.func = func
self.line = line

def __str__(self):
return f"{self.filename}|{self.func}|{self.line}"

def __repr__(self):
return str(self)

def __eq__(self, ref):
return self.filename == ref.filename and self.func == ref.func and self.line == ref.line


class CustomOpNameLowering(TorchDispatchMode):
Expand All @@ -198,16 +231,38 @@ def __init__(self):
def __enter__(self):
self._old_ir_debug = torch_xla._XLAC._get_ir_debug()
torch_xla._XLAC._set_ir_debug(True)
self.stack_sigs = []
return super().__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
torch_xla._XLAC._set_ir_debug(self._old_ir_debug)
del self.stack_sigs
super().__exit__(exc_type, exc_val, exc_tb)

def add_stack_sig(self, frame, depth):
stack = []
for s in inspect.getouterframes(frame):
sls = StackLayerSignature(s.filename, s.function, s.lineno)
stack.append(sls)

# Pop the top two stack laters
while len(stack) > depth:
stack.pop(0)

assert len(stack) == depth

self.stack_sigs.append(stack)

return stack

def __torch_dispatch__(self, func, types, args=(), kwargs={}):
res = func(*args, **kwargs)
if 'xla' in str(res.device):
frame = inspect.currentframe()
prefix = GetAllObjectAndClassNames(frame)
torch_xla._XLAC._set_xla_custom_op_name(res, prefix)
prefix, depth = GetAllObjectAndClassNames(frame)
self.depth = depth
self.add_stack_sig(frame, self.depth)

assert torch_xla._XLAC._set_xla_custom_op_name_prefix(
res, prefix, self.depth), "Custom op set failed"
return res
89 changes: 83 additions & 6 deletions test/test_hlo_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,52 @@
import torch_xla.debug.metrics as met
import unittest
import json
from custom_debug_lowering import CustomOpNameLowering
import inspect
import copy
from custom_debug_lowering import CustomOpNameLowering, StackLayerSignature


class HloStackExtractor:

def __init__(self, hlo_json):
assert 'stackFrameIndex' in hlo_json
assert 'fileLocations' in hlo_json['stackFrameIndex']
assert 'stackFrames' in hlo_json['stackFrameIndex']
assert 'fileNames' in hlo_json['stackFrameIndex']
assert 'functionNames' in hlo_json['stackFrameIndex']

self.file_locations = hlo_json['stackFrameIndex']['fileLocations']
self.stack_frames = hlo_json['stackFrameIndex']['stackFrames']
self.file_names = hlo_json['stackFrameIndex']['fileNames']
self.function_names = hlo_json['stackFrameIndex']['functionNames']

def extract(self, stack_frame_id):
stack_sigs = []

stack_frame = self.stack_frames[stack_frame_id - 1]

while True:
file_location_id = stack_frame['fileLocationId']
file_location = self.file_locations[file_location_id - 1]
file_name_id = file_location['fileNameId']
function_name_id = file_location['functionNameId']
line = file_location['line']
file_name = self.file_names[file_name_id - 1]
function_name = self.function_names[function_name_id - 1]

sig = StackLayerSignature(file_name, function_name, line)
stack_sigs.append(sig)

stack_frame_id = 0
if 'parentFrameId' in stack_frame:
stack_frame_id = stack_frame['parentFrameId']

if stack_frame_id == 0:
break
else:
stack_frame = self.stack_frames[stack_frame_id - 1]

return stack_sigs


class TestHloMetaData(unittest.TestCase):
Expand All @@ -32,21 +77,25 @@ def test_metadata(self):
nl2 = torch.nn.Tanh()
model = torch.nn.Sequential(layer1, nl1, layer2, nl2)

with CustomOpNameLowering():
with CustomOpNameLowering() as c:
model = model.to(device=xm.xla_device())
inp = torch.rand(4, 4, device=xm.xla_device())
#inp = torch.rand(4, 4)
#inp = inp.to(device=xm.xla_device())
out = model(inp)

# Get outer frames
stack_sigs = c.stack_sigs

ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx.build([out])
hlo_text = ctx.hlo_json()

# Strings to match in the lowering
bingo = {
"torch/_ops.py": False,
#"torch/nn/modules/linear.py": False,
#"torch/nn/modules/activation.py": False,
#"torch/nn/functional.py": False,
"torch/nn/modules/linear.py": False,
"torch/nn/modules/activation.py": False,
"torch/nn/functional.py": False,
"Sequential[model]/Linear[0]": False,
"Sequential[model]/ReLU[1]": False,
"Sequential[model]/Linear[2]": False,
Expand All @@ -60,10 +109,17 @@ def test_metadata(self):
non_zero_metadata = False

local_json = json.loads(hlo_text)

#with open("./hlo.json", "w") as f:
# f.write(json.dumps(local_json, indent=2))

hloEx = HloStackExtractor(local_json)

assert "computations" in local_json
for c in local_json["computations"]:
if "instructions" in c:
i = c["instructions"]

for op in i:
if 'metadata' in op:
meta = op["metadata"]
Expand All @@ -75,6 +131,27 @@ def test_metadata(self):
if isinstance(vm, str) and k in vm:
bingo[k] = True

# Decode stack frame id and check it matches one of the
# the passed in stacks
stack_frame_match = False
if 'stackFrameId' in meta:
hlo_stack_sig = hloEx.extract(meta['stackFrameId'])

for t_sig in stack_sigs:
if len(hlo_stack_sig) == len(t_sig) and hlo_stack_sig == t_sig:
stack_frame_match = True
break
elif len(hlo_stack_sig) > len(t_sig):
hlo_stack_sig_copy = copy.copy(hlo_stack_sig)
discards = []
while len(hlo_stack_sig_copy) > len(t_sig):
discards.append(hlo_stack_sig_copy.pop(0))
# Print an error message on a partial match
if hlo_stack_sig_copy == t_sig:
print(f"** PARTIAL MATCH: Discarded {discards}")

assert stack_frame_match, f"Stack\n{hlo_stack_sig} does not match any of\n{stack_sigs}"

assert non_zero_metadata, "No metadata was lowered - an issue with turning on IR DEBUG?"

for k, v in bingo.items():
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,12 @@ ptxla_cc_library(
srcs = [
"ir.cpp",
"lowering_context.cpp",
"stack_frame_index_builder.cpp",
],
hdrs = [
"ir.h",
"lowering_context.h",
"stack_frame_index_builder.h",
],
deps = [
":device",
Expand Down
15 changes: 7 additions & 8 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1918,15 +1918,14 @@ void InitXlaModuleBindings(py::module m) {
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
});
m.def("_set_xla_custom_op_name",
[](const at::Tensor& input, const std::string& op_name) {
m.def("_set_xla_custom_op_name_prefix",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to _set_user_metadata and take a torch::lazy::UserMetadata if possible.

Let's pass in everything that can be forwarded to HLO here (like filename, line num, whatever Indexes etc); so no computation happens in C++.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can pass in a python dict (which becomes a map here) with all needed content and do creation of CustomOpNameMetadata if you feel exposing CustomOpNameMetadata through pybind11 is too much work.

[](const at::Tensor& input, const std::string& op_name_prefix,
size_t max_call_stack_depth) -> bool {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->SetCustomOpName(op_name);
});
m.def("_get_xla_custom_op_name",
[](const at::Tensor& input) -> const std::string& {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return xtensor->GetCustomOpName();
std::shared_ptr<torch::lazy::UserMetaData> user_meta =
std::make_shared<CustomOpNameMetaData>(op_name_prefix,
max_call_stack_depth);
return xtensor->SetNodeUserMetadata(user_meta);
});
m.def("_get_all_reduce_token",
[](const std::string& device_str) -> const torch::lazy::Value& {
Expand Down
12 changes: 10 additions & 2 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,16 @@ void XlaNode::UpdateShardingHash() {
}
}

void XlaNode::SetCustomOpName(const std::string& op_name) {
custom_op_name_ = op_name;
std::shared_ptr<torch::lazy::UserMetaData> XlaNode::SetUserMetadataForSubGraph(
std::shared_ptr<torch::lazy::UserMetaData> user_meta) {
for (auto np : operands_) {
XlaNode* xnp = dynamic_cast<XlaNode*>(np.get());
if (xnp != nullptr && xnp->user_metadata() == nullptr) {
xnp->SetUserMetadataForSubGraph(user_meta);
}
}
// Only set if there is no metadata already set
return SetUserMetadata(user_meta);
}

} // namespace torch_xla
16 changes: 12 additions & 4 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/ir_metadata.h>

#include <functional>
#include <iostream>
Expand Down Expand Up @@ -146,8 +147,8 @@ class XlaNode : public torch::lazy::Node {
return unbounded_dynamic_dims_;
}

void SetCustomOpName(const std::string& op_name);
const std::string& custom_op_name() const { return custom_op_name_; }
std::shared_ptr<torch::lazy::UserMetaData> SetUserMetadataForSubGraph(
std::shared_ptr<torch::lazy::UserMetaData> user_meta);

protected:
std::unordered_set<uint32_t> unbounded_dynamic_dims_;
Expand All @@ -170,8 +171,6 @@ class XlaNode : public torch::lazy::Node {

// Experimental sharding annotations attached to the IR node.
std::vector<std::shared_ptr<xla::OpSharding>> output_shardings_;

std::string custom_op_name_;
};

inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
Expand All @@ -195,6 +194,15 @@ T* NodeCast(const torch::lazy::Node* node, torch::lazy::OpKind op) {
return const_cast<T*>(casted);
}

struct CustomOpNameMetaData : public torch::lazy::UserMetaData {
CustomOpNameMetaData(const std::string& input_op_name_prefix,
int input_max_stack_depth)
: op_name_prefix(input_op_name_prefix),
max_stack_depth(input_max_stack_depth) {}
std::string op_name_prefix;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this struct to contain the exact things you need to override in HLO proto, i.e.
source_file, source_line, stack_frame_id etc.

size_t max_stack_depth;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_IR_H_
Loading