Skip to content

Using int(shape) in export would result in silent specialization #138853

@henrylhtsang

Description

@henrylhtsang

🐛 Describe the bug

Hi team, just reporting this problem. I can bypass it if I replace int with math.trunc.

repro:

class M(torch.nn.Module):
    def forward(self, x):
        ori_size = (
            int(x.shape[-2] / 1),
            int(x.shape[-1] / 1),
        )
        x = F.interpolate(x, size=ori_size, mode="bilinear")
        return x

input1 = (torch.rand(1, 3, 28, 28, device="cuda"),)
input2 = (torch.rand(1, 3, 56, 56, device="cuda"),)
inputs = [input1, input2]
model = M().cuda()

_ = model(*input1)

dynamic_shapes = {
    "x": {2: torch.export.Dim.DYNAMIC, 3: torch.export.Dim.DYNAMIC},
}
ep = torch.export.export(model, input1, dynamic_shapes=dynamic_shapes, strict=False)
path = torch._inductor.aot_compile(ep.module(), input1)
aot_model = torch._export.aot_load(path, device="cuda")
for input in inputs:
    torch.testing.assert_close(aot_model(*input), model(*input))

error:

torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: The values for attribute 'shape' do not match: torch.Size([1, 3, 28, 28]) != torch.Size([1, 3, 56, 56]).

Versions

trunk

cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

Labels

oncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions