diff --git a/torchbenchmark/__init__.py b/torchbenchmark/__init__.py index 634b49887c..0ed2003558 100644 --- a/torchbenchmark/__init__.py +++ b/torchbenchmark/__init__.py @@ -4,6 +4,7 @@ import sys from urllib import request import importlib +from typing import Any, List, Tuple proxy_suggestion = "Unable to verify https connectivity, " \ "required for setup.\n" \ @@ -14,7 +15,7 @@ install_file = 'install.py' -def _test_https(test_url='https://github.com', timeout=0.5): +def _test_https(test_url: str = 'https://github.com', timeout: float = 0.5) -> bool: try: request.urlopen(test_url, timeout=timeout) except OSError: @@ -22,7 +23,7 @@ def _test_https(test_url='https://github.com', timeout=0.5): return True -def _install_deps(model_path, verbose=True): +def _install_deps(model_path: str, verbose: bool = True) -> Tuple[bool, Any]: run_args = [ [sys.executable, install_file], ] @@ -35,7 +36,7 @@ def _install_deps(model_path, verbose=True): if not verbose: run_kwargs['stderr'] = subprocess.STDOUT run_kwargs['stdout'] = subprocess.PIPE - subprocess.run(*run_args, **run_kwargs) + subprocess.run(*run_args, **run_kwargs) # type: ignore else: return (False, f"No install.py is found in {model_path}.") except subprocess.CalledProcessError as e: @@ -43,15 +44,15 @@ def _install_deps(model_path, verbose=True): except Exception as e: return (False, e) - return (True, None) - + return (True, None) -def _list_model_paths(): + +def _list_model_paths() -> List[str]: p = Path(__file__).parent.joinpath(model_dir) return sorted(str(child.absolute()) for child in p.iterdir() if child.is_dir()) -def setup(verbose=True, continue_on_fail=False): +def setup(verbose: bool = True, continue_on_fail: bool = False) -> bool: if not _test_https(): print(proxy_suggestion) sys.exit(-1) @@ -66,7 +67,7 @@ def setup(verbose=True, continue_on_fail=False): print("FAIL") try: errmsg = errmsg.decode() - except: + except Exception: pass failures[model_path] = errmsg if not continue_on_fail: @@ -81,15 +82,20 @@ def setup(verbose=True, continue_on_fail=False): return len(failures) == 0 + def list_models(): models = [] for model_path in _list_model_paths(): model_name = os.path.basename(model_path) - module = importlib.import_module(f'.models.{model_name}', package=__name__) - if not hasattr(module, 'Model'): + try: + module = importlib.import_module(f'.models.{model_name}', package=__name__) + except ModuleNotFoundError as e: + print(f"Warning: Could not find dependent module {e.name} for Model {model_name}, skip it") + continue + Model = getattr(module, 'Model', None) + if Model is None: print(f"Warning: {module} does not define attribute Model, skip it") continue - Model = getattr(module, 'Model') if not hasattr(Model, 'name'): Model.name = model_name models.append(Model)