diff --git a/test/test_fused_rms_norm.py b/test/test_fused_rms_norm.py index 9bd7e3732c..d5c353c2f1 100644 --- a/test/test_fused_rms_norm.py +++ b/test/test_fused_rms_norm.py @@ -11,7 +11,7 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase,