Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Preprocess index_put with bool inputs to masked_scatter/masked_fill #45584

25 changes: 25 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3243,6 +3243,31 @@ def forward(self, x):
x = torch.randn(3, 4, 5, requires_grad=True)
self.run_test(MaskedSelectModel(), x)

@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_to_masked_fill(self):
class MaskedFillModel(torch.nn.Module):
def forward(self, input_mask, some_const):
mask = input_mask.clone()
mask[mask != some_const] = 1
mask[mask == some_const] = 0
return mask

mask = torch.randn(2, 2, 2, requires_grad=True)
constant = torch.tensor(5, dtype=torch.float)
self.run_test(MaskedFillModel(), (mask, constant))

@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_to_masked_scatter(self):
class MaskedScatterModel(torch.nn.Module):
def forward(self, input_mask, some_const):
mask = input_mask.clone()
mask[mask != some_const] = torch.ones(8)
return mask

mask = torch.randn(2, 2, 2, requires_grad=True)
constant = torch.tensor(5, dtype=torch.float)
self.run_test(MaskedScatterModel(), (mask, constant))

@skipIfUnsupportedMinOpsetVersion(9)
def test_pixel_shuffle(self):
class PixelShuffle(torch.nn.Module):
Expand Down
78 changes: 78 additions & 0 deletions torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,89 @@ static void ReplaceAddWithConcat(Block* b) {
}
}

// Replace aten::index_put_ with aten::masked_scatter or aten::masked_fill
// when inputs to the index_put node contains boolean inputs
//
// before the pass (index_put -> masked_fill):
// graph(%0 : Float(2:4, 2:2, 2:1, requires_grad=0, device=cpu)):
// %mask.1 : Float(2:4, 2:2, 2:1, requires_grad=0, device=cpu)
// %22 : Tensor?[] = prim::ListConstruct(%21)
// %23 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
// %24 : bool = prim::Constant[value=0]()
// %mask : Float(2:4, 2:2, 2:1) = aten::index_put_(%mask.1, %22, %23, %24)
//
// after the pass
// graph(%0 : Float(2:4, 2:2, 2:1, requires_grad=0, device=cpu)):
// %46 : Float(requires_grad=0, device=cpu) = prim::Constant[value={5}]()
// %mask.1 : Float(2:4, 2:2, 2:1, requires_grad=0, device=cpu) =
// %23 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
// %24 : bool = prim::Constant[value=0]()
// %49 : Tensor = aten::masked_fill(%mask.1, %21, %23)
//
// before the pass (index_put -> masked_scatter)
// %48 : Float(8:1, requires_grad=0, device=cpu) = prim::Constant[value= 1 1
// 1 1 1 1 1 1 [ CPUFloatType{8} ]]()
// %42 : Tensor?[] = prim::ListConstruct(%41)
// %43 : bool = prim::Constant[value=0]()
// %44 : Float(2:4, 2:2, 2:1) = aten::index_put_(%mask, %42, %48, %43)
// return (%44)
//
// after the pass:
// %48 : Float(8:1, requires_grad=0, device=cpu) = prim::Constant[value= 1 1
// 1 1 1 1 1 1 [ CPUFloatType{8} ]]()
// %49 : Tensor = aten::masked_fill(%mask.1, %21, %23)
// %41 : Bool(2:4, 2:2, 2:1) = aten::to()
// %50 : Tensor = aten::masked_scatter(%49, %41, %48)
// return (%50)
static void ReplaceIndexPutWithMaskedScatter(Block* b) {
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
for (auto* child_block : it->blocks()) {
ReplaceIndexPutWithMaskedScatter(child_block);
}
if (it->kind() == aten::index_put_) {
auto* lc_node = it->input(1)->node();
if (!(lc_node->input(0)->type()->cast<TensorType>()) ||
(lc_node->input(0)
->type()
->cast<TensorType>()
->scalarType()
.value()) != c10::ScalarType::Bool) {
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
continue;
}

if ((!lc_node->inputs().size()) == 1) {
continue;
}

// If equated value is just a single scalar, then convert to masked_fill,
// and if value is a tensor of appropriate size, we convert to
// masked_scatter.
Node* masked_node;
if (it->input(2)->type()->cast<TensorType>() &&
(it->input(2)->type()->cast<TensorType>()->sizes().size().value()) ==
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
0) {
masked_node = b->owningGraph()->create(aten::masked_fill, 1);
} else {
masked_node = b->owningGraph()->create(aten::masked_scatter, 1);
}

masked_node->insertBefore(*it);
masked_node->addInput(it->input(0));
masked_node->addInput(lc_node->input(0));
masked_node->addInput(it->input(2));
it->replaceAllUsesWith(masked_node);
it->removeAllInputs();
it.destroyCurrent();
}
}
}

} // namespace

void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
FuseWithListUnpack(graph->block());
ReplaceAddWithConcat(graph->block());
ReplaceIndexPutWithMaskedScatter(graph->block());
}

} // namespace jit
Expand Down