Skip to content

Commit

Permalink
Add peephole to aten::device(str, int)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thiago Crepaldi committed Mar 24, 2023
1 parent 37f612d commit 959e5f3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
1 change: 0 additions & 1 deletion torch/_dynamo/backends/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def onnxrt(gm, example_inputs, *, filename=None, provider=None):
filename,
input_names=input_names,
output_names=output_names,
opset_version=17,
)
del example_inputs, example_outputs

Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/jit/passes/peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,24 @@ struct PeepholeOptimizeImpl {
node->output()->replaceAllUsesWith(output);
changed = true;
}
} else if (
node->matches("aten::device(str type, int index) -> Device") &&
shape_peepholes_) {
auto string_type = node->inputs().at(0)->type()->expect<StringType>();
if (string_type) {
WithInsertPoint guard(node);
std::string type_str = node->inputs().at(0)->node()->s(attr::value);
auto index = toIValue(node->inputs().at(1))->toInt();
auto device = c10::Device(type_str + ":" + std::to_string(index));
auto output = node->owningGraph()->insertConstant(device);
GRAPH_UPDATE(
"Replacing ",
getHeader(node),
" with a device constant ",
output->debugName());
node->output()->replaceAllUsesWith(output);
changed = true;
}
} else if (
node->matches("aten::dim(Tensor self) -> int") && shape_peepholes_) {
auto ptt = node->input()->type()->expect<TensorType>();
Expand Down
7 changes: 0 additions & 7 deletions torch/onnx/symbolic_opset17.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,6 @@ def _compute_edge_sizes(n_fft, window_size):
return left, right


# @_onnx_symbolic("aten::device")
# # @symbolic_helper.parse_args("s", "i")
# def device(g: jit_utils.GraphContext, type, index):
# # device as a noop
# return type


@_onnx_symbolic("aten::stft")
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b")
@_beartype.beartype
Expand Down

0 comments on commit 959e5f3

Please sign in to comment.