In [1]:
%reload_ext autoreload
%autoreload 2

- https://ax.dev/tutorials/gpei_hartmann_service.html
- https://ax.dev/versions/0.4.1/tutorials/gpei_hartmann_service.html

In [2]:
from ax import optimize

In [3]:
from reprpo.training import train
import tyro
from reprpo.experiments import experiment_configs
from reprpo.interventions import Interventions, DPOConfig, ReprPOConfig
from reprpo.interventions.losses import Losses
from reprpo.interventions.transforms import Transforms

# training_args = tyro.extras.overridable_config_cli(experiment_configs)
# training_args

In [4]:
import copy


def setattrattr(cfg, k, v):
    """
    Sets an attr even it's like o.a.b
    """
    if "." in k:
        k, k2 = k.split(".")
        # print(k, k2)
        # print(getattr(cfg, k))
        return setattrattr(getattr(cfg, k), k2, v)
    else:
        # print(cfg, k, v)
        return setattr(cfg, k, v)

In [5]:
# quick 2m per run
tuner_kwargs = dict(
    verbose=0,
    base_model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",  # ideally would be SFT
    batch_size=64,
    load_in_4bit=True,
    collection_layers_side=[8, 10, 12, 14, 16, 18],
    eval_samples=64,
)


def override(cfg, overrides):
    for k, v in overrides.items():
        try:
            setattrattr(cfg, k, v)
        except ValueError:
            print(f"WARNING: {k} not found in config")
    return cfg


def objective_func(**kwargs):
    cfg = copy.deepcopy(experiment_configs["side-ether-prefvec"][1])
    override(cfg, tuner_kwargs)
    override(cfg, kwargs)
    r = train(cfg)
    # print('r', r)
    return r

In [6]:
from pathlib import Path

name="ether-prefvec2"
exp_f = Path(f"../outputs/ax/{name}.json")
exp_f.parent.mkdir(exist_ok=True, parents=True)
exp_f

PosixPath('../outputs/ax/ether-prefvec2.json')

note you can have dependant params
- https://github.com/facebook/Ax/issues/1454

In [7]:
import warnings
import os
from ax.core.parameter import AxParameterWarning

warnings.filterwarnings("ignore", module="ax")
warnings.simplefilter("ignore", AxParameterWarning)

from loguru import logger

logger.remove()
logger.remove()
# logger.add(os.sys.stdout, level="INFO")
logger.add(os.sys.stderr, level="WARNING")

os.environ["WANDB_MODE"] = "disabled"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TQDM_DISABLE"] = "true"

In [8]:
from ax.service.ax_client import AxClient, ObjectiveProperties

if exp_f.exists():
    ax_client = AxClient.load_from_json_file(filepath=exp_f, verbose_logging=False)
