Skip to content

[AOTI] SymInt used in torch.cond isn't codegen  #139798

@henrylhtsang

Description

@henrylhtsang

🐛 Describe the bug

repro:

class M(torch.nn.Module):
    def forward(self, x, flag):
        # without the following line, it can run fine
        flag = flag.item()

        def true_fn(x):
            return x.clone()

        return torch.cond(flag > 0, true_fn, true_fn, [x])

input = (
    torch.rand(28, 28, device="cuda"),
    torch.tensor(1),
)
model = M().cuda()

_ = model(*input)

ep = torch.export.export(model, input, strict=False)
path = torch._inductor.aot_compile(ep.module(), input)
aot_model = torch._export.aot_load(path, device="cuda")
torch.testing.assert_close(aot_model(*input), model(*input))

error:

/tmp/torchinductor_henrylhtsang/cn73xvreaax4vbhjkxy5zj6rclkmj5jek2srq4cinsa47qiaq55n/ckt6c7aylzejddrwjsp7duxcohlwlzlmqqy7r7er3vb3tcapck7h.cpp:520:9: error: use of undeclared identifier 'u1'
  520 |     if (u1 > 0L) {
      |         ^
1 error generated.

codegen:

void AOTInductorModel::run_impl(
    AtenTensorHandle*
        input_handles, // array of input AtenTensorHandle; handles
                        // are stolen; the array itself is borrowed
    AtenTensorHandle*
        output_handles, // array for writing output AtenTensorHandle; handles
                        // will be stolen by the caller; the array itself is
                        // borrowed
    DeviceStreamType stream,
    AOTIProxyExecutorHandle proxy_executor
) {

    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 2);
    auto arg0_1 = std::move(inputs[0]);
    auto arg1_1 = std::move(inputs[1]);
    inputs.clear();
    auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());

    AOTICudaStreamGuard stream_guard(stream, this->device_idx_);
    RAIIAtenTensorHandle buf2;
    if (u1 > 0L) {
        // subgraph: true_graph_0
        AtenTensorHandle true_graph_0_arg0_1_handle;
        AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out(arg0_1, &true_graph_0_arg0_1_handle));
        RAIIAtenTensorHandle true_graph_0_arg0_1(true_graph_0_arg0_1_handle);

Versions

trunk

cc @ezyang @chauhang @penguinwu @bobrenjc93 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @desertfire @chenyang78

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions