/
run_baselines.py
31 lines (24 loc) · 1.01 KB
/
run_baselines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
import importlib
from poi.project_config import BaselinesConfig
from poi.baselines import run_all_baselines
def run_project_baselines(baselines_config: BaselinesConfig):
run_all_baselines(
project_data_pipeline=baselines_config.project_data_pipeline,
project_data_pipeline_kwargs_options=baselines_config.project_data_pipeline_kwargs_options,
models=baselines_config.models,
evaluation_function=baselines_config.evaluation_function,
PROJECT_NAME=baselines_config.PROJECT_NAME,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("project", help="name of the project to run baselines for")
args = parser.parse_args()
project_name = args.project
project_config_module = importlib.import_module(f"poi.{project_name}.config")
baselines_config: BaselinesConfig = getattr(
project_config_module, "baselines_config"
)
run_project_baselines(baselines_config=baselines_config)
if __name__ == "__main__":
main()