diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 5d0e2c2ae6078..1c6b2fada97b0 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10158,21 +10158,36 @@ def test_linear_module_free(self): def test_outside_linear_module_free(self): # Compared to test_linear_module_free, the linear # layer is not the code object that is directly compiled. - def model_inp_ctr(): - fc = torch.nn.Linear(100, 100) - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.fc_ref = fc + # This test does not use _test_compile_model_free because of difficulty + # in handling variable fc. - def forward(self, x): - return fc(x[0]) + fc = torch.nn.Linear(100, 100) + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc_ref = fc + + def forward(self, x): + return fc(x[0]) - # return fc to keep it alive in _test_compile_model_free - return Mod(), (torch.randn(100, 100), fc) + cleared = False + + def finalize(): + nonlocal cleared + cleared = True - self._test_compile_model_free(model_inp_ctr, lambda mod: mod.fc_ref) + def run(): + mod = Mod() + inp = torch.randn(100, 100) + weakref.finalize(mod.fc_ref, finalize) + torch.compile(mod, backend="eager")(inp) + + run() + del fc # This should delete all the references + gc.collect() + self.assertTrue(cleared) @unittest.skipIf(sys.version_info >= (3, 12), "leaks in 3.12+") def test_parameter_free(self):