https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/005_visualization.html#sphx-glr-download-tutorial-10-key-features-005-visualization-py

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import pandas as pd
from pathlib import Path

import optuna

# You can use Matplotlib instead of Plotly for visualization by simply replacing `optuna.visualization` with
# `optuna.visualization.matplotlib` in the following examples.
from optuna.visualization.matplotlib import plot_contour
from optuna.visualization.matplotlib import plot_edf
from optuna.visualization.matplotlib import plot_intermediate_values
from optuna.visualization.matplotlib import plot_optimization_history
from optuna.visualization.matplotlib import plot_parallel_coordinate
from optuna.visualization.matplotlib import plot_param_importances
from optuna.visualization.matplotlib import plot_rank
from optuna.visualization.matplotlib import plot_slice
from optuna.visualization.matplotlib import plot_timeline



In [3]:
from reprpo.training import train
import tyro
from reprpo.experiments import experiment_configs
from reprpo.ax.parameters import search_spaces

from reprpo.interventions import Interventions, DPOConfig, ReprPOConfig, DPOProjGradConfig
from reprpo.interventions.losses import Losses
from reprpo.interventions.transforms import Transforms

## Objective

In [4]:
SEED=42
key_metric = "acc_gain_vs_ref/oos"

torch.manual_seed(SEED)
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
# silence please
import os
from loguru import logger
logger.remove()
logger.remove()
logger.add(os.sys.stderr, level="WARNING")

os.environ["WANDB_MODE"] = "disabled"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TQDM_DISABLE"] = "true"

In [6]:
f_db = f"sqlite:///../outputs/optuna/optuna.db"
f = f_db.replace('sqlite:///', './')
print(f)
Path(f).parent.mkdir(parents=True, exist_ok=True)
f_db

./../outputs/optuna/optuna.db


'sqlite:///../outputs/optuna/optuna.db'

In [7]:
from reprpo.ax.target import override, tuner_kwargs
from reprpo.experiments import experiment_configs
import copy


import optuna.pruners
from optuna_integration.wandb import WeightsAndBiasesCallback

Note on pruning. It's only really usefull with validation metrics and for long jobs over many epochs. I've got a small proxy job so there is no need.

In [8]:
MAX_TRIALS= 150

def list2tuples(d):
    for k, v in d.items():
        if isinstance(v, list):
            d[k] = tuple(v)
    return d

for starter_experiment_name, trial2args in search_spaces.items():
    study_name = f"{starter_experiment_name}"

    def objective_func(kwargs, trial):
        cfg = copy.deepcopy(experiment_configs[starter_experiment_name][1])
        override(cfg, tuner_kwargs)
        
        override(cfg, kwargs)
        kwargs = list2tuples(kwargs)
        r = train(cfg, trial=trial)
        return r

    def objective(trial: optuna.Trial) -> float:
        kwargs = trial2args(trial)
        r = objective_func(kwargs, trial)
        return r[key_metric]

    os.environ["WANDB_NOTEBOOK_NAME"] = f"{study_name}.ipynb"
    wandb_kwargs = {"project": "reprpo-optuna", "name": study_name}
    wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs)

    study = optuna.create_study(
        study_name=study_name,
        direction="maximize",
        load_if_exists=True,
        storage=f_db,
        sampler=optuna.samplers.TPESampler(seed=SEED),
        pruner=optuna.pruners.NopPruner(),
    )

    n = 0

    if len(study.trials)>0:
        df = study.trials_dataframe().query('state == "COMPLETE"').sort_values('value', ascending=False)
        n = len(df)

        print(f"loaded {n} {study_name} trials")
    if n < MAX_TRIALS:
        study.optimize(objective, 
                    n_trials=MAX_TRIALS, 
                    callbacks=[wandbc], gc_after_trial=True, 
                    # catch=(Exception,)
                    )

    print('study.best_trial', study.best_trial)

  wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs)
