diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index e275d4a9f1c87..82cc5915bc7c4 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -108,7 +108,7 @@ def forward(self, x): m.eval() input = torch.randn(2, 2) output_s = m.forward(input) - mf = torch._C._freeze_module(m._c) + mf = torch.jit.freeze(m) # Check if frozen module looks as below: # module m { @@ -127,6 +127,7 @@ def forward(self, x): # } # } # } + mf = mf._c self.assertFalse(mf.hasattr('sub1')) self.assertFalse(mf.hasattr('a')) self.assertTrue(mf.hasattr('b')) diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index ec926e1afb1a5..7ceca7f52759d 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -43,6 +43,8 @@ from torch.jit._serialization import save, load from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph +from torch.jit._freeze import freeze + # For backwards compatibility _fork = fork _wait = wait