else:
    ax_client = AxClient(verbose_logging=False)

    ax_client.create_experiment(
        name=name,
        parameters=[
            # main
            {"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
            {"name": "collect_input", "type": "choice", "values": [False, True]},
            {"name": "collect_hs", "type": "choice", "values": [False, True]},
            # {
            #     "name": "loss",
            #     "type": "choice",
            #     "values": ["mse", "prefvec", "rank"],
            #     "dependents": {
            #         "prefvec": ["loss.β", "loss.use_dpo_loss","loss.use_nll_loss","loss.weight_tokens","loss.use_orth_loss",],
            #         "rank": ["loss.α"],
            #         "mse": ["loss.α"],
            #     },
            # },
            # {
            #     "name": "transform",
            #     "type": "choice",
            #     "values": ["ether", "svd", "ortho", "none", "hra"],
            #     "dependents": {
            #         "ether": ["transform.nb", "transform.Htype", "transform.reduction"],
            #         "svd": ["transform.quantile", "transform.dual_svd"],
            #     },
            # },
            # NOT prefvec
            # {
            #     "name": "loss.α",
            #     "type": "range",
            #     "bounds": [1.e-6, 2.],
            #     "log_scale": True,
            # },
            # # SVD
            # {
            #     "name": "transform.quantile",
            #     "type": "choice",
            #     "values": [0.1, 0.25, 0.5, 0.75, 1.0],
            # },
            # {
            #     "name": "transform.dual_svd",
            #     "type": "choice",
            #     "values": [False, True],
            # },
            # prefvec
            {
                "name": "loss.β",
                "type": "range",
                "bounds": [1.e-6, 2.],
                "log_scale": True,
            },
            {
                "name": "loss.use_dpo_loss",
                "type": "choice",
                "values": [False, True],
            },
            {
                "name": "loss.use_nll_loss",
                "type": "choice",
                "values": [False, True],
            },
            {
                "name": "loss.use_angle_loss",
                "type": "choice",
                "values": [False, True],
            },
            {
                "name": "loss.weight_tokens",
                "type": "choice",
                "values": [False, True],
            },
            {
                "name": "loss.use_orth_loss",
                "type": "choice",
                "values": [False, True],
            },
            # ether
            {
                "name": "transform.nb",
                "type": "range",
                "bounds": [1, 64],
            },
            {
                "name": "transform.Htype",
                "type": "choice",
                "values": ["ether", "etherplus", "oft", "etherplusHH"],
            },
            {
                "name": "transform.reduction",
                "type": "range",
                "bounds": [1, 128],
            },
        ],
        tracking_metric_names=[
            "acc/train",
            "acc/test",
            "acc/oos",
            "acc/rnd",
            
            "acc_gain_vs_ref/train",
            "acc_gain_vs_ref/test",
            "acc_gain_vs_ref/oos",
            "acc_gain_vs_ref/rnd",

            "perplexity_gain_vs_ref/train",
            "perplexity_gain_vs_ref/test",
            "perplexity_gain_vs_ref/oos",
            "perplexity_gain_vs_ref/rnd",

            "preference_logp_gain/train",
            "preference_logp_gain/test",
            "preference_logp_gain/oos",
            "preference_logp_gain/rnd",

            "preference_logp_gain_vs_ref/train",
            "preference_logp_gain_vs_ref/test",
            "preference_logp_gain_vs_ref/oos",
            "preference_logp_gain_vs_ref/rnd",
        ],
        objectives={"acc_gain_vs_ref/oos": ObjectiveProperties(minimize=False)},
    )

[INFO 09-25 14:04:22] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter lr. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 09-25 14:04:22] ax.service.utils.instantiation: Inferred value type of ParameterType.BOOL for parameter collect_input. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 09-25 14:04:22] ax.service.utils.instantiation: Inferred value type of ParameterType.BOOL for parameter collect_hs. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 09-25 14:04:22] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter loss.β. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in par

## Run

In [9]:
from tqdm.auto import tqdm

for _ in tqdm(range(450)):
    parameters, trial_index = ax_client.get_next_trial()
    try:
        r = objective_func(**parameters)
        print(parameters, r)
    except KeyboardInterrupt:
        ax_client.save_to_json_file(filepath=exp_f)
        break
    except Exception as e:
        logger.exception(f"Error in objective_func: parameters={parameters}")
        continue
    ax_client.complete_trial(trial_index=trial_index, raw_data=r)

best_parameters, metrics = ax_client.get_best_parameters()

  0%|          | 0/450 [00:00<?, ?it/s]



ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=4.698871626213426e-06,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                           

Using the latest cached version of the dataset since wassname/genies_preferences couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'us_history_textbook' at /root/.cache/huggingface/datasets/wassname___genies_preferences/us_history_textbook/0.0.0/9e92ec3b21e9800bb26e9f7cdc5792103b651b15 (last modified on Wed Sep 25 15:00:42 2024).
Using the latest cached version of the dataset since wassname/genies_preferences couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'us_history_textbook' at /root/.cache/huggingface/datasets/wassname___genies_preferences/us_history_textbook/0.0.0/9e92ec3b21e9800bb26e9f7cdc5792103b651b15 (last modified on Wed Sep 25 15:00:42 2024).
Using the latest cached version of the dataset since wassname/genies_preferences couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'us_history_textbook' at /root/.cache/huggingface/datasets/wassname___genies_preferences/us

ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=0.00019311629320111775,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                          

[INFO 09-25 15:27:25] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=1.8174077223017635e-05,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                          

[INFO 09-25 15:38:25] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=0.0004311289520020512,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                           

