diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py index ae7c32cdeb..a9f9d2afeb 100644 --- a/tests/integration_tests/__init__.py +++ b/tests/integration_tests/__init__.py @@ -22,6 +22,7 @@ class OverrideDefinitions: test_descr: str = "default" test_name: str = "default" ngpu: int = 4 + disabled: bool = False def __repr__(self): return self.test_descr diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index a0aa1903a6..31c15017d1 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -65,17 +65,18 @@ def build_features_test_list() -> list[OverrideDefinitions]: "2d_compile", ), # TODO: re-enable this test once the async TP CI issue is fixed - # OverrideDefinitions( - # [ - # [ - # "--compile.enable", - # "--parallelism.tensor_parallel_degree 2", - # "--parallelism.enable_async_tensor_parallel", - # ], - # ], - # "2D async TP compile", - # "2d_asynctp_compile", - # ), + OverrideDefinitions( + [ + [ + "--compile.enable", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.enable_async_tensor_parallel", + ], + ], + "2D async TP compile", + "2d_asynctp_compile", + disabled=True, + ), OverrideDefinitions( [ [ @@ -432,16 +433,17 @@ def build_features_test_list() -> list[OverrideDefinitions]: "cpu_offload+opt_in_bwd+TP+DP+CP", ngpu=8, ), - # OverrideDefinitions( - # [ - # [ - # "--memory_estimation.enable", - # ] - # ], - # "FSDP2 Memory Tracking and Estimation", - # "fsdp2_memory_estimation", - # ngpu=2, - # ), + OverrideDefinitions( + [ + [ + "--memory_estimation.enable", + ] + ], + "FSDP2 Memory Tracking and Estimation", + "fsdp2_memory_estimation", + ngpu=2, + disabled=True, + ), OverrideDefinitions( [ [ diff --git a/tests/integration_tests/h100.py b/tests/integration_tests/h100.py index ae1fb5b597..53af87a2ea 100755 --- a/tests/integration_tests/h100.py +++ b/tests/integration_tests/h100.py @@ -19,6 +19,7 @@ def build_h100_tests_list() -> list[OverrideDefinitions]: same root config file. """ integration_tests_flavors = [ + # TODO: re-enable this test once the async TP issue is fixed OverrideDefinitions( [ [ @@ -29,6 +30,7 @@ def build_h100_tests_list() -> list[OverrideDefinitions]: ], "2D async TP compile", "2d_asynctp_compile", + disabled=True, ), OverrideDefinitions( [ @@ -41,6 +43,7 @@ def build_h100_tests_list() -> list[OverrideDefinitions]: "Float8 test", "float8", ), + # TODO: re-enable this test once the async TP issue is fixed OverrideDefinitions( [ [ @@ -57,6 +60,7 @@ def build_h100_tests_list() -> list[OverrideDefinitions]: "FSDP+async TP+PP+torch.compile+Float8", "fsdp+tp+cp+compile+float8", ngpu=8, + disabled=True, ), OverrideDefinitions( [ diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index dff179e4a5..a64c69eb61 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -84,6 +84,9 @@ def run_tests(args, test_list: list[OverrideDefinitions]): if args.test_name != "all" and test_flavor.test_name != args.test_name: continue + if test_flavor.disabled: + continue + # Check if we have enough GPUs if args.ngpu < test_flavor.ngpu: logger.info( diff --git a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py index c680f84a73..33ccbd8903 100755 --- a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py +++ b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py @@ -63,18 +63,19 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: "2d", ), # TODO: re-enable this test once the async TP issue is fixed - # OverrideDefinitions( - # [ - # [ - # "--model.name simple_fsdp", - # "--compile.enable", - # "--parallelism.tensor_parallel_degree 2", - # "--parallelism.enable_async_tensor_parallel", - # ], - # ], - # "2D async TP", - # "2d_asynctp", - # ), + OverrideDefinitions( + [ + [ + "--model.name simple_fsdp", + "--compile.enable", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.enable_async_tensor_parallel", + ], + ], + "2D async TP", + "2d_asynctp", + disabled=True, + ), OverrideDefinitions( [ [