-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Closed
Description
🐛 Describe the bug
repro:
class M(torch.nn.Module):
def forward(self, x, flag):
# without the following line, it can run fine
flag = flag.item()
def true_fn(x):
return x.clone()
return torch.cond(flag > 0, true_fn, true_fn, [x])
input = (
torch.rand(28, 28, device="cuda"),
torch.tensor(1),
)
model = M().cuda()
_ = model(*input)
ep = torch.export.export(model, input, strict=False)
path = torch._inductor.aot_compile(ep.module(), input)
aot_model = torch._export.aot_load(path, device="cuda")
torch.testing.assert_close(aot_model(*input), model(*input))
error:
/tmp/torchinductor_henrylhtsang/cn73xvreaax4vbhjkxy5zj6rclkmj5jek2srq4cinsa47qiaq55n/ckt6c7aylzejddrwjsp7duxcohlwlzlmqqy7r7er3vb3tcapck7h.cpp:520:9: error: use of undeclared identifier 'u1'
520 | if (u1 > 0L) {
| ^
1 error generated.
codegen:
void AOTInductorModel::run_impl(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 2);
auto arg0_1 = std::move(inputs[0]);
auto arg1_1 = std::move(inputs[1]);
inputs.clear();
auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
AOTICudaStreamGuard stream_guard(stream, this->device_idx_);
RAIIAtenTensorHandle buf2;
if (u1 > 0L) {
// subgraph: true_graph_0
AtenTensorHandle true_graph_0_arg0_1_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out(arg0_1, &true_graph_0_arg0_1_handle));
RAIIAtenTensorHandle true_graph_0_arg0_1(true_graph_0_arg0_1_handle);
Versions
trunk
cc @ezyang @chauhang @penguinwu @bobrenjc93 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @desertfire @chenyang78