[INFO 09-25 16:23:16] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=2.299695442734178e-06,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                           

[INFO 09-25 17:00:56] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=0.0004929435553744471,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                           

[INFO 09-25 17:25:14] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping
[INFO 09-25 17:25:14] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=0.0005301425105697182,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                           

[INFO 09-25 17:56:24] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=0.00028782301272903255,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                          

[INFO 09-25 18:16:20] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=0.00010509713959707121,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                          

[INFO 09-25 18:36:04] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=1e-06,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                                  'base_mod

[INFO 09-25 18:55:47] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=1.518572645007845e-06,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                           

[INFO 09-25 19:57:55] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=5.989903332998635e-06,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                           

[INFO 09-25 21:24:43] ax.modelbridge.base: Untransformed parameter 2.0000000000000004 greater than upper bound 2.0, clamping


ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=1e-06,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj',
                                  'base_mod

KeyboardInterrupt: 

In [10]:
ax_client.save_to_json_file(filepath=exp_f)

## Viewm

In [13]:
import pandas as pd

df = ax_client.generation_strategy.trials_as_df
# d = df['Arm Parameterizations'].values
# pd.DataFrame([next(iter(dd.values())) for dd in d])
df

[INFO 09-26 00:00:01] ax.modelbridge.generation_strategy: Note that parameter values in dataframe are rounded to 2 decimal points; the values in the dataframe are thus not the exact ones suggested by Ax in trials.


Unnamed: 0,Generation Step,Generation Model(s),Trial Index,Trial Status,Arm Parameterizations
0,[GenerationStep_0],[Sobol],0,COMPLETED,"{'0_0': {'lr': 0.0, 'collect_input': False, 'c..."
1,[GenerationStep_0],[Sobol],1,COMPLETED,"{'1_0': {'lr': 0.36, 'collect_input': True, 'c..."
2,[GenerationStep_0],[Sobol],2,COMPLETED,"{'2_0': {'lr': 0.0, 'collect_input': False, 'c..."
3,[GenerationStep_0],[Sobol],3,COMPLETED,"{'3_0': {'lr': 0.0, 'collect_input': True, 'co..."
4,[GenerationStep_0],[Sobol],4,COMPLETED,"{'4_0': {'lr': 0.0, 'collect_input': False, 'c..."
5,[GenerationStep_0],[Sobol],5,COMPLETED,"{'5_0': {'lr': 0.0, 'collect_input': True, 'co..."
6,[GenerationStep_0],[Sobol],6,COMPLETED,"{'6_0': {'lr': 0.03, 'collect_input': False, '..."
7,[GenerationStep_0],[Sobol],7,COMPLETED,"{'7_0': {'lr': 0.0, 'collect_input': True, 'co..."
8,[GenerationStep_0],[Sobol],8,COMPLETED,"{'8_0': {'lr': 0.0, 'collect_input': False, 'c..."
9,[GenerationStep_0],[Sobol],9,COMPLETED,"{'9_0': {'lr': 0.05, 'collect_input': True, 'c..."


In [16]:
df = ax_client.get_trials_data_frame()
d = df.iloc[:, 4:].sort_values("acc_gain_vs_ref/oos", ascending=False)


def make_pretty(styler):
    styler.set_caption("Ax results")
    styler.background_gradient(axis=None, cmap="seismic_r")
    return styler


make_pretty(d.style)



