Skip to content

Commit

Permalink
[ONNX] Update constant-folding of Gather op (#50554)
Browse files Browse the repository at this point in the history
Update constant-folding of Gather operator so it also includes cases where rank of indices input is 0.
Currently it only support cases where rank of indices is 1.

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Feb 2, 2021
1 parent ec3aae8 commit 3856bf8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
32 changes: 32 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,38 @@ def forward(self, input, indices):
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
self.run_test(GatherModel(), input=(input, indices))

@skipIfUnsupportedMinOpsetVersion(11)
def test_gather_constant_fold(self):
class GatherModule(torch.nn.Module):
def __init__(self):
super(GatherModule, self).__init__()
self.register_buffer("weight", torch.ones(5))

def forward(self, x):
# shape is of rank 0
shape = self.weight.shape[0]
m = 5 - shape
return x.clamp(min=m)

x = torch.randn(1)
self.run_test(GatherModule(), (x,))

class GatherModule(torch.nn.Module):
def __init__(self):
super(GatherModule, self).__init__()
self.register_buffer("weight", torch.ones(2))

def forward(self, x):
# shape is of rank 0
shape = self.weight.shape[0]
pad = [1, shape, shape, shape]
zero_pad = torch.nn.ZeroPad2d(pad)
return zero_pad(x)

x = torch.randn(1, 3, 2)
self.run_test(GatherModule(), (x,))

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(9)
def test_expand(self):
class ExpandModel(torch.nn.Module):
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/passes/onnx/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,16 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
auto indices_corr = at::add(indices, inputTensorValues[0].sizes()[axis]);
auto indices_masked = at::where(less_mask, indices_corr, indices);
updated_val = at::index_select(inputTensorValues[0], axis, indices_masked);
auto q = indices.dim();
// Cases where rank of indices > 1 are not currently supported.
if (q > 1) {
return c10::nullopt;
}
// If rank of indices is 0, rank of output tensor should be
// rank_of_input - 1.
if (q < 1) {
updated_val = updated_val.squeeze();
}
return c10::optional<at::Tensor>(updated_val);
} else {
return c10::nullopt;
Expand Down

0 comments on commit 3856bf8

Please sign in to comment.