In [None]:
import logging

import ax
import pandas as pd
import ray
from ax.service.ax_client import AxClient
from ax.plot.contour import interact_contour, plot_contour
from ax.plot.slice import plot_slice
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.service.utils.best_point import get_best_from_model_predictions, get_best_raw_objective_point
from ray import tune
from ray.tune import track, JupyterNotebookReporter
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models

from synthesized import HighDimSynthesizer
from tune_utils import AxSearch2


root_logger = logging.getLogger()
root_logger.setLevel(50)
logger = logging.getLogger(__name__)

def ray_callback(synthesizer, iteration, losses):
    track.log(
        iteration = iteration
    )
    return False


def train_evaluate(parameterization):
    with HighDimSynthesizer(df=data, **parameterization) as synthesizer:     
        synthesizer.learn(data, num_iterations=None, callback=ray_callback, callback_freq=1)
        
        data_ = synthesizer.preprocess(data.sample(loss_sample_size))
        feed_dict = synthesizer.get_data_feed_dict(data_)
        losses = synthesizer.get_losses(data=feed_dict)
        
        loss = losses['total-loss'].numpy().item()

        track.log(
            mean_loss=loss,
            reconstruction_loss = losses['reconstruction-loss'].numpy().item(),
            kl_loss = losses['kl-loss'].numpy().item(),
            iteration = synthesizer.global_step.numpy().item()
        )

In [None]:
init_notebook_plotting()

In [None]:
ray.init(address='auto', redis_password='5241590000000000', log_to_driver=False)

In [None]:
gs = GenerationStrategy(
    steps=[
        GenerationStep(
            model=Models.SOBOL,
            num_trials=20, 
            min_trials_observed=15, 
            max_parallelism=20, 
            enforce_num_trials=True, 
            model_kwargs={'deduplicate': True, 'seed': None},
            model_gen_kwargs=None
        ),
         GenerationStep(
            model=Models.GPEI,
            num_trials=-1,
            min_trials_observed=0,
            max_parallelism=20,
            enforce_num_trials=True, 
            model_kwargs=None, 
            model_gen_kwargs=None,
        )
    ]
)

axc = AxClient(generation_strategy=gs, verbose_logging=False, enforce_sequential_optimization=False)
axc.create_experiment(
    name="capacity_tuning",
    parameters=[
        {"name": "capacity", "type": "range", "bounds": [8, 128]},
        {"name": "latent_size", "type": "range", "bounds": [8, 128]},
        {"name": "num_layers", "type": "range", "bounds": [1, 4]},
        {"name": "residual_depths", "type": "range", "bounds": [2, 6]},
        {"name": "learning_rate", "type": "range", "bounds": [1e-5, 1e-1], "log_scale": True},
        {"name": "max_training_time", "type": "fixed", "value": 120.0}
    ],
    objective_name="mean_loss",
    minimize=True
)

In [None]:
loss_sample_size = 50_000
data = pd.read_csv('data/credit_with_categoricals.csv')
data = data.dropna()
loss_sample_size = min(loss_sample_size, len(data))

In [None]:
analysis = tune.run(
    train_evaluate,
    num_samples=100,
    search_alg=AxSearch2(axc, mode='min'),  # Note that the argument here is the `AxClient`.
    verbose=1,  # Set this level to 1 to see status updates and to 2 to also see trial results.
    # To use GPU, specify: resources_per_trial={"gpu": 1}.
    resources_per_trial={"cpu": 2},
    max_failures=3,
    progress_reporter=JupyterNotebookReporter(overwrite=True, max_progress_rows=100)
)

In [None]:
# Gets the best parameters from comparing all trials
params, mean_value = get_best_raw_objective_point(axc.experiment)
print("Best Trial Params:")
print(params, mean_value)

# Gets the parameters by predicting with the bayesian model
params, (mean_value, variance) = get_best_from_model_predictions(axc.experiment)
print("Estimated Best Params:")
print(params, mean_value, variance)



In [None]:
render(axc.get_feature_importances())

In [None]:
for param in params:
    try:
        render(plot_slice(axc.generation_strategy.model, param, 'mean_loss'))
    except ValueError:
        pass

In [None]:
render(interact_contour(axc.generation_strategy.model, metric_name='mean_loss'))

In [None]:
render(axc.get_optimization_trace())

In [None]:
axc.get_trials_data_frame()