Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Fix for sequence of mutations in blocks #51577

Merged
merged 34 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dc81f3d
Fix for feedback
neginraoof Jan 28, 2021
e368129
clang
neginraoof Jan 28, 2021
c2370ed
Fix for clang
neginraoof Jan 28, 2021
4d0a47e
[ONNX] Update constant-folding of Gather op (#50554)
KsenijaS Jan 28, 2021
def794f
Fix for feedback
neginraoof Jan 28, 2021
3766f32
[ONNX] Fix bug in unfold symbolic (#50504)
KsenijaS Jan 29, 2021
35b8f35
[ONNX] Improve error message for parse_arg in symbolic functions (#50…
BowenBao Jan 29, 2021
f67519e
Update test.sh
neginraoof Jan 29, 2021
651db7e
Fixed tests
neginraoof Jan 29, 2021
bd35e20
Merge branch 'neraoof/fixSetAttribute' of github.com:neginraoof/pytor…
neginraoof Jan 29, 2021
2a4a1f4
Fix mypy
neginraoof Jan 29, 2021
4fab460
[ONNX] Export get/set attribute nodes (#50768)
neginraoof Jan 29, 2021
36eabc8
[ONNX] Enable remaining failed tests in opset13 (#50806)
hwangdeyu Jan 29, 2021
5bbb160
[ONNX] Add silu operator support for onnx (#51193)
hwangdeyu Jan 29, 2021
6a00ffe
[ONNX] Fix graph position to insert clone node for inplace op removal…
BowenBao Jan 31, 2021
c6c6535
[ONNX] Fix graph sequence output from loop node (#51305)
BowenBao Jan 31, 2021
687a9a2
Update error message that displays when encountering an op unsupporte…
Feb 1, 2021
8ba250c
[ONNX] Enable Constant Folding for ONNX Opset 13 (#51096)
hwangdeyu Feb 1, 2021
0015b95
[ONNX] Update unsafe_chunk() method to support new version 13 of Spli…
fatcat-z Feb 1, 2021
f15d3f4
[ONNX] Fix opset 11 ConstantChunk with negative dim (#51396)
BowenBao Feb 2, 2021
8221155
[ONNX] Support list remove for onnx export (#51373)
BowenBao Feb 2, 2021
1b2190d
fix bug (#51222)
KsenijaS Feb 2, 2021
ef4e3b7
[ONNX] Modifications in remove inplace ops passes to better handle bi…
shubhambhokare1 Feb 2, 2021
a7d103e
Fix multiple mutations
neginraoof Feb 2, 2021
104bbd0
Merge branch 'onnx_ms_1' of https://github.com/pytorch/pytorch into n…
neginraoof Feb 2, 2021
e5a91c3
merge
neginraoof Feb 2, 2021
8007894
Clear up the code
neginraoof Feb 3, 2021
ee28dd1
Fix for feedback
neginraoof Feb 4, 2021
58004e2
Adding test for opset 13
neginraoof Feb 4, 2021
4750169
Update test_pytorch_onnx_onnxruntime.py
neginraoof Feb 5, 2021
2b12a05
Merge branch 'onnx_ms_1' into neraoof/fixSetAttribute2
neginraoof Feb 5, 2021
305c96b
merge
neginraoof Feb 5, 2021
f808f8a
Fix for del, append and insert
neginraoof Feb 5, 2021
76e562e
Remove unused mutation remover
neginraoof Feb 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
175 changes: 161 additions & 14 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def run_ort(ort_sess, input):

ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
ort_outs = ort_sess.run(None, ort_inputs)

return inline_flatten_list(ort_outs, [])


Expand Down Expand Up @@ -1777,7 +1776,15 @@ def ngram_attention_bias(sequence_length: int, ngram: int, device: torch.device,
bias = torch.ones((ngram, sequence_length), device=device, dtype=dtype) * float("-inf")
for stream_idx in range(ngram):
for i in range(sequence_length):
bias = bias * 2
bias[stream_idx, i] = 5
bias = bias * 5
bias[0, 0] = 5

for stream_idx in range(ngram):
for i in range(sequence_length):
bias[stream_idx, i] = 5
bias[0, i] = 5
return bias

class ScriptModel(torch.nn.Module):
Expand Down Expand Up @@ -3911,8 +3918,8 @@ def forward(self, x):
res3 = []
res4 = []
for i in range(len(arr)):
res = res.append(arr[i].sum(0, False))
res1 = res1.append(arr[-1 - i].sum(0, False))
res.append(arr[i].sum(0, False))
res1.append(arr[-1 - i].sum(0, False))
res2 += 1
res3 = res3 + [arr[i].sum(0, False)]
res4 += [arr[-1 - i].sum(0, False)]
Expand Down Expand Up @@ -4261,7 +4268,7 @@ def forward(self, x, y):
self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2)
self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2)

@disableScriptTest()
@disableScriptTest() # Sort with dynamic dim not supported in ONNX
def test_sort(self):
class SortModel(torch.nn.Module):
def forward(self, x):
Expand All @@ -4274,7 +4281,7 @@ def forward(self, x):
self.run_test(SortModel(), x)

@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
@disableScriptTest() # Sort with dynamic dim not supported in ONNX
def test_sort_ascending(self):
class SortModel(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -5799,8 +5806,7 @@ def forward(self, x, y, cond):
res = res + [x]
else:
res = res + [y]
# TODO: remove torch.stack once graph sequence output is supported.
return torch.stack(res)
return res

x = torch.randn(2, 3)
y = torch.randn(3, 3)
Expand Down Expand Up @@ -6565,17 +6571,19 @@ def __init__(self):
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))

def set_cell_anchors(self, anchors):
def set_cell_anchors(self, anchors, boxes):
self.conv.weight = torch.ones(3, 10)
if self.conv.bias is not None:
self.conv.bias = torch.randn(3, 10, 3)
self.conv.weight = anchors + self.conv.weight
boxes[:] = torch.zeros(2, 3)

def forward(self, anchors) -> torch.Tensor:
self.set_cell_anchors(anchors)
def forward(self, anchors) -> Tuple[torch.Tensor, torch.Tensor]:
boxes = torch.ones(2, 2, 3)
self.set_cell_anchors(anchors, boxes)
if self.conv.bias is not None:
return self.conv.weight
return anchors
return self.conv.weight, boxes
return anchors, boxes

model = torch.jit.script(MyModule())
anchors = torch.rand(3, 10)
Expand Down Expand Up @@ -6629,6 +6637,7 @@ def check_init(input_data, hidden_size, prev_state):
if prev_state.size(0) == 0:
state[:] = torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1) + state[:]
state_copy[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 2
state_copy[:] = torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 2
else:
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
return state, state_copy
Expand Down Expand Up @@ -6660,8 +6669,9 @@ def check_init(input_data, hidden_size, prev_state):
state = torch.zeros(state_size, device=input_data.device)
state_copy = torch.zeros(state_size, device=input_data.device)
if prev_state.size(0) == 0:
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
state_copy[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 2
for i in range(2):
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * i
state_copy[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * i
elif prev_state.size(0) == 1:
s = state[:]
state[:] = prev_state + s
Expand All @@ -6684,6 +6694,143 @@ def forward(self, input_data, prev_state):
self.run_test(model, (random_data, empty_tensor))

@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_if_3(self):
@torch.jit.script
def check_init(input_data, hidden_size, prev_state):
# type: (torch.Tensor, int, torch.Tensor) -> torch.Tensor
batch_size = input_data.size(0)
spatial_size_0 = input_data.size(2)
spatial_size_1 = input_data.size(3)
# generate empty prev_state, if None is provided
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
state = torch.zeros(state_size, device=input_data.device)
if prev_state.size(0) < 2:
state = state * 3
if prev_state.size(0) == 0:
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
else:
state = state + 2

return state

class Example(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.hidden_size = hidden_size

def forward(self, input_data, prev_state):
prev_state = check_init(input_data, self.hidden_size, prev_state)
return prev_state

model = Example(4)
random_data = torch.rand((1, 5, 4, 4))
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
self.run_test(model, (random_data, empty_tensor))

@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_if_4(self):
@torch.jit.script
def check_init(input_data, hidden_size, prev_state):
# type: (torch.Tensor, int, torch.Tensor) -> torch.Tensor
batch_size = input_data.size(0)
spatial_size_0 = input_data.size(2)
spatial_size_1 = input_data.size(3)
# generate empty prev_state, if None is provided
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
state = torch.zeros(state_size, device=input_data.device)
if prev_state.size(0) == 0:
state = state + 3
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
state = state + 3
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
else:
state = state + 2
return state

class Example(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.hidden_size = hidden_size

def forward(self, input_data, prev_state):
prev_state = check_init(input_data, self.hidden_size, prev_state)
return prev_state

model = Example(4)
random_data = torch.rand((1, 5, 4, 4))
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
self.run_test(model, (random_data, empty_tensor))

@skipIfUnsupportedMinOpsetVersion(11)
def test_list_append_in_block(self):
class ListModel(torch.nn.Module):
def forward(self, x, y):
res = []
for i in range(x.size(0)):
res.append(torch.matmul(x[i], y))
return res

model = torch.jit.script(ListModel())
x = torch.randn(16, 3, 4)
y = torch.randn(4, 5)
self.run_test(model, (x, y))

@skipIfUnsupportedMinOpsetVersion(13)
def test_list_append_in_nested_block(self):
class ListModel(torch.nn.Module):
def forward(self, x, y):
res = []
for i in range(x.size(0)):
for j in range(x.size(1)):
res.append(torch.matmul(x[i][j], y))
return res

model = torch.jit.script(ListModel())
x = torch.randn(4, 4, 3, 4)
y = torch.randn(4, 5)
self.run_test(model, (x, y))

@skipIfUnsupportedMinOpsetVersion(13)
def test_list_pop_in_block(self):
class ListModel(torch.nn.Module):
def forward(self, x, y):
res = []
elem = torch.matmul(x[0], y)
for i in range(x.size(0)):
res.append(torch.matmul(x[i], y))
for i in range(x.size(0)):
elem = res.pop()
for i in range(x.size(0)):
res.append(torch.matmul(x[i], y))
elem = res.pop()
return res.append(elem)

model = torch.jit.script(ListModel())
x = torch.randn(16, 3, 4)
y = torch.randn(4, 5)
self.run_test(model, (x, y))


@skipIfUnsupportedMinOpsetVersion(13)
def test_list_del_in_block(self):
class ListModel(torch.nn.Module):
def forward(self, x, y):
res = []
elem = torch.matmul(x[0], y)
for i in range(x.size(0)):
res.append(torch.matmul(x[i], y))
for i in range(x.size(0)):
del res[0]
for i in range(x.size(0)):
res.append(torch.matmul(x[i], y))
del res[0]
return res.append(elem)

model = torch.jit.script(ListModel())
x = torch.randn(16, 3, 4)
y = torch.randn(4, 5)
self.run_test(model, (x, y))

@disableScriptTest()
def test_unsafe_chunk(self):
class ChunkModel(torch.nn.Module):
Expand Down