In [None]:
import optuna
import matplotlib.pylab as plt
import seaborn as sns
import pandas as pd

In [None]:
dfs = []
study_names = ['ice_fishing_model_11_25_cma-es_3.0_0.1',
               'ice_fishing_model_11_25_cma-es_2.5_0.1']

for s in study_names:
    storage_name = 'sqlite:///ice_fishing_model_11_25.db'

    study = optuna.create_study(
        study_name=s, 
        storage=storage_name,
        sampler=optuna.samplers.CmaEsSampler(n_startup_trials=500),
        load_if_exists=True)
    dfs.append(study.trials_dataframe(attrs=("number", "value", "params", "state", 'user_attrs')))

In [None]:
dfs[0].head()

In [None]:
dfs = [df[(df['state'] == 'COMPLETE') & (df['number'] >= 500)] for df in dfs]

In [None]:
dfs[0].value.hist()

In [None]:
top_20_dfs = [df.sort_values(by='value', ascending=False).head(int(len(df) * 0.2)).reset_index() for df in dfs]

In [None]:
top_20_percent = pd.concat(top_20_dfs)

In [None]:
col_orig = ['value', 'user_attrs_fish_abundance', 'user_attrs_ss_tau',  
            'params_slw_base', 'params_slw_fish', 'params_slw_soc', 'params_slw_time',
            'user_attrs_ssw_fail', 'user_attrs_ssw_loc', 'user_attrs_ssw_soc', 'user_attrs_ssw_suc',
            'params_ssw_fail_ls', 'params_ssw_soc_ls', 'params_ssw_suc_ls', ]

params = ['slw_base', 'slw_fish', 'slw_soc', 'slw_time',
          'ssw_fail', 'ssw_loc', 'ssw_soc', 'ssw_suc', 
          'ssw_fail_ls', 'ssw_soc_ls', 'ssw_suc_ls']
col_new = ['catch', 'fish_abundance', 'ss_tau'] + params

col_name_mapping = {i:j for i, j in zip(col_orig, col_new)}

In [None]:
top_20_percent = top_20_percent[col_orig]
top_20_percent.rename(columns=col_name_mapping, inplace=True)

In [None]:
top_20_percent.head()

In [None]:
top_20_percent_melted = top_20_percent.loc[:, params + ['fish_abundance']].melt(
     id_vars=['fish_abundance'], var_name='parameter', value_name='value')

In [None]:
top_20_percent_melted.replace({'fish_abundance': {3.0: 'low', 2.5: 'middle', 2.0: 'high'}}, inplace=True)

In [None]:
g = sns.catplot(
    data=top_20_percent_melted,
    x='fish_abundance',
    y='value',
    hue='fish_abundance',
    col='parameter',
    kind='point',  # 'point',
    errorbar='sd',
    sharey=False,
    col_wrap=4,
    height=3,
    aspect=0.7
)

# Set titles to just the parameter name for easy matching
g.set_titles("{col_name}")

# Define manual limits
limits = {
    'slw_base': (-20, -1), 
    'slw_fish': (-15.0, 0.0), 
    'slw_soc': (-3, 3.0), 
    'slw_time': (0.0, 2.0),
    'ssw_fail': (0, 0.6),
    'ssw_loc':  (0, 0.6),
    'ssw_soc':  (0, 0.6),
    'ssw_suc':  (0, 0.6),
    'ssw_fail_ls': (0.5, 20), 
    'ssw_soc_ls': (0.5, 20),
    'ssw_suc_ls': (0.5, 50)
}

ref_lines = {
    'slw_base': -3,
    'slw_fish': -1.72,
    'slw_soc': 0,  # -0.33,
    'slw_time': 0.13,
    # 'ssw_fail': 0.25,
    # 'ssw_loc': 0.25,
    # 'ssw_soc': 0.25,
    # 'ssw_suc': 0.25,
    'ssw_fail_ls': 10,
    'ssw_soc_ls': 10,
    'ssw_suc_ls': 25
}

# Apply limits
for ax, col_name in zip(g.axes.flat, g.col_names):
    
    # 1. Apply Limits
    if col_name in limits:
        ax.set_ylim(limits[col_name])

    # 2. Draw Horizontal Lines
    if col_name in ref_lines:
        ax.axhline(
            y=ref_lines[col_name], 
            color='red', 
            linestyle='--', 
            linewidth=2, 
            alpha=0.8,
            label='Target' # Optional: adds this line to a potential legend
        )
        
sns.move_legend(
    g, 
    "lower right", 
    bbox_to_anchor=(0.93, 0.2), # Adjust the 0.02 slightly if it's cut off
    frameon=True,             # Draws a box around the legend
    title="Fish abundance"            # Ensure title is correct
)
        
plt.tight_layout()
plt.show()