In [1]:
%reload_ext autoreload
%autoreload 2

- https://ax.dev/tutorials/gpei_hartmann_service.html

In [2]:
from ax import optimize

In [3]:
from reprpo.training import train
import tyro
from reprpo.experiments import experiment_configs
from reprpo.interventions import Interventions, DPOConfig, ReprPOConfig
from reprpo.interventions.losses import Losses
from reprpo.interventions.transforms import Transforms

# training_args = tyro.extras.overridable_config_cli(experiment_configs)
# training_args



In [4]:
import copy

def setattrattr(cfg, k, v):
    """
    Sets an attr even it's like o.a.b
    """
    if '.' in k:
        k, k2 = k.split('.')
        # print(k, k2)
        # print(getattr(cfg, k))
        return setattrattr(getattr(cfg, k), k2, v)
    else:
        # print(cfg, k, v)
        return setattr(cfg, k, v)

In [5]:
# quick 2m per run
tuner_kwargs = dict(
    verbose=0,
    base_model= 'TinyLlama/TinyLlama-1.1B-Chat-v1.0', # ideally would be SFT
    batch_size=64,
    load_in_4bit=True,
    collection_layers_side = [8, 10, 12, 14, 16, 18],
    eval_samples=24,
)



def override(cfg, overrides):
    for k, v in overrides.items():
        try:
            setattrattr(cfg, k, v)
        except ValueError:
            print(f"3 WARNING: {k} not found in config")
    return cfg

def objective_func(**kwargs):
    print('1 kwargs', kwargs)
    cfg = copy.deepcopy(experiment_configs['side-ether-prefvec'][1])
    override(cfg, tuner_kwargs)
    override(cfg, kwargs)
    r = train(cfg)
    print('r', r)
    return r

In [6]:
from pathlib import Path
name = 'ax1'
exp_f = Path(f'../outputs/ax/{name}.json')
exp_f.parent.mkdir(exist_ok=True, parents=True)

note you can have dependant params
- https://github.com/facebook/Ax/issues/1454

In [7]:
import warnings
import os
from ax.core.parameter import AxParameterWarning
warnings.filterwarnings("ignore", module="ax")
warnings.simplefilter("ignore", AxParameterWarning)

from loguru import logger
logger.remove()
logger.remove()
# logger.add(os.sys.stdout, level="INFO")
logger.add(os.sys.stderr, level="WARNING")

os.environ["WANDB_MODE"] = "disabled"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TQDM_DISABLE"] = "true"

In [8]:
# SearchSpace(parameters=[RangeParameter(name='lr', parameter_type=FLOAT, range=[1e-06, 0.4], log_scale=True), RangeParameter(name='loss.β', parameter_type=FLOAT, range=[1e-06, 0.4], log_scale=True), ChoiceParameter(name='loss.use_dpo_loss', parameter_type=BOOL, values=[False, True], is_ordered=True, sort_values=True), ChoiceParameter(name='loss.use_nll_loss', parameter_type=BOOL, values=[False, True], is_ordered=True, sort_values=True), ChoiceParameter(name='loss.weight_tokens', parameter_type=BOOL, values=[False, True], is_ordered=True, sort_values=True), ChoiceParameter(name='loss.use_orth_loss', parameter_type=BOOL, values=[False, True], is_ordered=True, sort_values=True)], parameter_constraints=[]).

In [None]:
# r = objective_func()
# r

In [None]:
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.measurement.synthetic_functions import branin

if exp_f.exists():
    ax_client = AxClient.load(filepath=exp_f)
else:
    ax_client = AxClient(verbose_logging=False)


ax_client.create_experiment(
    name="branin_test_experiment",
    parameters=[
         {"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
        {"name": "loss.β", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
        {
            "name": "loss.use_dpo_loss",
            "type": "choice",
            "values": [False, True],
        },
        {
            "name": "loss.use_nll_loss",
            "type": "choice",
            "values": [False, True],
        },
        {
            "name": "loss.weight_tokens",
            "type": "choice",
            "values": [False, True],
        },
        {
            "name": "loss.use_orth_loss",
            "type": "choice",
            "values": [False, True],
        },
    ],
    objectives={"oos": ObjectiveProperties(minimize=False)},
)

for _ in range(15):
    parameters, trial_index = ax_client.get_next_trial()
    try:
        r = objective_func(**parameters)
    except Exception as e:
        logger.exception(f"Error in objective_func: parameters={parameters}")
        continue
    ax_client.complete_trial(trial_index=trial_index, raw_data=r)

best_parameters, metrics = ax_client.get_best_parameters()

In [None]:
ax_client.save(filepath=exp_f)