Skip to content

Commit

Permalink
Test Reordering: Run previously failing tests first (#101123)
Browse files Browse the repository at this point in the history
Makes the CI prioritize running any test files that had a failing test in a previous iteration of the given PR.

A follow up to #100522 which makes the `.pytest_cache` available to use here

A concrete example:
1. Person A pushes a new commit and creates a PR.
2. 2 hours later, test_im_now_broken.py fails
3. Person A attempts to fix the test, but the test is actually still broken
4. The CI, seeing that test_im_now_broken.py had failed on a previous run, will now prioritize running that test first. Instead of waiting another 2 hours to get a signal, Person A only needs to wait ~15 minutes (which is how long it takes for tests to start running)

# Testing
I modified a file to make the tests invoking it fail and triggered CI twice with this failure.

First run: https://github.com/pytorch/pytorch/actions/runs/4963943209/jobs/8883800811
Test step took 1h 9m to run

Second run: https://github.com/pytorch/pytorch/actions/runs/4965016776/jobs/8885657992
Test step failed within 2m 27s

Pull Request resolved: #101123
Approved by: https://github.com/malfet, https://github.com/huydhn
  • Loading branch information
ZainRizvi authored and pytorchmergebot committed May 16, 2023
1 parent b5ed606 commit b147401
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 26 deletions.
19 changes: 19 additions & 0 deletions tools/shared/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def pluralize(count: int, singular_word: str, plural_word: str = "") -> str:
if count == 1:
return f"{count} {singular_word}"

if not plural_word:
plural_word = f"{singular_word}s"

return f"{count} {plural_word}"


def duration_to_str(seconds: float) -> str:
if seconds < 0.00001:
return "0s"
elif seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
return f"{seconds / 60:.1f}m"
else:
return f"{seconds / 3600:.1f}h"
89 changes: 87 additions & 2 deletions tools/test/test_test_selections.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import io
import json
import pathlib
import random
import sys
import unittest
from collections import defaultdict
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Set, Tuple
from unittest import mock

REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
from tools.testing.test_selections import (
_get_previously_failing_tests,
calculate_shards,
get_reordered_tests,
ShardedTest,
THRESHOLD,
)
except ModuleNotFoundError:
print("Can't import required modules, exiting")
exit(1)
Expand Down Expand Up @@ -328,5 +337,81 @@ def test_calculate_2_shards_against_optimal_shards(self) -> None:
self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])


def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
file_object = io.StringIO()
json.dump(contents, file_object)
file_object.seek(0)
return file_object


class TestParsePrevTests(unittest.TestCase):
@mock.patch("pathlib.Path.exists", return_value=False)
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
expected_failing_test_files: Set[str] = set()

found_tests = _get_previously_failing_tests()

self.assertSetEqual(expected_failing_test_files, found_tests)

@mock.patch("pathlib.Path.exists", return_value=True)
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
expected_failing_test_files: Set[str] = set()

found_tests = _get_previously_failing_tests()

self.assertSetEqual(expected_failing_test_files, found_tests)
mock_open.assert_called()

lastfailed_with_multiple_tests_per_file = {
"test/test_car.py::TestCar::test_num[17]": True,
"test/test_car.py::TestBar::test_num[25]": True,
"test/test_far.py::TestFar::test_fun_copy[17]": True,
"test/test_bar.py::TestBar::test_fun_copy[25]": True,
}

@mock.patch("pathlib.Path.exists", return_value=True)
@mock.patch(
"builtins.open",
return_value=mocked_file(lastfailed_with_multiple_tests_per_file),
)
def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> None:
expected_failing_test_files = {"test_car", "test_bar", "test_far"}
found_tests = _get_previously_failing_tests()

self.assertSetEqual(expected_failing_test_files, found_tests)

@mock.patch(
"tools.testing.test_selections._get_previously_failing_tests",
return_value={"test4"},
)
@mock.patch(
"tools.testing.test_selections._get_modified_tests",
return_value={"test2", "test4"},
)
def test_get_reordered_tests(
self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any
) -> None:
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=600.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=500.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=400.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=300.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=200.0),
]

expected_prioritized_tests = {"test4", "test2"}
expected_remaining_tests = {"test1", "test3", "test5"}

prioritized_tests, remaining_tests = get_reordered_tests(tests)

# Just want to check the names of the tests
prioritized_tests_name = {test.name for test in prioritized_tests}
remaining_tests_name = {test.name for test in remaining_tests}

self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
self.assertSetEqual(expected_remaining_tests, remaining_tests_name)


