Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torch/csrc/jit/passes/onnx/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
if (q > 1) {
return c10::nullopt;
}
// If the device of indices tensor is not the same with it of the input
// tensor, move it to the device of the input tensor
auto indices_val = node->input(1);
if (inputTensorValues[0].device() != indices.device()) {
indices = indices.to(inputTensorValues[0].device());
}
// If indices input for onnx::Gather has a value less than 0,
// It needs to be adjusted (+= dim value) for aten op
auto less_mask = at::lt(indices, 0);
Expand Down