Skip to content

Commit

Permalink
Add aten::device variant instead of prim::device
Browse files Browse the repository at this point in the history
  • Loading branch information
Thiago Crepaldi committed Mar 20, 2023
1 parent 99b3fb9 commit 37f612d
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 43 deletions.
1 change: 1 addition & 0 deletions torch/_dynamo/backends/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def onnxrt(gm, example_inputs, *, filename=None, provider=None):
filename,
input_names=input_names,
output_names=output_names,
opset_version=17,
)
del example_inputs, example_outputs

Expand Down
9 changes: 4 additions & 5 deletions torch/csrc/jit/mobile/promoted_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ void device(Stack& stack) {
}

void device_with_index(Stack& stack) {
at::Tensor device;
int index;
pop(stack, device);
pop(stack, index);
device.device().set_index(index);
std::string type = pop(stack).toStringRef();
int index = pop(stack).toInt();
std::string device_str = type + ":" + std::to_string(index);
auto device = c10::Device(device_str);
push(stack, device);
}

Expand Down
17 changes: 0 additions & 17 deletions torch/csrc/jit/passes/peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,23 +269,6 @@ struct PeepholeOptimizeImpl {
node->output()->replaceAllUsesWith(output);
changed = true;
}
} else if (
node->matches("prim::device(Tensor a, int index) -> Device") &&
shape_peepholes_) {
auto ptt = node->inputs().at(0)->type()->cast<TensorType>();
auto index = toIValue(node->inputs().at(1))->toInt();;
if (ptt->device()) {
WithInsertPoint guard(node);
ptt->device()->set_index(index);
auto output = node->owningGraph()->insertConstant(*ptt->device());
GRAPH_UPDATE(
"Replacing ",
getHeader(node),
" with a device constant ",
output->debugName());
node->output()->replaceAllUsesWith(output);
changed = true;
}
} else if (
node->matches("aten::dim(Tensor self) -> int") && shape_peepholes_) {
auto ptt = node->input()->type()->expect<TensorType>();
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,10 +493,6 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
TORCH_SELECTIVE_SCHEMA("prim::device(Tensor a) -> Device"),
device,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::device(Tensor a, int index) -> Device"),
device_with_index,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::dtype(Tensor a) -> int"),
dtype,
Expand Down Expand Up @@ -2296,6 +2292,10 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
push(stack, c10::Device(pop(stack).toStringRef()));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::device(str type, int index) -> Device"),
device_with_index,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::percentFormat(str self, ...) -> str"),
[](Stack& stack) {
Expand Down
22 changes: 5 additions & 17 deletions torch/csrc/jit/runtime/static/native_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,25 +1310,13 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::device,
prim_device,
[](Node* n) -> SROperator {
if (!sr_schema_check(
n,
"prim::device(Tensor a) -> Device",
"prim::device(Tensor a, int index) -> Device")) {
if (!sr_schema_check(n, "prim::device(Tensor a) -> Device")) {
return nullptr;
}

if (n->inputs().size() == 1) {
return [](ProcessedNode* pnode) {
const auto& input = pnode->Input(0).toTensor();
pnode->Output(0) = input.device();
};
} else {
return [](ProcessedNode* pnode) {
const auto& input = pnode->Input(0).toTensor();
input.device().set_index(pnode->Input(1).toInt());
pnode->Output(0) = input.device();
};
}
return [](ProcessedNode* pnode) {
const auto& input = pnode->Input(0).toTensor();
pnode->Output(0) = input.device();
};
});

REGISTER_NATIVE_OPERATOR_FUNCTOR(
Expand Down
7 changes: 7 additions & 0 deletions torch/onnx/symbolic_opset17.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def _compute_edge_sizes(n_fft, window_size):
return left, right


# @_onnx_symbolic("aten::device")
# # @symbolic_helper.parse_args("s", "i")
# def device(g: jit_utils.GraphContext, type, index):
# # device as a noop
# return type


@_onnx_symbolic("aten::stft")
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b")
@_beartype.beartype
Expand Down

0 comments on commit 37f612d

Please sign in to comment.