Unnamed: 0,acc/oos,acc/rnd,acc/test,acc/train,acc_gain_vs_ref/oos,acc_gain_vs_ref/rnd,acc_gain_vs_ref/test,acc_gain_vs_ref/train,perplexity_gain_vs_ref/oos,perplexity_gain_vs_ref/rnd,perplexity_gain_vs_ref/test,perplexity_gain_vs_ref/train,preference_logp_gain/oos,preference_logp_gain/rnd,preference_logp_gain/test,preference_logp_gain/train,preference_logp_gain_vs_ref/oos,preference_logp_gain_vs_ref/rnd,preference_logp_gain_vs_ref/test,preference_logp_gain_vs_ref/train,lr,collect_input,collect_hs,loss.β,loss.use_dpo_loss,loss.use_nll_loss,loss.use_angle_loss,loss.weight_tokens,loss.use_orth_loss,transform.nb,transform.reduction,transform.Htype
48,0.744,0.921875,0.96875,1.0,1.113772,1.0,1.0,1.0,1.049971,1.356395,1.162667,1.170551,16.725189,16.179222,42.420944,42.461075,6.18819,2.693804,5.405641,6.314675,0.000403,False,False,1e-06,False,False,True,False,True,6,60,etherplusHH
54,0.742667,0.921875,0.96875,1.0,1.111776,1.0,1.0,1.0,1.058746,1.366543,1.177389,1.189223,16.856361,16.306145,42.506218,42.645554,6.310843,2.820724,5.490918,6.499156,0.000408,False,False,1.2e-05,False,False,True,False,True,9,57,ether
45,0.741333,0.90625,0.984375,1.0,1.10978,0.983051,1.016129,1.0,1.075222,1.406664,1.199246,1.210066,16.893654,16.469284,42.735207,42.76059,6.362766,2.983863,5.719899,6.614193,0.000416,False,False,1e-06,False,False,True,False,True,5,56,etherplusHH
53,0.741333,0.90625,0.984375,1.0,1.10978,0.983051,1.016129,1.0,1.106522,1.480148,1.232965,1.255576,17.095062,16.676277,43.281715,43.346733,6.58431,3.190855,6.266418,7.200348,0.000426,False,False,5.3e-05,False,False,True,False,True,3,55,ether
52,0.741333,0.90625,0.96875,1.0,1.10978,0.983051,1.0,1.0,1.056137,1.38391,1.180059,1.18595,16.813591,16.1502,42.501968,42.587128,6.273644,2.664779,5.486655,6.440735,0.000407,False,False,2e-06,False,False,True,False,True,16,61,etherplusHH
49,0.74,0.90625,0.984375,1.0,1.107784,0.983051,1.016129,1.0,1.077529,1.42435,1.201568,1.218279,16.936951,16.562836,42.871269,42.801796,6.414001,3.077415,5.855974,6.655408,0.000417,False,False,3e-06,False,False,True,False,True,1,63,etherplusHH
46,0.738667,0.90625,0.953125,1.0,1.105788,0.983051,0.983871,1.0,1.063996,1.380798,1.183572,1.194358,16.835587,16.362389,42.539574,42.602707,6.299283,2.876973,5.52427,6.456314,0.000411,False,False,1e-06,False,False,True,False,True,10,57,etherplusHH
4,0.737333,0.921875,0.953125,1.0,1.103792,1.0,0.983871,1.0,0.99967,1.272706,1.080832,1.08364,16.21595,15.622299,41.56469,41.762199,5.636047,2.136883,4.549378,5.615801,0.000358,False,False,1.3e-05,False,False,True,False,True,38,75,etherplusHH
50,0.737333,0.90625,0.984375,1.0,1.103792,0.983051,1.016129,1.0,1.156394,1.593674,1.294931,1.325845,17.071503,17.08638,43.277275,43.083633,6.592074,3.600962,6.261967,6.937241,0.000442,False,False,1e-06,False,False,True,False,True,1,62,etherplusHH
57,0.733333,0.921875,0.953125,1.0,1.097804,1.0,0.983871,1.0,1.045012,1.275345,1.123301,1.120911,16.709702,16.213425,42.831848,43.191765,6.140002,2.728009,5.816542,7.045368,0.000385,False,False,0.000372,False,False,True,False,True,1,56,ether


In [None]:
# Retrieve best parameters
best_parameters, values = ax_client.get_best_parameters()
best_parameters

In [None]:
means, covariances = values
means

## plot

In [None]:
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import init_notebook_plotting, render

init_notebook_plotting()

In [None]:
render(ax_client.get_contour_plot(param_x="lr", param_y="loss.β", metric_name="oos"))

In [None]:
render(ax_client.get_optimization_trace())  # Objective_optimum is optional.

In [None]:
# pot
render(ax_client.get_contour_plot())

In [None]:
from ax.plot.slice import plot_slice

model = ax_client.generation_strategy.model
render(plot_slice(model, "lr", "oos"))