Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from urllib import request
import importlib
from typing import Any, List, Tuple

proxy_suggestion = "Unable to verify https connectivity, " \
"required for setup.\n" \
Expand All @@ -14,15 +15,15 @@
install_file = 'install.py'


def _test_https(test_url='https://github.com', timeout=0.5):
def _test_https(test_url: str = 'https://github.com', timeout: float = 0.5) -> bool:
try:
request.urlopen(test_url, timeout=timeout)
except OSError:
return False
return True


def _install_deps(model_path, verbose=True):
def _install_deps(model_path: str, verbose: bool = True) -> Tuple[bool, Any]:
run_args = [
[sys.executable, install_file],
]
Expand All @@ -35,23 +36,23 @@ def _install_deps(model_path, verbose=True):
if not verbose:
run_kwargs['stderr'] = subprocess.STDOUT
run_kwargs['stdout'] = subprocess.PIPE
subprocess.run(*run_args, **run_kwargs)
subprocess.run(*run_args, **run_kwargs) # type: ignore
else:
return (False, f"No install.py is found in {model_path}.")
except subprocess.CalledProcessError as e:
return (False, e.output)
except Exception as e:
return (False, e)

return (True, None)

return (True, None)

def _list_model_paths():

def _list_model_paths() -> List[str]:
p = Path(__file__).parent.joinpath(model_dir)
return sorted(str(child.absolute()) for child in p.iterdir() if child.is_dir())


def setup(verbose=True, continue_on_fail=False):
def setup(verbose: bool = True, continue_on_fail: bool = False) -> bool:
if not _test_https():
print(proxy_suggestion)
sys.exit(-1)
Expand All @@ -66,7 +67,7 @@ def setup(verbose=True, continue_on_fail=False):
print("FAIL")
try:
errmsg = errmsg.decode()
except:
except Exception:
pass
failures[model_path] = errmsg
if not continue_on_fail:
Expand All @@ -81,15 +82,20 @@ def setup(verbose=True, continue_on_fail=False):

return len(failures) == 0


def list_models():
models = []
for model_path in _list_model_paths():
model_name = os.path.basename(model_path)
module = importlib.import_module(f'.models.{model_name}', package=__name__)
if not hasattr(module, 'Model'):
try:
module = importlib.import_module(f'.models.{model_name}', package=__name__)
except ModuleNotFoundError as e:
print(f"Warning: Could not find dependent module {e.name} for Model {model_name}, skip it")
continue
Model = getattr(module, 'Model', None)
if Model is None:
print(f"Warning: {module} does not define attribute Model, skip it")
continue
Model = getattr(module, 'Model')
if not hasattr(Model, 'name'):
Model.name = model_name
models.append(Model)
Expand Down