https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/005_visualization.html#sphx-glr-download-tutorial-10-key-features-005-visualization-py

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import pandas as pd

import optuna

# You can use Matplotlib instead of Plotly for visualization by simply replacing `optuna.visualization` with
# `optuna.visualization.matplotlib` in the following examples.
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_intermediate_values
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_rank
from optuna.visualization import plot_slice
from optuna.visualization import plot_timeline



In [3]:
from reprpo.training import train
import tyro
from reprpo.experiments import experiment_configs
from reprpo.interventions import Interventions, DPOConfig, ReprPOConfig, DPOProjGradConfig
from reprpo.interventions.losses import Losses
from reprpo.interventions.transforms import Transforms

## Objective

In [4]:
from pathlib import Path
key_metric = "acc_gain_vs_ref/oos"
study_name = "projgrad"
SEED=42

torch.manual_seed(SEED)
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
from reprpo.ax.target import objective_func

In [6]:
# silence please
import os
from loguru import logger
logger.remove()
logger.remove()
logger.add(os.sys.stderr, level="WARNING")

os.environ["WANDB_MODE"] = "disabled"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TQDM_DISABLE"] = "true"

https://github.com/vwxyzjn/cleanrl/blob/master/tuner_example.py

In [7]:
def trial2args(trial: optuna.Trial):
    args = {
        "learning-rate": trial.suggest_float("learning-rate", 1e-6, 1e-3, log=True),
        "collect_input": trial.suggest_categorical("collect_input", [True, False]),
        "collect_hs": trial.suggest_categorical("collect_hs", [True, False]),
        
        # ## OrthoConfig
        # "orthogonal_map": trial.suggest_categorical("orthogonal_map", ('householder', 'cayley', 'matrix_exp')),

        # ## ETHERConfig
        # "nb": trial.suggest_int("nb", 1, 32),
        # "Htype": trial.suggest_categorical("Htype", ["ether", "etherplus", "oft", "etherplusHH"]),
        # # "ether_dropout": trial.suggest_float("ether_dropout", 0.1, 10.0),
        # "flip_side": trial.suggest_categorical("flip_side", [True, False]),
        # "reduction": trial.suggest_int("reduction", 1, 200),

        # ## HRAConfig
        # "r": trial.suggest_int("r", 2, 128),
        # "apply_GS": trial.suggest_categorical("apply_GS", [True, False]),

        # ## SVDConfig
        # "quantile": trial.suggest_categorical("quantile", [0.1, 0.5, 0.75, 1]),
        # "dual_svd": trial.suggest_categorical("dual_svd", [True, False]),

        # # prefvec
        # "loss.β": trial.suggest_float("loss.β", 1e-6, 2.0, log=True),
        # "use_orth_loss": trial.suggest_categorical("use_orth_loss", [True, False]),
        # "use_angle_loss": trial.suggest_categorical("use_angle_loss", [True, False]),
        # "use_dpo_loss": trial.suggest_categorical("use_dpo_loss", [True, False]),
        # "use_nll_loss": trial.suggest_categorical("use_nll_loss", [True, False]),
        # "weight_tokens": trial.suggest_categorical("weight_tokens", [True, False]),

        # ## RankLossConfig
        # "α": trial.suggest_float("α", 0, 10.0),

        # ## MSELossConfig
        # "α": trial.suggest_float("α", 0, 10.0),

        # projgrad
        "β": trial.suggest_float("β", 0.0, 1.0, log=False),
        "reverse_pref": trial.suggest_categorical("reverse_pref", [True, False]),
        "scale_orth": trial.suggest_categorical("scale_orth", [True, False]),
        "weight_dim": trial.suggest_int("weight_dim", 0, 2),
        "neg_slope": trial.suggest_categorical("neg_slope",[0, 0.1, 0.5, 1]),
    }
    return args

In [9]:
# # TODO we can report early
# trial.report(aggregated_normalized_score, step=seed)
# # TODO we can report fails
# if if trial.should_prune():
#     raise optuna.TrialPruned()

# https://github.com/optuna/optuna-integration/blob/935b44965316acb6076188a1e708481fe5d2978d/optuna_integration/pytorch_lightning/pytorch_lightning.py#L28

In [10]:
import optuna.pruners
from optuna_integration.wandb import WeightsAndBiasesCallback
wandb_kwargs = {"project": "reprpo-optuna", "name": study_name}
wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs)

  wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Note on pruning. It's only really usefull with validation metrics and for long jobs over many epochs. I've got a small proxy job so there is no need.

In [24]:
f_db = f"sqlite:///../outputs/optuna/{study_name}.db"
f = f_db.replace('sqlite:///', './')
print(f)
Path(f).parent.mkdir(parents=True, exist_ok=True)
f_db

./../outputs/optuna/projgrad.db


'sqlite:///../outputs/optuna/projgrad.db'

In [29]:
from reprpo.ax.target import override, tuner_kwargs
from reprpo.experiments import experiment_configs
import copy


starter_experiment_name = "projgrad"

def objective_func(kwargs, trial):
    cfg = copy.deepcopy(experiment_configs[starter_experiment_name][1])
    override(cfg, tuner_kwargs)
    # now subcommands
    override(cfg, kwargs)
    r = train(cfg, trial=trial)
    return r

def objective(trial: optuna.Trial) -> float:
    kwargs = trial2args(trial)
    r = objective_func(kwargs, trial)
    return r[key_metric]

In [30]:
study = optuna.create_study(
    study_name=study_name,
    direction="maximize",
    load_if_exists=True,
    storage=f_db,
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=optuna.pruners.NopPruner(),
)
study.optimize(objective, n_trials=30, timeout=60*10, callbacks=[wandbc])

[I 2024-09-28 08:28:34,270] Using an existing study with name 'projgrad' instead of creating a new one.


In [None]:
study.best_trial

In [None]:
plot_timeline(study)

In [None]:
plot_optimization_history(study)

In [None]:
# plot_intermediate_values(study)

In [None]:
plot_contour(study)


In [None]:
plot_slice(study)


In [None]:
plot_param_importances(study)

### Apendix 1: dataclass 2 optuna

In [None]:
# import inspect
# import typing
# from typing import Literal

# def optuna_suggest_from_dataclass(t):
#     n = t.__name__
#     print(f'## {n}')
#     sig = inspect.signature(t)
#     for name, param in sig.parameters.items():
#         if param.annotation== bool:
#             print(f'"{name}": trial.suggest_categorical("{name}", [True, False]),')
#         elif param.annotation==int:
#             print(f'"{name}": trial.suggest_int("{name}", 1, 10),')
#         elif param.annotation ==float:
#             print(f'"{name}": trial.suggest_float("{name}", 0.1, 10.0),')
#         elif param.annotation == str:
#             print(f'"{name}": trial.suggest_categorical("{name}", ["a", "b", "c"]),')
#         elif param.annotation == tuple:
#             print(f'"{name}": trial.suggest_categorical("{name}", [(1, 2), (3, 4), (5, 6)]),')
#         elif typing.get_origin(param.annotation) == Literal:
#             print(f'"{name}": trial.suggest_categorical("{name}", {param.annotation.__args__}),')
#         else:
#             print(f"!!Unknown type {param}")
#             # print(name, param.default, param.annotation)

# optuna_suggest_from_dataclass(ReprPOConfig)
# for t in Transforms:
#     print(f'## {t}')
#     optuna_suggest_from_dataclass(t.value)
# for l in Losses:
#     print(f'## {l}')
#     optuna_suggest_from_dataclass(l.value)


# optuna_suggest_from_dataclass(DPOProjGradConfig)