Skip to content

Commit

Permalink
[ci] remove dead code related to test selection (#81163)
Browse files Browse the repository at this point in the history
Since we are using Rockset for all this now, remove the code that used
the S3 path.
Pull Request resolved: #81163
Approved by: https://github.com/janeyx99
  • Loading branch information
suo authored and pytorchmergebot committed Jul 12, 2022
1 parent 9f58d5d commit d321be6
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 295 deletions.
131 changes: 0 additions & 131 deletions tools/stats/export_slow_tests.py

This file was deleted.

165 changes: 1 addition & 164 deletions tools/testing/test_selections.py
@@ -1,60 +1,9 @@
import json
import os
import subprocess

from tools.stats.s3_stat_parser import (
get_previous_reports_for_branch,
Report,
Version2Report,
HAVE_BOTO3,
)
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests

from typing import Any, Dict, List, Optional, Tuple, cast
from typing_extensions import TypedDict


class JobTimeJSON(TypedDict):
commit: str
JOB_BASE_NAME: str
job_times: Dict[str, float]


def _get_stripped_CI_job() -> str:
return os.environ.get("BUILD_ENVIRONMENT", "")


def _get_job_times_json(job_times: Dict[str, float]) -> JobTimeJSON:
return {
"commit": subprocess.check_output(
["git", "rev-parse", "HEAD"], encoding="ascii"
).strip(),
"JOB_BASE_NAME": _get_stripped_CI_job(),
"job_times": job_times,
}


def _calculate_job_times(reports: List["Report"]) -> Dict[str, float]:
"""Compute test runtime by filename: ("test_file_name" -> (current_avg, # values))"""
jobs_to_times: Dict[str, Tuple[float, int]] = dict()
for report in reports:
v_report = cast(Version2Report, report)
assert (
"format_version" in v_report.keys() and v_report.get("format_version") == 2
), "S3 format currently handled is version 2 only"
files: Dict[str, Any] = v_report["files"]
for name, test_file in files.items():
if name not in jobs_to_times:
jobs_to_times[name] = (test_file["total_seconds"], 1)
else:
curr_avg, curr_count = jobs_to_times[name]
new_count = curr_count + 1
new_avg = (
curr_avg * curr_count + test_file["total_seconds"]
) / new_count
jobs_to_times[name] = (new_avg, new_count)

return {job: time for job, (time, _) in jobs_to_times.items()}
from typing import Dict, List, Tuple


def calculate_shards(
Expand Down Expand Up @@ -91,63 +40,6 @@ def calculate_shards(
return sharded_jobs


def _pull_job_times_from_S3() -> Dict[str, float]:
if HAVE_BOTO3:
ci_job_prefix = _get_stripped_CI_job()
s3_reports: List["Report"] = get_previous_reports_for_branch(
"origin/viable/strict", ci_job_prefix
)
else:
print(
"Uh oh, boto3 is not found. Either it is not installed or we failed to import s3_stat_parser."
)
print(
"If not installed, please install boto3 for automatic sharding and test categorization."
)
s3_reports = []

if len(s3_reports) == 0:
print("::warning:: Gathered no reports from S3. Please proceed without them.")
return dict()

return _calculate_job_times(s3_reports)


def _query_past_job_times(test_times_file: Optional[str] = None) -> Dict[str, float]:
"""Read historic test job times from a file.
If the file doesn't exist or isn't matching current commit. It will download data from S3 and exported it.
"""
if test_times_file and os.path.exists(test_times_file):
with open(test_times_file) as file:
test_times_json: JobTimeJSON = json.load(file)

curr_commit = subprocess.check_output(
["git", "rev-parse", "HEAD"], encoding="ascii"
).strip()
file_commit = test_times_json.get("commit", "")
curr_ci_job = _get_stripped_CI_job()
file_ci_job = test_times_json.get("JOB_BASE_NAME", "N/A")
if curr_commit != file_commit:
print(f"Current test times file is from different commit {file_commit}.")
elif curr_ci_job != file_ci_job:
print(f"Current test times file is for different CI job {file_ci_job}.")
else:
print(
f"Found stats for current commit: {curr_commit} and job: {curr_ci_job}. Proceeding with those values."
)
return test_times_json.get("job_times", {})

# Found file, but commit or CI job in JSON doesn't match
print(
f"Overwriting current file with stats based on current commit: {curr_commit} and CI job: {curr_ci_job}"
)

job_times = export_S3_test_times(test_times_file)

return job_times


def _query_changed_test_files() -> List[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'master')}"
cmd = ["git", "diff", "--name-only", default_branch, "HEAD"]
Expand All @@ -161,47 +53,6 @@ def _query_changed_test_files() -> List[str]:
return lines


# Get sharded test allocation based on historic S3 data.
def get_shard_based_on_S3(
which_shard: int, num_shards: int, tests: List[str], test_times_file: str
) -> List[str]:
# Short circuit and don't do any work if there's only 1 shard
if num_shards == 1:
return tests

jobs_to_times = _query_past_job_times(test_times_file)

# Got no stats from S3, returning early to save runtime
if len(jobs_to_times) == 0:
print(
"::warning:: Gathered no stats from S3. Proceeding with default sharding plan."
)
return tests[which_shard - 1 :: num_shards]

shards = calculate_shards(num_shards, tests, jobs_to_times)
_, tests_from_shard = shards[which_shard - 1]
return tests_from_shard


def get_slow_tests_based_on_S3(
test_list: List[str], td_list: List[str], slow_test_threshold: int
) -> List[str]:
"""Get list of slow tests based on historic S3 data."""
jobs_to_times: Dict[str, float] = _query_past_job_times()

# Got no stats from S3, returning early to save runtime
if len(jobs_to_times) == 0:
print("::warning:: Gathered no stats from S3. No new slow tests calculated.")
return []

slow_tests: List[str] = []
for test in test_list:
if test in jobs_to_times and test not in td_list:
if jobs_to_times[test] > slow_test_threshold:
slow_tests.append(test)
return slow_tests


def get_reordered_tests(tests: List[str]) -> List[str]:
"""Get the reordered test filename list based on github PR history or git changed file."""
prioritized_tests: List[str] = []
Expand Down Expand Up @@ -242,20 +93,6 @@ def get_reordered_tests(tests: List[str]) -> List[str]:
return tests


# TODO Refactor this and unify with tools.stats.export_slow_tests
def export_S3_test_times(test_times_filename: Optional[str] = None) -> Dict[str, float]:
test_times: Dict[str, float] = _pull_job_times_from_S3()
if test_times_filename is not None:
print(f"Exporting S3 test stats to {test_times_filename}.")
if os.path.exists(test_times_filename):
print(f"Overwriting existent file: {test_times_filename}")
with open(test_times_filename, "w+") as file:
job_times_json = _get_job_times_json(test_times)
json.dump(job_times_json, file, indent=" ", separators=(",", ": "))
file.write("\n")
return test_times


def get_test_case_configs(dirpath: str) -> None:
get_slow_tests(dirpath=dirpath)
get_disabled_tests(dirpath=dirpath)

0 comments on commit d321be6

Please sign in to comment.