https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/005_visualization.html#sphx-glr-download-tutorial-10-key-features-005-visualization-py

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] ="expandable_segments:True" # seems to stop gpu mem from filling up despite clearing

In [3]:
import torch
import pandas as pd
from pathlib import Path
import optuna
from reprpo.hp.helpers import optuna_df

In [4]:
from reprpo.training import train
from reprpo.experiments import experiment_configs
from reprpo.hp.space import search_spaces

## Objective

In [5]:
SEED=42
key_metric = "acc_gain_vs_ref/oos"
torch.manual_seed(SEED)
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [6]:
# silence please
import os
from loguru import logger
logger.remove()
logger.remove()
logger.add(os.sys.stderr, level="WARNING")

os.environ["WANDB_MODE"] = "disabled"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TQDM_DISABLE"] = "true"

In [7]:
f_db = f"sqlite:///optuna2.db"
f = f_db.replace('sqlite:///', './')
print(f)
Path(f).parent.mkdir(parents=True, exist_ok=True)
f_db

In [8]:
# print(f'to visualise run in cli\ncd nbs\noptuna-dashboard {f_db}')

In [9]:
from reprpo.hp.target import override, default_tuner_kwargs, list2tuples, objective
import functools
from reprpo.experiments import experiment_configs
import copy
import wandb

import optuna.pruners
from optuna_integration.wandb import WeightsAndBiasesCallback

## Opt

Note on pruning. It's only really usefull with validation metrics and for long jobs over many epochs. I've got a small proxy job so there is no need.

In [10]:
# from reprpo.experiments import experiment_configs
from reprpo.hp.space import experiment_configs
experiment_configs.keys()

In [11]:
import warnings
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning) 

In [12]:
from optuna.study.study import storages, get_all_study_names
study_names = get_all_study_names(storage=f_db)

for study_name in study_names:
    print(study_name)
    study = optuna.load_study(study_name=study_name, storage=f_db)
    try:
        df_res = optuna_df(study, key_metric)
        display(df_res)
        print()
    except ValueError as e:
        print('-')

In [13]:
# unit test, moved to pytest
# for exp_name, (N, trial2args) in search_spaces.items():
#     study = optuna.create_study(direction="maximize")
#     cfg = copy.deepcopy(experiment_configs[exp_name][1])
#     print('exp_name', exp_name)
#     for _ in range(10):
#         trial = study.ask()
#         kwargs = trial2args(trial)
#         override(cfg, default_tuner_kwargs)
#         override(cfg, kwargs)
#         kwargs = list2tuples(kwargs)

#     # try one dev run
#     kwargs['dev'] = True
#     _objective = functools.partial(objective, key_metric=key_metric, starter_experiment_name=exp_name, trial2args=trial2args)

#     study.optimize(_objective, 
#                 n_trials=1, # do 20 at a time, round robin, untill done
#                 gc_after_trial=True, 
#                 # catch=(AssertionError, OSError, RuntimeError, KeyError, torch.OutOfMemoryError)
#     )

In [14]:
MAX_TRIALS= 250
import numpy as np
spaces = list(search_spaces.items())
while True:
    np.random.shuffle(spaces)
    for exp_name, (max_trials, trial2args) in spaces:
        try:
            study_name = f"{exp_name}"
            study = optuna.create_study(
                study_name=study_name,
                direction="maximize",
                load_if_exists=True,
                storage=f_db,
                sampler=optuna.samplers.TPESampler(seed=SEED),
                # pruner=optuna.pruners.NopPruner(),
            )

            n = 0
            try:
                df = study.trials_dataframe().sort_values('value', ascending=False)
                n = len(df)
            except Exception as e:
                print(e)
                pass
            if n>0:
                print(f"loaded {n} {study_name} trials")

                df_res = optuna_df(study, key_metric)
                print(df_res.to_markdown())

            
            if n < max_trials:
                _objective = functools.partial(objective, key_metric=key_metric, starter_experiment_name=exp_name, trial2args=trial2args)

                study.optimize(_objective, 
                            n_trials=20, # do 20 at a time, round robin, untill done
                            gc_after_trial=True, 
                            catch=(AssertionError, OSError, RuntimeError, KeyError, torch.OutOfMemoryError)
                )

            print('='*80)
        except KeyboardInterrupt:
            break
        except Exception as e:
            logger.exception(e)