Skip to content

Commit

Permalink
Allow more inserts before reIndexTopology (#102312)
Browse files Browse the repository at this point in the history
Summary:
Currently if you are inserting into JIT IR at the same point in the middle of the graph,
it only allows for 40 inserts before it has to reindex. Reindexing is N**2 behavior, which can
lead to slow load times. This changes it so that it keeps track of how many insertions happen
at single point (like when a function is being inlined) to predict how many future insertions will happen
there. It then adjusts how it assigns topology to make sure there is enough room for those predicted insertions.
In practice this will allow around 2M inserts at a single point before it reindexes.

Test Plan: test_jit.py

Differential Revision: [D46206617](https://our.internmc.facebook.com/intern/diff/D46206617)
Pull Request resolved: #102312
Approved by: https://github.com/eellison
  • Loading branch information
zdevito authored and pytorchmergebot committed Jun 1, 2023
1 parent 6b8e68c commit b9294c7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 10 additions & 3 deletions torch/csrc/jit/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1283,13 +1283,20 @@ void Node::assignTopoPosition() {

// insert between two existing nodes
} else {
const auto posBetween = prevPos + (nextPos - prevPos) / 2;
if (posBetween == prevPos) {
int64_t remaining = nextPos - prevPos;
AT_ASSERT(remaining > 0);
if (remaining == 1) {
// There was no room
owningBlock()->reIndexTopology();
return;
}
topo_position_ = posBetween;
int64_t predicted_future_insertions = 0;
if (next() == graph_->insertPoint()) {
predicted_future_insertions = graph_->predicted_insert_count_++;
}
topo_position_ = prevPos +
std::max(int64_t(1), remaining / (2 + predicted_future_insertions));
AT_ASSERT(prevPos < topo_position_ && topo_position_ < nextPos);
}
}

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,7 @@ struct Graph : std::enable_shared_from_this<Graph> {
// when insertNode() is called, the node is inserted before this node
// by default this is set to append to the top level block
Node* insert_before_;
int64_t predicted_insert_count_ = 0;

c10::optional<size_t> op_version_;

Expand Down Expand Up @@ -1403,14 +1404,15 @@ struct Graph : std::enable_shared_from_this<Graph> {
// set where nodes are inserted to append to the end of this block
void setInsertPoint(Block* b) {
AT_ASSERT(b->owningGraph() == this);
insert_before_ = b->return_node();
setInsertPoint(b->return_node());
}
// set where nodes are inserted to insert _before_ this node
// for implementation simplicity we only support inserting before a node for
// now
void setInsertPoint(Node* n) {
AT_ASSERT(n->owningGraph() == this && n->inBlockList());
insert_before_ = n;
predicted_insert_count_ = 0;
}
Node* insertPoint() {
return insert_before_;
Expand Down

0 comments on commit b9294c7

Please sign in to comment.