diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index c6e8ed30d5..e97d22c3b7 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -26,6 +26,10 @@ permissions: jobs: build-test: + if: | + matrix.gpu-arch-type == 'cuda' || + (matrix.gpu-arch-type == 'rocm' && + (github.event_name == 'push' && github.ref == 'refs/heads/main' || github.event_name == 'schedule')) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main strategy: fail-fast: false @@ -73,8 +77,7 @@ jobs: sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - export TEST_WITH_ROCM=$([[ "${{ matrix.gpu-arch-type }}" == "rocm" ]] && echo 1 || echo 0) - python -m tests.integration_tests.run_tests --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 + python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint rm -rf artifacts-to-be-uploaded/*/checkpoint diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index 011fa25554..b2cb8ea503 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -25,9 +25,6 @@ } -TEST_WITH_ROCM = os.getenv("TEST_WITH_ROCM", "0") == "1" - - def _run_cmd(cmd): return subprocess.run([cmd], text=True, shell=True) @@ -92,7 +89,7 @@ def run_tests(args, test_list: list[OverrideDefinitions]): continue # Skip the test for ROCm - if TEST_WITH_ROCM and test_flavor.skip_rocm_test: + if args.gpu_arch_type == "rocm" and test_flavor.skip_rocm_test: continue # Check if we have enough GPUs @@ -110,6 +107,12 @@ def main(): parser.add_argument( "output_dir", help="Directory to dump results generated by tests" ) + parser.add_argument( + "--gpu_arch_type", + default="cuda", + choices=["cuda", "rocm"], + help="GPU architecture type. Must be specified as either 'cuda' or 'rocm'.", + ) parser.add_argument( "--test_suite", default="features",