Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
6ec9ac3
[Inlined Callstack Fix] Fix inlined callstack for blocks of the node.
kimishpatel Apr 21, 2021
2d424ca
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 21, 2021
537b458
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 21, 2021
971bca0
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 21, 2021
38f7083
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 22, 2021
eaa7e48
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 22, 2021
fffb963
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 23, 2021
be37864
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 27, 2021
d69a566
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 27, 2021
f5fdc0b
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 28, 2021
a0be9bc
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 29, 2021
69b46cb
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 30, 2021
5db2bc6
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel Apr 30, 2021
2a0da64
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel May 2, 2021
17134a4
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel May 3, 2021
2ced766
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel May 3, 2021
c224410
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel May 3, 2021
0b8b13a
Update on "[Inlined Callstack Fix] Fix inlined callstack for blocks o…
kimishpatel May 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions test/cpp/jit/test_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2126,6 +2126,63 @@ def c(x):
ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2"));
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InlinedCallStackTest, BlockAnnotation) {
Module a("A");
a.define(R"(
def forward(self, x, y, z: int):
if (z == 1):
return x + y
else:
return x * y
)");
Module b("B");
b.define(R"(
def forward(self, x):
return x + 2
)");
Module c("C");
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"(
def forward(self, x, y, z: int):
return self.A0.forward(x, y, z) + self.B0.forward(x)
)");

auto graph = c.get_method("forward").function().optimized_graph();
std::stringstream add_ss, mul_ss;
for (Node* n : graph->nodes()) {
if (n->kind() == prim::If) {
for (Block* block : n->blocks()) {
for (Node* if_node : block->nodes()) {
if (if_node->kind() == aten::add) {
Comment on lines +2154 to +2158

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this could be factored out as a function findNodeOfGivenKind or something like this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, you could try to use subgraph matcher for this, but i'm not sure that with the required boilerplate code it would be beneficial.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment still applies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. But looking at this, it looks like it is just a 10ish lines of code without much resuse. If more similar code is added I suppose we can refactor. So for now I will leave it as is unless you have a strong opinion about it.

for (const auto e : if_node->callstack().value()->vec()) {
add_ss << std::get<1>(e);
}
add_ss << if_node->sourceRange();
}
if (if_node->kind() == aten::mul) {
for (const auto e : if_node->callstack().value()->vec()) {
mul_ss << std::get<1>(e);
}
mul_ss << if_node->sourceRange();
}
}
}
}
}
ASSERT_NE(add_ss.str().find("line 3"), std::string::npos);
ASSERT_NE(add_ss.str().find("line 4"), std::string::npos);
ASSERT_NE(
add_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
ASSERT_NE(add_ss.str().find("return x + y"), std::string::npos);
ASSERT_NE(mul_ss.str().find("line 3"), std::string::npos);
ASSERT_NE(mul_ss.str().find("line 6"), std::string::npos);
ASSERT_NE(
mul_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
ASSERT_NE(mul_ss.str().find("return x * y"), std::string::npos);
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(AutogradSymbolsTest, Basic) {
Symbol sym = Symbol::fromQualString("aten::test_symbol");
Expand Down
73 changes: 53 additions & 20 deletions torch/csrc/jit/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,53 @@ at::ArrayRef<Value*> createTupleUnpack(Value* v) {
return g.insertNode(g.createTupleUnpack(v))->outputs();
}

void inlineCallStackOfNode(
Node* n,
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
Function* callee,
Node* to_replace,
c10::optional<ModuleInstanceInfo> m_info);

void inlineCallStackOfBlock(
Block* b,
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
Function* callee,
Node* to_replace,
c10::optional<ModuleInstanceInfo> m_info) {
for (auto n : b->nodes()) {
inlineCallStackOfNode(n, new_cs_entries, callee, to_replace, m_info);
}
}

void inlineCallStackOfNode(
Node* new_node,
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
Function* callee,
Node* to_replace,
c10::optional<ModuleInstanceInfo> m_info) {
auto new_node_cs = new_node->callstack();

InlinedCallStack* raw_callstack_ptr =
new_node_cs ? new_node_cs->get() : nullptr;

if (!new_cs_entries.count(raw_callstack_ptr)) {
if (new_node_cs) {
new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
*new_node_cs, callee, to_replace->sourceRange(), m_info);
} else {
new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
callee, to_replace->sourceRange(), m_info);
}
}
new_node->setCallStack(new_cs_entries.at(raw_callstack_ptr));
// We updated the inlined callstack of new_node.
// Same must be done for the nodes of the blocks of new_node.
// For example If node's block otherwise is not annotated appropriately.
for (auto block : new_node->blocks()) {
inlineCallStackOfBlock(block, new_cs_entries, callee, to_replace, m_info);
}
}

// inline_optimized_graph argument is used in substitute function call for
// ONNX conversion
std::vector<Value*> inlineCallTo(
Expand Down Expand Up @@ -2008,26 +2055,12 @@ std::vector<Value*> inlineCallTo(
continue;
}

auto new_node_cs = new_node->callstack();

InlinedCallStack* raw_callstack_ptr =
new_node_cs ? new_node_cs->get() : nullptr;

if (!new_callstack_entries.count(raw_callstack_ptr)) {
if (new_node_cs) {
new_callstack_entries[raw_callstack_ptr] =
c10::make_intrusive<InlinedCallStack>(
*new_node_cs,
callee,
to_replace->sourceRange(),
module_instance_info);
} else {
new_callstack_entries[raw_callstack_ptr] =
c10::make_intrusive<InlinedCallStack>(
callee, to_replace->sourceRange(), module_instance_info);
}
}
new_node->setCallStack(new_callstack_entries.at(raw_callstack_ptr));
inlineCallStackOfNode(
new_node,
new_callstack_entries,
callee,
to_replace,
module_instance_info);
}
const auto& old_outputs = to_replace->outputs();

Expand Down