Skip to content

Commit

Permalink
[ONNX] Redesign inplace conversion (#55033)
Browse files Browse the repository at this point in the history
* Create `InplaceConverter` and `ValueTracker` to keep track of aliases of values throughout the graph. For a given value, a new alias is created every time when there is an inplace operation, SetAttr, or through nested blocks owned by If/Loop nodes.
* Fix bug where controlflow node output types are not set, when the complete node is unable to run ONNX shape inference due to containing non-onnx node.
* Add symbolic for `__not__` ~~and `prim_min`~~(update: moved to a separate PR), and update `index_put` opset9 to support case of assignment without providing indices.
* Bump ORT version in CI test.

ghstack-source-id: 338781ce164798fd375f1a2d0e9f1267deb7746b
Pull Request resolved: #56173
  • Loading branch information
BowenBao committed Apr 16, 2021
1 parent 0698bb7 commit 597846f
Show file tree
Hide file tree
Showing 6 changed files with 874 additions and 592 deletions.
2 changes: 1 addition & 1 deletion .jenkins/caffe2/test.sh
Expand Up @@ -170,7 +170,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# JIT C++ extensions require ninja, so put it into PATH.
export PATH="/var/lib/jenkins/.local/bin:$PATH"
if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then
pip install -q --user onnxruntime==1.6.0
pip install -q --user onnxruntime==1.7.0
fi
"$ROOT_DIR/scripts/onnx/test.sh"
fi
239 changes: 238 additions & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -4569,6 +4569,24 @@ def forward(self, x, y):
y = torch.randn(4, 5)
self.run_test(model, (x, y))

@skipIfUnsupportedMinOpsetVersion(14) # Need onnx::identity of sequence in opset 14
def test_list_append_nested_2(self):
class ListModel(torch.nn.Module):
def forward(self, x):
res = []
res_replicate = []
for i in range(x.size(0)):
if len(res) > 2:
for j in range(x.size(1)):
res.append(x[i][j])
res_replicate.append(res[-1])
res.append(res_replicate[-1])
return res, res_replicate

model = torch.jit.script(ListModel())
x = torch.randn(4, 4, 3, 4)
self.run_test(model, (x, ))

@skipIfUnsupportedMinOpsetVersion(11)
def test_list_pop(self):
class ListModel(torch.nn.Module):
Expand Down Expand Up @@ -4651,6 +4669,36 @@ def forward(self, x, y):
y = torch.randn(4, 5)
self.run_test(model, (x, y))

@skipIfUnsupportedMinOpsetVersion(11)
def test_list_set(self):
class ListModel(torch.nn.Module):
def forward(self, x, y):
res = []
for i in range(x.size(0)):
res.append(x[i])
res[y] = x[y]
return res

model = torch.jit.script(ListModel())
x = torch.randn(12, 4)
y = torch.tensor(2, dtype=torch.long)
self.run_test(model, (x, y))

@skipIfUnsupportedMinOpsetVersion(13)
def test_list_idx_sum(self):
class ListModel(torch.nn.Module):
def forward(self, x, y):
indices = torch.arange(x.size(0))
res = []
for i in range(x.size(0)):
res.append(x[i])
return res[torch.sum(indices[:y])]

model = torch.jit.script(ListModel())
x = torch.randn(12, 4)
y = torch.tensor(2, dtype=torch.long)
self.run_test(model, (x, y))

@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_factories(self):
class TensorFactory(torch.nn.Module):
Expand Down Expand Up @@ -4830,6 +4878,125 @@ def forward(self, x, y):
self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2)
self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2)

@skipIfUnsupportedMinOpsetVersion(9)
def test_inplace_with_loop(self):
class M(torch.nn.Module):
def forward(self, x):
a = torch.ones(12,)
for i in range(10):
a.add_(torch.ones(12,))
return a + x

m = M()
x = torch.randn(12,)
self.run_test(torch.jit.script(M()), (x))

@skipIfUnsupportedMinOpsetVersion(9)
def test_inplace_with_loop_2(self):
class M(torch.nn.Module):
def forward(self, x):
_bias = torch.ones(12,)
a = torch.ones(12,) # used in loop, altered.
a_ref = a # not used in loop, should be altered.
b = x.clone() # used in loop, not be altered.
b_ref = b # not used in loop, should not be altered.
for i in range(10):
if i == 3:
for j in range(5):
a += _bias
_bias.add_(torch.ones(12,))
b = b + torch.ones(12,)

_bias.add_(torch.ones(12,))
a += _bias
# TODO: value for a_ref is incorrect.
# a_ref += torch.ones(12,)
b_ref += torch.ones(12,)
return _bias + x, a, b, b_ref

m = M()
x = torch.zeros(12,)
self.run_test(torch.jit.script(M()), (x))

