In [1]:
%reload_ext autoreload
%autoreload 2

- https://ax.dev/tutorials/gpei_hartmann_service.html
- https://ax.dev/versions/0.4.1/tutorials/gpei_hartmann_service.html

In [2]:
from reprpo.training import train
import tyro
from reprpo.experiments import experiment_configs
from reprpo.interventions import Interventions, DPOConfig, ReprPOConfig
from reprpo.interventions.losses import Losses
from reprpo.interventions.transforms import Transforms


In [3]:
import pandas as pd

from ax.service.ax_client import logger as ax_logger
ax_logger.setLevel("DEBUG")


note you can have dependant params
- https://github.com/facebook/Ax/issues/1454

In [4]:
import warnings
import os
from ax.core.parameter import AxParameterWarning

warnings.filterwarnings("ignore", module="ax")
warnings.simplefilter("ignore", AxParameterWarning)

from loguru import logger

logger.remove()
logger.remove()
# logger.add(os.sys.stdout, level="INFO")
logger.add(os.sys.stderr, level="WARNING")

os.environ["WANDB_MODE"] = "disabled"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TQDM_DISABLE"] = "true"

In [5]:
from reprpo.ax.parameters import parameters_ether_prefvec, parameters_loss, parameters_transform
from reprpo.ax.target import objective_func


In [6]:
from pathlib import Path
key_metric = "acc_gain_vs_ref/oos"

parameters = parameters_ether_prefvec
name="ether-prefvec2"


parameters = parameters_loss
name="loss"

parameters = parameters_transform
name="transform"
exp_f = Path(f"../outputs/ax/{name}.json")
exp_f.parent.mkdir(exist_ok=True, parents=True)
exp_f


PosixPath('../outputs/ax/transform.json')

In [7]:
# exp_f.unlink()

In [8]:
from ax.service.ax_client import AxClient, ObjectiveProperties
import torch

ax_kwargs=dict(
    torch_device=torch.device("cuda"),
    verbose_logging=False
)

if exp_f.exists():
    ax_client = AxClient.load_from_json_file(
        filepath=exp_f, **ax_kwargs)
    print('loaded')
else:
    ax_client = AxClient(**ax_kwargs)

    ax_client.create_experiment(
        name=name,
        parameters=parameters,
        tracking_metric_names=[
            "acc/train",
            "acc/test",
            "acc/oos",
            "acc/rnd",
            
            "acc_gain_vs_ref/train",
            "acc_gain_vs_ref/test",
            "acc_gain_vs_ref/oos",
            "acc_gain_vs_ref/rnd",

            "perplexity_gain_vs_ref/train",
            "perplexity_gain_vs_ref/test",
            "perplexity_gain_vs_ref/oos",
            "perplexity_gain_vs_ref/rnd",

            "preference_logp_gain/train",
            "preference_logp_gain/test",
            "preference_logp_gain/oos",
            "preference_logp_gain/rnd",

            "preference_logp_gain_vs_ref/train",
            "preference_logp_gain_vs_ref/test",
            "preference_logp_gain_vs_ref/oos",
            "preference_logp_gain_vs_ref/rnd",
        ],
        objectives={"acc_gain_vs_ref/oos": ObjectiveProperties(minimize=False)},
    )

[INFO 09-26 07:47:39] ax.service.utils.instantiation: Inferred value type of ParameterType.STRING for parameter transform. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 09-26 07:47:39] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter transform.nb. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 09-26 07:47:39] ax.service.utils.instantiation: Inferred value type of ParameterType.STRING for parameter transform.Htype. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 09-26 07:47:39] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter transform.reduction. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float',

## Run

In [9]:
from tqdm.auto import tqdm
import time

for _ in tqdm(range(450)):
    t0 = time.time()
    parameters, trial_index = ax_client.get_next_trial()
    print(f"Time to get_next_trial: {time.time() - t0}")
    try:
        r = objective_func(**parameters)
        print(parameters, r)
    except KeyboardInterrupt:
        ax_client.save_to_json_file(filepath=exp_f)
        break
    # except Exception as e:
    #     logger.exception(f"Error in objective_func: parameters={parameters}")
    #     continue
    ax_client.complete_trial(trial_index=trial_index, raw_data=r)

# best_parameters, metrics = ax_client.get_best_parameters()
# best_parameters, metrics

  0%|          | 0/450 [00:00<?, ?it/s]



Time to get_next_trial: 0.020870447158813477
ReprPOConfig(dataset='us_history_textbook',
             verbose=0,
             dev=False,
             load_in_4bit=True,
             load_in_8bit=False,
             use_gradient_checkpointing=False,
             batch_size=64,
             n_samples=1800,
             eval_samples=64,
             max_length=196,
             max_prompt_length=96,
             base_model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
             lr=1e-05,
             collection_layers_side=[8, 10, 12, 14, 16, 18],
             collection_keys_in=('base_model.model.model.layers.{layer}.self_attn.o_proj',
                                 'base_model.model.model.layers.{layer}.mlp.down_proj'),
             collection_keys_out=('base_model.model.model.layers.{layer}.self_attn.q_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.k_proj',
                                  'base_model.model.model.layers.{layer}.self_attn.v_proj'

In [9]:
ax_client.save_to_json_file(filepath=exp_f)

In [None]:
# Fast
ax_client.get_best_parameters(use_model_predictions=False)[0]

In [None]:
g=ax_client.generation_strategy
gg = g._nodes[1]
# with torch around 1.3mins
print(gg.model_kwargs['torch_device'])
gg.model_kwargs['torch_device'] = torch.device("cuda")
gg


## Viewm

In [24]:

# df = ax_client.generation_strategy.trials_as_df
# df

In [None]:
df = ax_client.get_trials_data_frame()
df[key_metric].plot(xlabel="iteration", ylabel=key_metric)

In [None]:
df = ax_client.get_trials_data_frame()
d = df.iloc[:, 4:].sort_values(key_metric, ascending=False)#.head(20)

# remove columns ending with train test or rnd
d = d.loc[:, ~d.columns.str.endswith("train")]
d = d.loc[:, ~d.columns.str.endswith("test")]
d = d.loc[:, ~d.columns.str.endswith("rnd")]


def make_pretty(styler):
    styler.set_caption("Ax results")
    styler.background_gradient(axis=0, cmap="seismic_r")
    return styler


make_pretty(d.style)

In [None]:
# Retrieve best parameters
best_parameters, values = ax_client.get_best_parameters()
best_parameters

In [14]:
# values

In [None]:
means, covariances = values
means

## plot

In [None]:
g=ax_client.generation_strategy
gg = g._nodes[1]
# with torch around 2mins
gg.model_kwargs['torch_device'] = torch.device("cuda")
gg


In [27]:
# more than 20 mins?
ax_client.fit_model()

In [28]:
# instant
ax_client.fit_model()

In [None]:
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import init_notebook_plotting, render

init_notebook_plotting()

In [30]:
import plotly.io as pio
pio.renderers.default = "jupyterlab"

In [None]:
ax_client.get_trial_parameters(0)

In [None]:
render(ax_client.get_optimization_trace())  # Objective_optimum is optional.

In [33]:
# # plot
# render(ax_client.get_contour_plot())

In [None]:
from ax.plot.slice import plot_slice


model = ax_client.generation_strategy.model
ss = model.model_space.parameters
ss


In [None]:
for k,v in ss.items():
    if v.parameter_type.value in [1,2]:
        render(plot_slice(model, k, key_metric))

In [None]:
pd.Series(model.feature_importances(key_metric)).sort_values(ascending=False)