diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index d7af37e9470a9..dd76042f60599 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -22,8 +22,10 @@ def setUp(self): super(TestCudaFuser, self).setUp() self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() + self.old_te_fuse = torch._C._jit_texpr_fuser_enabled() torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) if(RUN_CUDA): torch._C._jit_register_cuda_fuser() @@ -33,6 +35,7 @@ def tearDown(self): torch._C._jit_clear_cuda_fuser() torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse) torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse) + torch._C._jit_set_texpr_fuser_enabled(self.old_te_fuse) super(TestCudaFuser, self).tearDown() def _has_cuda_fusion_group(self, graph):