diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index d515afa5d8e1e..b4065f177251c 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -487,6 +487,12 @@ c10::optional runTorchBackendForOnnx( if (q > 1) { return c10::nullopt; } + // If the device of indices tensor is not the same with it of the input + // tensor, move it to the device of the input tensor + auto indices_val = node->input(1); + if (inputTensorValues[0].device() != indices.device()) { + indices = indices.to(inputTensorValues[0].device()); + } // If indices input for onnx::Gather has a value less than 0, // It needs to be adjusted (+= dim value) for aten op auto less_mask = at::lt(indices, 0);