In [None]:
import optuna
from plotly.io import show
from optuna.visualization import plot_param_importances
from hydra import compose, initialize
from pathlib import Path
import pandas as pd

In [64]:
with initialize(version_base=None, config_path="../configs/filepaths"):
    filepaths = compose(config_name="base")

In [65]:
split_strategy = "random_reaction_center"
split_idx = 0
objective = "val_roc"

In [120]:
models = ['rc_agg', 'rc_cxn', 'bom', 'cgr', 'mfp', 'drfp', 'rxnfp',]
n_top = 10
for i, model in enumerate(models):
    study_name = f"{model}_{split_strategy}_split_{split_idx}_obj_{objective}"
    storage = f"sqlite:///{filepaths.results}/hpo/{study_name}.db"
    study = optuna.load_study(
        study_name=study_name,
        storage=storage,
    )
    df = study.trials_dataframe()
    df.sort_values(
        by=["value"],
        ascending=False,
        inplace=True,
    )
    average_top_values = df["value"].head(n_top).mean()
    best_trial = df.iloc[0]
    best_params = {
        key.replace("params_", ""): value
        for key, value in best_trial.items()
        if key.startswith("params_")
    }
    best_value = best_trial["value"]
    best_trial_number = best_trial["number"]
    print("-" * 80)
    print(f"Model #{i}: {model} | Best {objective}: {best_value:.4f} on trial {best_trial_number} | Top-{n_top}-mean: {average_top_values} | # of trials: {len(df)}")
    print(f"Best params: {best_params}")
    print("-" * 80)

--------------------------------------------------------------------------------
Model #0: rc_agg | Best val_roc: 0.9028 on trial 110 | Top-10-mean: 0.8977639734745025 | # of trials: 169
Best params: {'data/neg_multiple': np.int64(4), 'model/d_h_encoder': np.int64(95), 'model/encoder_depth': np.int64(4), 'training/pos_multiplier': np.int64(1)}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Model #1: rc_cxn | Best val_roc: 0.8950 on trial 47 | Top-10-mean: 0.8793943822383881 | # of trials: 92
Best params: {'data/neg_multiple': np.int64(5), 'model/d_h_encoder': np.int64(33), 'model/encoder_depth': np.int64(6), 'training/pos_multiplier': np.int64(2)}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Model #2: bom | Best val_roc: 0.8509 on trial 84 | Top-10-mean: 0.8

In [114]:
model = "rxnfp"
study_name = f"{model}_{split_strategy}_split_{split_idx}_obj_{objective}"
storage = f"sqlite:///{filepaths.results}/hpo/{study_name}.db"
study = optuna.load_study(
    study_name=study_name,
    storage=storage,
)

In [115]:
print(len(study.trials))
optuna.visualization.plot_optimization_history(study)

1000


In [116]:
fig = plot_param_importances(study)
fig.show()

In [119]:
mode = lambda x: x.mode().values[0]
mean = lambda x: x.mean()

df = study.trials_dataframe()
df.sort_values(
    by=["value"],
    ascending=False,
    inplace=True,
)
k = 5
top_k = df.iloc[:k][["value", "user_attrs_n_epochs", "duration", *[col for col in df.columns if col.startswith("params_")]]]

aggs = {
    "params_data/neg_multiple": mode,
    "params_model/d_h_encoder": mean,
    "params_model/model": mode,
    "params_model/encoder_depth": mode,
    "params_training/pos_multiplier": mode,
}

for col in top_k.columns:
    
    if col.startswith("params_"):
        try:
            print(f"{col}: {int(aggs[col](top_k[col]))}")
        except:
            pass
display(top_k)


params_data/neg_multiple: 3
params_model/d_h_encoder: 55
params_model/encoder_depth: 4
params_training/pos_multiplier: 1


Unnamed: 0,value,user_attrs_n_epochs,duration,params_data/neg_multiple,params_model/d_h_encoder,params_model/encoder_depth,params_model/model,params_training/pos_multiplier
760,0.832892,10.0,0 days 00:00:44.446009,3,36,4,linear,1
610,0.827421,10.0,0 days 00:00:45.466971,3,49,4,linear,1
91,0.82603,10.0,0 days 00:00:46.452361,3,94,4,linear,1
305,0.824831,10.0,0 days 00:00:46.544497,3,65,4,linear,1
526,0.82446,10.0,0 days 00:00:46.072071,3,35,4,linear,1