[I 2024-09-29 00:35:32,336] Using an existing study with name 'projgrad' instead of creating a new one.


loaded 175 projgrad trials
study.best_trial FrozenTrial(number=70, state=1, values=[1.057915057915058], datetime_start=datetime.datetime(2024, 9, 28, 17, 25, 12, 131962), datetime_complete=datetime.datetime(2024, 9, 28, 17, 28, 13, 416392), params={'learning-rate': 0.00012426382563887213, 'β': 0.7386239719822631, 'reverse_pref': True, 'scale_orth': True, 'weight_dim': 0, 'neg_slope': 0.5}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'learning-rate': FloatDistribution(high=0.001, log=True, low=1e-06, step=None), 'β': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'reverse_pref': CategoricalDistribution(choices=(True, False)), 'scale_orth': CategoricalDistribution(choices=(True, False)), 'weight_dim': IntDistribution(high=2, log=False, low=0, step=1), 'neg_slope': CategoricalDistribution(choices=(0, 0.1, 0.5, 1))}, trial_id=71, value=None)


  wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs)


[I 2024-09-29 00:35:32,439] Using an existing study with name 'side-ether-prefvec' instead of creating a new one.
[W 2024-09-29 00:35:32,510] Trial 6 failed with parameters: {'learning-rate': 1.329291894316217e-05, 'collect_input': True, 'collect_hs': True, 'nb': 5, 'Htype': 'etherplus', 'flip_side': False, 'reduction': 167, 'loss.β': 2.177484667394932e-05, 'use_orth_loss': False, 'use_angle_loss': False, 'use_dpo_loss': True, 'use_nll_loss': True, 'weight_tokens': False} because of the following error: TypeError("unhashable type: 'dict'").
Traceback (most recent call last):
  File "/workspace/repr-preference-optimization/.venv/lib/python3.11/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_2550942/1891008565.py", line 23, in objective
    r = objective_func(kwargs, trial)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_2550942/1891008565.py", line 18, in objective_f

loaded 0 side-ether-prefvec trials


TypeError: unhashable type: 'dict'

In [None]:
%debug

In [22]:
# plot_timeline(study)

In [None]:
plot_optimization_history(study)

In [16]:
# plot_intermediate_values(study)

In [21]:
# plot_contour(study)


In [None]:
plot_slice(study)


In [None]:
plot_param_importances(study)

### Apendix 1: dataclass 2 optuna

In [20]:
# import inspect
# import typing
# from typing import Literal

# def optuna_suggest_from_dataclass(t):
#     n = t.__name__
#     print(f'## {n}')
#     sig = inspect.signature(t)
#     for name, param in sig.parameters.items():
#         if param.annotation== bool:
#             print(f'"{name}": trial.suggest_categorical("{name}", [True, False]),')
#         elif param.annotation==int:
#             print(f'"{name}": trial.suggest_int("{name}", 1, 10),')
#         elif param.annotation ==float:
#             print(f'"{name}": trial.suggest_float("{name}", 0.1, 10.0),')
#         elif param.annotation == str:
#             print(f'"{name}": trial.suggest_categorical("{name}", ["a", "b", "c"]),')
#         elif param.annotation == tuple:
#             print(f'"{name}": trial.suggest_categorical("{name}", [(1, 2), (3, 4), (5, 6)]),')
#         elif typing.get_origin(param.annotation) == Literal:
#             print(f'"{name}": trial.suggest_categorical("{name}", {param.annotation.__args__}),')
#         else:
#             print(f"!!Unknown type {param}")
#             # print(name, param.default, param.annotation)

# optuna_suggest_from_dataclass(ReprPOConfig)
# for t in Transforms:
#     print(f'## {t}')
#     optuna_suggest_from_dataclass(t.value)
# for l in Losses:
#     print(f'## {l}')
#     optuna_suggest_from_dataclass(l.value)


# optuna_suggest_from_dataclass(DPOProjGradConfig)