Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 29 additions & 43 deletions .github/scripts/filter_test_configs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#!/usr/bin/env python3

import json
import logging
import os
import re
import subprocess
import sys
import warnings
from enum import Enum
from functools import lru_cache
from logging import info
from typing import Any, Callable, Dict, List, Optional, Set
from urllib.request import Request, urlopen

Expand All @@ -17,33 +19,7 @@

PREFIX = "test-config/"

# Same as shard names
VALID_TEST_CONFIG_LABELS = {
f"{PREFIX}{label}"
for label in {
"backwards_compat",
"crossref",
"default",
"deploy",
"distributed",
"docs_tests",
"dynamo",
"force_on_cpu",
"functorch",
"inductor",
"inductor_distributed",
"inductor_huggingface",
"inductor_timm",
"inductor_torchbench",
"jit_legacy",
"multigpu",
"nogpu_AVX512",
"nogpu_NO_AVX2",
"slow",
"tsan",
"xla",
}
}
logging.basicConfig(level=logging.INFO)


def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool:
Expand Down Expand Up @@ -155,19 +131,25 @@ def get_labels(pr_number: int) -> Set[str]:
}


def filter_labels(labels: Set[str], label_regex: Any) -> Set[str]:
"""
Return the list of matching labels
"""
return {l for l in labels if re.match(label_regex, l)}


def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, List[Any]]:
"""
Select the list of test config to run from the test matrix. The logic works
as follows:

If the PR has one or more labels as specified in the VALID_TEST_CONFIG_LABELS set, only
these test configs will be selected. This also works with ciflow labels, for example,
if a PR has both ciflow/trunk and test-config/functorch, only trunk functorch builds
and tests will be run
If the PR has one or more test-config labels as specified, only these test configs
will be selected. This also works with ciflow labels, for example, if a PR has both
ciflow/trunk and test-config/functorch, only trunk functorch builds and tests will
be run.

If the PR has none of the test-config label, all tests are run as usual.
"""

filtered_test_matrix: Dict[str, List[Any]] = {"include": []}

for entry in test_matrix.get("include", []):
Expand All @@ -177,18 +159,19 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis

label = f"{PREFIX}{config_name.strip()}"
if label in labels:
print(
f"Select {config_name} because label {label} is presented in the pull request by the time the test starts"
)
msg = f"Select {config_name} because label {label} is present in the pull request by the time the test starts"
info(msg)
filtered_test_matrix["include"].append(entry)

valid_test_config_labels = labels.intersection(VALID_TEST_CONFIG_LABELS)

if not filtered_test_matrix["include"] and not valid_test_config_labels:
# Found no valid label and the filtered test matrix is empty, return the same
test_config_labels = filter_labels(labels, re.compile(f"{PREFIX}.+"))
if not filtered_test_matrix["include"] and not test_config_labels:
info("Found no test-config label on the PR, so all test configs are included")
# Found no test-config label and the filtered test matrix is empty, return the same
# test matrix as before so that all tests can be run normally
return test_matrix
else:
msg = f"Found {test_config_labels} on the PR so only these test configs are run"
info(msg)
# When the filter test matrix contain matches or if a valid test config label
# is found in the PR, return the filtered test matrix
return filtered_test_matrix
Expand Down Expand Up @@ -374,30 +357,33 @@ def process_jobs(
# - If the target record has the job (config) name, only that test config
# will be skipped or marked as unstable
if not target_job_cfg:
print(
msg = (
f"Issue {target_url} created by {author} has {issue_type.value} "
+ f"all CI jobs for {workflow} / {job_name}"
)
info(msg)
return _filter_jobs(
test_matrix=test_matrix,
issue_type=issue_type,
)

if target_job_cfg == BUILD_JOB_NAME:
print(
msg = (
f"Issue {target_url} created by {author} has {issue_type.value} "
+ f"the build job for {workflow} / {job_name}"
)
info(msg)
return _filter_jobs(
test_matrix=test_matrix,
issue_type=issue_type,
)

if target_job_cfg in (TEST_JOB_NAME, BUILD_AND_TEST_JOB_NAME):
print(
msg = (
f"Issue {target_url} created by {author} has {issue_type.value} "
+ f"all the test jobs for {workflow} / {job_name}"
)
info(msg)
return _filter_jobs(
test_matrix=test_matrix,
issue_type=issue_type,
Expand Down Expand Up @@ -497,7 +483,7 @@ def perform_misc_tasks(

# Obviously, if the job name includes unstable, then this is an unstable job
is_unstable = job_name and IssueType.UNSTABLE.value in job_name
if not is_unstable and test_matrix:
if not is_unstable and test_matrix and test_matrix.get("include"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind what the mps bug was? I don't understand how this fixes it

Copy link
Contributor Author

@huydhn huydhn Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have added a link to the description, but here it is https://github.com/pytorch/pytorch/actions/runs/8329872641/job/22802012305?pr=121381#step:9:151. So you can see the issue there in Is the current job unstable? True, this is wrong because there is no unstable issue for MPS. The bug is in this part is_unstable = all(IssueType.UNSTABLE.value in r for r in test_matrix["include"]) which returns True when test_matrix["include"] is an empty array.

Another way to express this is all(False for _ in []) returns True :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha true by vacuity

# Even when the job name doesn't mention unstable, we will also mark it as
# unstable when the test matrix only includes unstable jobs. Basically, this
# logic allows build or build-and-test jobs to be marked as unstable too.
Expand Down
12 changes: 5 additions & 7 deletions .github/scripts/test_filter_test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
remove_disabled_jobs,
set_periodic_modes,
SUPPORTED_PERIODICAL_MODES,
VALID_TEST_CONFIG_LABELS,
)


Expand Down Expand Up @@ -273,13 +272,13 @@ def test_filter(self) -> None:
testcases = [
{
"test_matrix": '{include: [{config: "default", runner: "linux"}]}',
"expected": '{"include": [{"config": "default", "runner": "linux"}]}',
"description": "No match, keep the same test matrix",
"expected": '{"include": []}',
"description": "Request test-config/cfg but the test matrix doesn't have it",
},
{
"test_matrix": '{include: [{config: "default", runner: "linux"}, {config: "plain-cfg"}]}',
"expected": '{"include": [{"config": "default", "runner": "linux"}, {"config": "plain-cfg"}]}',
"description": "No match because there is no prefix or suffix, keep the same test matrix",
"expected": '{"include": []}',
"description": "A valid test config label needs to start with test-config/",
},
{
"test_matrix": '{include: [{config: "default", runner: "linux"}, {config: "cfg", shard: 1}]}',
Expand All @@ -294,9 +293,8 @@ def test_filter(self) -> None:
)
self.assertEqual(case["expected"], json.dumps(filtered_test_matrix))

def test_filter_with_valid_label(self) -> None:
def test_filter_with_test_config_label(self) -> None:
mocked_labels = {f"{PREFIX}cfg", "ciflow/trunk"}
VALID_TEST_CONFIG_LABELS.add(f"{PREFIX}cfg")

testcases = [
{
Expand Down