Skip to content

Commit

Permalink
flake8 & clang
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenBao committed Oct 3, 2020
1 parent 6268036 commit bc02b94
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 3 additions & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,9 @@ def forward(self, x):

x2 = [] if x2 is None else [x2]
if len(x2) > 0:
self.run_test(Squeeze(d), x1, input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, test_with_inputs=x2)
self.run_test(Squeeze(d), x1,
input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}},
test_with_inputs=x2)
else:
self.run_test(Squeeze(d), x1)

Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,16 @@ Node* CloneNodeToGraph(Node* n, std::shared_ptr<Graph> n_graph) {
// order to be consumed as inputs
std::vector<Value*> unsqueezed;
for (auto* input : lc_node->inputs()) {
Node* unsqueezed_node = n_graph->insertNode(n_graph->create(::c10::onnx::Unsqueeze, 1));
Node* unsqueezed_node =
n_graph->insertNode(n_graph->create(::c10::onnx::Unsqueeze, 1));
auto new_input = n_graph->addInput();
new_input->copyMetadata(input);
unsqueezed_node->addInput(new_input);
unsqueezed_node->is_(attr::axes, {0});
unsqueezed.emplace_back(unsqueezed_node->output());
}
Node* concat_node = n_graph->insertNode(n_graph->create(::c10::onnx::Concat, 1));
Node* concat_node =
n_graph->insertNode(n_graph->create(::c10::onnx::Concat, 1));
concat_node->i_(attr::axis, 0);
for (auto v : unsqueezed) {
concat_node->addInput(v);
Expand Down

0 comments on commit bc02b94

Please sign in to comment.