https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/005_visualization.html#sphx-glr-download-tutorial-10-key-features-005-visualization-py

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
from pathlib import Path
import optuna
from reprpo.hp.helpers import optuna_df
from reprpo.training import train
from reprpo.experiments import experiment_configs
from reprpo.hp.space import search_spaces
from optuna.study.study import storages, get_all_study_names
from reprpo.hp.helpers import get_params, get_optuna_df


[2024-10-06 11:13:49,621] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
import seaborn as sns
import matplotlib.pyplot as plt

from matplotlib import rcParams
rcParams['figure.dpi'] = 80
rcParams['figure.figsize'] = 4,2

In [4]:
import warnings
warnings.filterwarnings("ignore")

In [5]:
from reprpo.hp.target import override, default_tuner_kwargs
from reprpo.experiments import experiment_configs
import copy

## Objective

In [5]:
SEED=42
key_metric = "acc_gain_vs_ref/oos"

In [None]:
f_db = f"sqlite:///optuna.db"
f = f_db.replace('sqlite:///', './')
print(f)
Path(f).parent.mkdir(parents=True, exist_ok=True)
f_db

## Opt

Note on pruning. It's only really usefull with validation metrics and for long jobs over many epochs. I've got a small proxy job so there is no need.

In [None]:
def plot_param_importances(study, key_metric):
    df_res = get_optuna_df(study, key_metric)
    df2 = study.trials_dataframe().query('state == "COMPLETE"').sort_values('value', ascending=False)

    # plot_param_importances(study)
    # plt.show()

    for i in range(len(df_res)):
        row = df_res.iloc[i]
        param = row.name
        best = row.best
        x = f"params_{param}"
        hue = df2[x].apply(lambda x: x == best)
        if row.dist=='categorical':
            p = sns.catplot(data=df2, x=x, y="value", kind="boxen", 
                        #height=3, 
                        height=2, aspect=3,
                        legend=False,
                        hue=hue)
        else:
            plt.figure(figsize=(6, 2))
            sns.scatterplot(data=df2, x=x, y='value', hue=hue, alpha=0.5, ax=plt.gca())
            sns.scatterplot(data=df2[hue], x=x, y='value', ax=plt.gca(), marker='x', color='blue', s=100)
        plt.legend('off').remove()
        plt.title(f"{param} i={row.importance:2.2g} best={best}")
        plt.xlabel(None)
        # plt.ylabel(key_metric)
        plt.show()

In [None]:

study_names = get_all_study_names(storage=f_db)

for study_name in study_names:
    print(study_name)
    study = optuna.load_study(study_name=study_name, storage=f_db)
    try:
        df_res2 = optuna_df(study, key_metric)
        display(df_res2)
        plot_param_importances(study, key_metric)
        # print()
    except ValueError as e:
        print('-')
    if len(df_res2): 1/0

In [None]:
df_res = get_optuna_df(study, key_metric)
df_res