Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/scripts/run-bisection.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,4 @@ python bisection.py --work-dir ${BISECT_BASE}/gh${GITHUB_RUN_ID} \
--pytorch-src ${PYTORCH_SRC_DIR} \
--torchbench-src ${TORCHBENCH_SRC_DIR} \
--config ${BISECT_BASE}/config.yaml \
--output ${BISECT_BASE}/gh${GITHUB_RUN_ID}/result.json \
--debug
--output ${BISECT_BASE}/gh${GITHUB_RUN_ID}/result.json
40 changes: 32 additions & 8 deletions bisection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,21 @@ def exist_dir_path(string):
# For example, ["test_eval[yolov3-cpu-eager]", "test_train[yolov3-gpu-eager]"]
# -> "((eval and yolov3 and cpu and eager) or (train and yolov3 and gpu and eager))"
# If targets is None, run everything except slomo
def targets_to_bmfilter(targets: List[str]) -> str:
def targets_to_bmfilter(targets: List[str], models: List[str]) -> str:
bmfilter_names = []
if targets == None or len(targets) == 0:
return "(not slomo)"
for test in targets:
regex = re.compile("test_(train|eval)\[([a-zA-Z0-9_]+)-([a-z]+)-([a-z]+)\]")
m = regex.match(test).groups()
partial_name = " and ".join(m)
if not m:
if m in models:
partial_name = m
else:
print(f"Cannot recognize the TorchBench filter: {test}. Exit.")
exit(1)
else:
partial_name = " and ".join(m)
bmfilter_names.append(f"({partial_name})")
return "(" + " or ".join(bmfilter_names) + ")"

Expand Down Expand Up @@ -128,11 +135,23 @@ def prep(self) -> bool:
repo_origin_url = gitutils.get_git_origin(self.srcpath)
if not repo_origin_url == TORCH_GITREPO:
print(f"WARNING: Unmatched repo origin url: {repo_origin_url} with standard {TORCH_GITREPO}")
self.update_repos()
return True


# Update pytorch, torchtext, torchvision, and torchaudio repo
def update_repos(self):
repos = [self.srcpath]
repos.extend(TORCHBENCH_DEPS.values())
for repo in repos:
gitutils.clean_git_repo(repo)
assert gitutils.update_git_repo(repo), f"Failed to update master branch of {repo}."

# Get all commits between start and end, save them in self.commits
def init_commits(self, start: str, end: str) -> bool:
commits = gitutils.get_git_commits(self.srcpath, start, end)
def init_commits(self, start: str, end: str, abtest: bool) -> bool:
if not abtest:
commits = gitutils.get_git_commits(self.srcpath, start, end)
else:
commits = [start, end]
if not commits or len(commits) < 2:
print(f"Failed to retrieve commits from {start} to {end} in {self.srcpath}.")
return False
Expand Down Expand Up @@ -220,26 +239,31 @@ class TorchBench:
timelimit: int # timeout limit in minutes
workdir: str
devbig: str
models: List[str]
torch_src: TorchSource

def __init__(self, srcpath: str,
torch_src: TorchSource,
timelimit: int,
workdir: str,
devbig: str,
branch: str = "0.1"):
branch: str = "master"):
self.srcpath = srcpath
self.torch_src = torch_src
self.timelimit = timelimit
self.workdir = workdir
self.devbig = devbig
self.branch = branch
self.models = list()

def prep(self) -> bool:
# Verify the code in srcpath is pytorch/benchmark
repo_origin_url = gitutils.get_git_origin(self.srcpath)
if not repo_origin_url == TORCHBENCH_GITREPO:
print(f"WARNING: Unmatched repo origin url: {repo_origin_url} with standard {TORCHBENCH_GITREPO}")
# get list of models
self.models = [ model for model in os.listdir(os.path.join(self.srcpath, "torchbenchmark", "models"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to do this instead of importing 'list_models' function from torchbenchmark?

Copy link
Contributor Author

@xuzhao9 xuzhao9 Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is called in prepartion, when the pytorch/torchvision/torchtext packages are not built and installed yet (they are installed later when actually running the bisection/abtesting). Therefore, the list_models function doesn't work at this point.

if os.path.isdir(os.path.join(self.srcpath, "torchbenchmark", "models", model)) ]
return True

def run_benchmark(self, commit: Commit, targets: List[str]) -> str:
Expand All @@ -253,7 +277,7 @@ def run_benchmark(self, commit: Commit, targets: List[str]) -> str:
os.remove(os.path.join(output_dir, f))
else:
os.mkdir(output_dir)
bmfilter = targets_to_bmfilter(targets)
bmfilter = targets_to_bmfilter(targets, self.models)
print(f"Running TorchBench for commit: {commit.sha}, filter {bmfilter} ...", end="", flush=True)
if not self.devbig:
command = f"""bash .github/scripts/run-bench.sh "{output_dir}" "{bmfilter}" &> {output_dir}/benchmark.log"""
Expand Down Expand Up @@ -395,7 +419,7 @@ def regression(self, left: Commit, right: Commit, targets: List[str]) -> List[st
def prep(self) -> bool:
if not self.torch_src.prep():
return False
if not self.torch_src.init_commits(self.start, self.end):
if not self.torch_src.init_commits(self.start, self.end, self.abtest):
return False
if not self.bench.prep():
return False
Expand Down
11 changes: 10 additions & 1 deletion torchbenchmark/util/gitutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@
from datetime import datetime
from typing import Optional, List

def update_git_repo(repo: str, branch: str) -> bool:
def clean_git_repo(repo: str) -> bool:
try:
command = f"git clean -xdf"
subprocess.check_call(command, cwd=repo, shell=True)
return True
except subprocess.CalledProcessError:
print(f"Failed to cleanup git repo {repo}")
return None

def update_git_repo_branch(repo: str, branch: str) -> bool:
try:
command = f"git pull origin {branch}"
out = subprocess.check_output(command, cwd=repo, shell=True).decode().strip()
Expand Down