Skip to content

Dynamo and cond with free variables creates malformed graph #90469

@ezyang

Description

@ezyang

🐛 Describe the bug

repro

diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index b0640f6511..8fb8754c51 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -1443,8 +1443,10 @@ class ExportTests(torch._dynamo.test_case.TestCase):
                 self.linear = torch.nn.Linear(3, 3)
 
             def forward(self, pred, x):
+                y = x * 2
+
                 def true_fn(val):
-                    return self.linear(val) * torch.tensor(2)
+                    return self.linear(val) * torch.tensor(2) * y
 
                 def false_fn(val):
                     return self.linear(val) * torch.tensor(-1)

the true graph ends up being

def forward(self, x):
    self_linear = self.self_linear(x);  x = None
    tensor = torch.tensor(2)
    mul = self_linear * tensor;  self_linear = tensor = None
    mul_1 = mul * mul;  mul = mul = None
    return mul_1

which is so bad

Versions

master

cc @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

Metadata

Metadata

Assignees

Labels

module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions