Skip to content

Commit

Permalink
Refactor: reduce number of decorator layers (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanlin2013 committed Jun 16, 2022
1 parent 9764fd3 commit 0f45606
Showing 1 changed file with 14 additions and 27 deletions.
41 changes: 14 additions & 27 deletions mbl/workflow/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,29 @@ def run(
)


def mlflow_s3_storage(profile_name: str):
def mlflow_tracker(profile_name: str):
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
@mlflow_mixin
def wrapper(config: Dict[str, Union[int, float, str]]):
boto3.setup_default_session(profile_name=profile_name)
func(*args, **kwargs)
mlflow_config = config.pop("mlflow")
mlflow.set_tags(mlflow_config["tags"])
try:
func(config)
except Exception as e:
mlflow.set_tag("error", type(e).__name__)
mlflow.log_text(traceback.format_exc(), "error.txt")
mlflow.end_run("FAILED")

return wrapper

return decorator


def mlflow_exception_catcher(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
mlflow.set_tag("error", type(e).__name__)
mlflow.log_text(traceback.format_exc(), "error.txt")
mlflow.end_run("FAILED")

return wrapper


class GridSearch(abc.ABC):
@staticmethod
@mlflow_mixin
@mlflow_tracker
@abc.abstractmethod
def experiment(config: Dict[str, Union[int, float, str]]):
return NotImplemented
Expand All @@ -102,12 +97,8 @@ class AthenaMetadata:
table: str = "tsdrg"

@staticmethod
@mlflow_mixin
@mlflow_s3_storage(profile_name="minio")
@mlflow_exception_catcher
@mlflow_tracker(profile_name="minio")
def experiment(config: Dict[str, Union[int, float, str]]):
mlflow_config = config.pop("mlflow")
mlflow.set_tags(mlflow_config["tags"])
mlflow.log_params(config)
experiment = RandomHeisenbergTSDRG(**config)
filename = Path("-".join([f"{k}_{v}" for k, v in config.items()]) + ".p")
Expand Down Expand Up @@ -151,12 +142,8 @@ class AthenaMetadata:
table: str = "folding_tsdrg"

@staticmethod
@mlflow_mixin
@mlflow_s3_storage(profile_name="minio")
@mlflow_exception_catcher
@mlflow_tracker(profile_name="minio")
def experiment(config: Dict[str, Union[int, float, str]]):
mlflow_config = config.pop("mlflow")
mlflow.set_tags(mlflow_config["tags"])
config = RandomHeisenbergFoldingTSDRGGridSearch.retrieve_energy_bounds(config)
mlflow.log_params(config)
experiment = RandomHeisenbergFoldingTSDRG(**config)
Expand Down

0 comments on commit 0f45606

Please sign in to comment.