@skipIfUnsupportedMinOpsetVersion(11)
def test_inplace_attr_with_loop(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self._bias = torch.arange(12,)

def forward(self, x):
self._bias = torch.arange(12,)
for i in range(10):
if i == 3:
for j in range(5):
self._bias += torch.arange(12,)
return self._bias + x

m = M()
x = torch.zeros(12,)
self.run_test(torch.jit.script(M()), (x))

@skipIfUnsupportedMinOpsetVersion(11)
def test_inplace_attr_copy_with_loop(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self._bias = torch.arange(12,)

def forward(self, x):
self._bias = torch.arange(12,)
for i in range(10):
if i == 3:
for j in range(5):
self._bias.copy_(torch.arange(12,))
self._bias.copy_(self._bias + torch.arange(12,))

self._bias.copy_(self._bias + torch.arange(12,))
return self._bias + x

m = M()
x = torch.zeros(12,)
self.run_test(torch.jit.script(M()), (x))

@skipIfUnsupportedMinOpsetVersion(14) # Need onnx::identity of sequence in opset 14
def test_inplace_sequence_with_loop(self):
class M(torch.nn.Module):
def process(self, beam_hyps: List[torch.Tensor], done: torch.Tensor, x):
batch_size = x.shape[0]
for i in range(batch_size):
if done[i]:
continue

beam_idx = 0
for _, token in enumerate(x[i]):
beam_hyps.append(token)
beam_idx += 1

if beam_idx == 6:
break

done[i] = len(beam_hyps) > 4

return beam_hyps, done

def forward(self, x):
beam_hyps: List[torch.Tensor] = []
batch_size = x.shape[0]
cur_len = 0
max_len = x.shape[1]
done = torch.zeros(batch_size, dtype=torch.bool)
while cur_len < max_len:
beam_hyps, done = self.process(beam_hyps, done, x[:, 0, :])
cur_len = cur_len + 1

return beam_hyps

m = torch.jit.script(M())
x = torch.randn(8, 4, 3)
self.run_test(torch.jit.script(M()), (x))


@disableScriptTest() # Sort with dynamic dim not supported in ONNX
def test_sort(self):
class SortModel(torch.nn.Module):
Expand Down Expand Up @@ -7601,6 +7768,37 @@ def forward(self, feature_maps, anchors) -> Tuple[torch.Tensor, torch.Tensor]:
anchors = torch.ones(3, 10, 3)
self.run_test(model, (x, anchors))

@skipIfUnsupportedMinOpsetVersion(11)
def test_set_attr_5(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv = torch.nn.Conv1d(10, 3, 3)
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))

def set_cell_anchors(self, anchors):
self.conv.weight = torch.arange(10)
for i in range(10):
if i == 3:
for j in range(10):
w = self.conv.weight
self.conv.weight = torch.arange(10) + w

self.conv.weight = self.conv.weight + torch.arange(10)
# NOTE: `is not None` and `assert` is for passing torchscript.
if self.conv.bias is not None:
a = self.conv.bias
assert a is not None
self.conv.bias = anchors + a

def forward(self, anchors):
self.set_cell_anchors(anchors)
return self.conv.weight, self.conv.bias

model = torch.jit.script(MyModule())
anchors = torch.ones(3, 10, 3)
self.run_test(model, (anchors))

@skipIfUnsupportedMinOpsetVersion(11)
def test_set_attr_in_loop(self):
class MyModule(torch.nn.Module):
Expand Down Expand Up @@ -7698,7 +7896,11 @@ def forward(self, input_data, prev_state):
model = Example(10)
random_data = torch.rand((1, 5, 30, 30))
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
self.run_test(model, (random_data, empty_tensor))
random_state = torch.rand((1, 1, 10, 30, 30))
self.run_test(model, (random_data, empty_tensor),
input_names=['data', 'state'],
dynamic_axes={'state': [0, 1, 2, 3, 4]},
test_with_inputs=[(random_data, random_state)])

@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_if_3(self):
Expand Down Expand Up @@ -7768,6 +7970,41 @@ def forward(self, input_data, prev_state):
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
self.run_test(model, (random_data, empty_tensor))


@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_if_5(self):
@torch.jit.script
def check_init(input_data, hidden_size, prev_state):
# type: (torch.Tensor, int, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
batch_size = input_data.size(0)
spatial_size_0 = input_data.size(2)
spatial_size_1 = input_data.size(3)
# generate empty prev_state, if None is provided
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
state = torch.zeros(state_size, device=input_data.device)
state_ref = state
if prev_state.size(0) == 0:
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
state = state + 3
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
else:
state = state + 2
return state, state_ref

class Example(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.hidden_size = hidden_size

def forward(self, input_data, prev_state):
prev_state, state_ref = check_init(input_data, self.hidden_size, prev_state)
return prev_state, state_ref

model = Example(4)
random_data = torch.rand((1, 5, 4, 4))
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
self.run_test(model, (random_data, empty_tensor))

@skipIfUnsupportedMinOpsetVersion(11)
def test_list_append_in_block(self):
class ListModel(torch.nn.Module):
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
Expand Up @@ -236,6 +236,11 @@ std::vector<Value*> FixupONNXLoopNode(Node* node, int opset_version) {
// NOTE: the output order is deliberately changed to match expected order
// since onnx loop requires scan outputs to be the last outputs.
auto new_outputs = ConvertSequenceDependencies(node, opset_version);

// Copy type of block output to node output.
for (size_t i = 0; i < node->outputs().size(); ++i) {
node->output(i)->setType(node->blocks().at(0)->outputs().at(i + 1)->type());
}
TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
return new_outputs;
}
Expand Down Expand Up @@ -375,6 +380,11 @@ std::vector<Value*> FixupONNXIfNode(Node* node, int opset_version) {
auto* graph = if_node->owningGraph();
FixupONNXSubblockOutputs(node);
ONNXFixupUninitializedOutput(if_node);
// Copy type of block output to node output.
for (size_t i = 0; i < node->outputs().size(); ++i) {
node->output(i)->setType(node->blocks().at(0)->outputs().at(i)->type());
}

GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph());
return if_node->outputs().vec();
}
Expand Down

0 comments on commit 597846f

Please sign in to comment.