if __name__ == "__main__":
unittest.main()
125 changes: 101 additions & 24 deletions tools/testing/test_selections.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import json
import math
import os
import subprocess
from pathlib import Path

from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from warnings import warn

from tools.shared.logging_utils import duration_to_str, pluralize

from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests

Expand Down Expand Up @@ -37,7 +42,7 @@ class ShardedTest(NamedTuple):
name: str
shard: int
num_shards: int
time: Optional[float]
time: Optional[float] # In seconds

def __str__(self) -> str:
return f"{self.name} {self.shard}/{self.num_shards}"
Expand Down Expand Up @@ -133,50 +138,122 @@ def _query_changed_test_files() -> List[str]:
return lines


def _get_previously_failing_tests() -> Set[str]:
PYTEST_FAILED_TESTS_CACHE_FILE_PATH = Path(".pytest_cache/v/cache/lastfailed")

if not PYTEST_FAILED_TESTS_CACHE_FILE_PATH.exists():
warn(
f"No pytorch cache found at {PYTEST_FAILED_TESTS_CACHE_FILE_PATH.absolute()}"
)
return set()

with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH, "r") as f:
last_failed_tests = json.load(f)

prioritized_tests = _parse_prev_failing_test_files(last_failed_tests)
return _python_test_file_to_test_name(prioritized_tests)


def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[str]:
prioritized_tests = set()

# The keys are formatted as "test_file.py::test_class::test_method[params]"
# We just need the test_file part
for test in last_failed_tests:
parts = test.split("::")
if len(parts) > 1:
test_file = parts[0]
prioritized_tests.add(test_file)

return prioritized_tests


def _get_modified_tests() -> Set[str]:
try:
changed_files = _query_changed_test_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
# If unable to get changed files from git, quit without doing any sorting
return set()

return _python_test_file_to_test_name(set(changed_files))


def _python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
prefix = f"test{os.path.sep}"
valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")}
valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests}

return valid_tests


def get_reordered_tests(
tests: List[ShardedTest],
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
"""
Get the reordered test filename list based on github PR history or git changed file.
We prioritize running test files that were changed.
"""
prioritized_tests: List[str] = []
if len(prioritized_tests) == 0:
try:
changed_files = _query_changed_test_files()
except Exception:
# If unable to get changed files from git, quit without doing any sorting
return ([], tests)

prefix = f"test{os.path.sep}"
prioritized_tests = [
f for f in changed_files if f.startswith(prefix) and f.endswith(".py")
]
prioritized_tests = [f[len(prefix) :] for f in prioritized_tests]
prioritized_tests = [f[: -len(".py")] for f in prioritized_tests]
print("Prioritized test from test file changes.")

def print_tests(tests: Set[str], test_group_description: str) -> None:
if not tests:
return

print(f"{test_group_description}:")
for test in tests:
print(f" {test}")

prioritized_tests: Set[str] = set()

pri_test = _get_previously_failing_tests()
print_tests(
pri_test, "If run, these tests will prioritized because they previously failed"
)
prioritized_tests |= pri_test

pri_test |= _get_modified_tests()
print_tests(
pri_test, "If run, these tests will be prioritized because they were modified"
)
prioritized_tests |= pri_test

bring_to_front = []
the_rest = []

test_time_for_regular_tests_so_far = 0.0
# how much sooner did we run prioritized tests compared to a naive ordering
time_savings_sec = 0.0

for test in tests:
if test.name in prioritized_tests:
bring_to_front.append(test)
# Calculate approx time saved by reordering
time_savings_sec = test_time_for_regular_tests_so_far
else:
the_rest.append(test)
if len(tests) == len(bring_to_front) + len(the_rest):
print(
f"reordering tests for PR:\n"
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
)
return (bring_to_front, the_rest)
else:
test_time_for_regular_tests_so_far += test.get_time()

if len(tests) != len(bring_to_front) + len(the_rest):
print(
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
)
return ([], tests)

# TODO: Would be great to upload these stats to RDS/Rockset!
test_cnt_str = pluralize(len(tests), "test")
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
print(
f"Prioritized tests estimated to run up to {duration_to_str(time_savings_sec)} sooner than they would've otherwise"
)

prioritized_test_names = [t.name for t in bring_to_front]
print(f"Prioritized: {prioritized_test_names}")
remaining_test_names = [t.name for t in the_rest]
print(f"The Rest: {remaining_test_names}")

return (bring_to_front, the_rest)


def get_test_case_configs(dirpath: str) -> None:
get_slow_tests(dirpath=dirpath)
Expand Down

0 comments on commit b147401

Please sign in to comment.