In [None]:
import os
import warnings
from typing import List
import logging

warnings.filterwarnings(action='ignore', module='numpy')
warnings.filterwarnings(action='ignore', module='pandas')
warnings.filterwarnings(action='ignore', module='sklearn')
warnings.filterwarnings(action='ignore', module='tensorflow')

import ax
import pandas as pd
import ray
import numpy as np
import scipy as sp
import seaborn as sns
import matplotlib.pyplot as plt
from ax.service.ax_client import AxClient
from ax.plot.contour import interact_contour, plot_contour
from ax.plot.slice import plot_slice
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.service.utils.best_point import get_best_from_model_predictions, get_best_raw_objective_point
from ray import tune
from ray.tune import track, JupyterNotebookReporter
from ray.tune.suggest.ax import AxSearch
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
import seaborn as sns

from synthesized.insight.latent import get_latent_space, latent_dimension_usage, total_latent_space_usage
from synthesized.insight.dataset import describe_dataset_values, describe_dataset, classification_score
from tune_utils import AxSearch2


root_logger = logging.getLogger()
root_logger.setLevel(50)

def train_evaluate(params):
    NUM_ROWS = params['num_rows']
    RANDOM_SEED = params['random_seed']
    BETA = params['beta']
    NUM_ITERATIONS = params['num_iterations']
    TRIAL = params['trial']
    
    latent_space = get_latent_space(
        df=data.sample(NUM_ROWS, random_state=RANDOM_SEED), 
        num_iterations=NUM_ITERATIONS,
        beta=BETA
    )
    m_lsu = latent_dimension_usage(df_latent=latent_space, usage_type='mean').sort_values('usage', 0, False)
    metrics = {f'm_dim_{n}': val for n, val in enumerate(m_lsu['usage'])}
    
    s_lsu = latent_dimension_usage(df_latent=latent_space, usage_type='stddev').sort_values('usage', 0, False)
    metrics.update({f's_dim_{n}': val for n, val in enumerate(s_lsu['usage'])})
                                    
    metrics['m_total'] = total_latent_space_usage(df_latent=latent_space, usage_type='mean')
    metrics['s_total'] = total_latent_space_usage(df_latent=latent_space, usage_type='stddev')
                
    track.log(**metrics)

In [None]:
init_notebook_plotting()

In [None]:
ray.init(address='auto', redis_password='5241590000000000', log_to_driver=False)

In [None]:
gs = GenerationStrategy(
    steps=[
        GenerationStep(
            model=Models.SOBOL,
            num_trials=-1, 
            min_trials_observed=0, 
            max_parallelism=20, 
            enforce_num_trials=True, 
            model_kwargs={'deduplicate': True, 'seed': None},
            model_gen_kwargs=None
        )
    ]
)
axc = AxClient(generation_strategy=gs, verbose_logging=False, enforce_sequential_optimization=False)
axc.create_experiment(
    name="capacity_tuning",
    parameters=[
        {"name": "num_rows", "type": "choice", "values": [1000, 2000, 4000, 8000, 16000, 32000, 64000]},
        {"name": "beta", "type": "choice", "values": [1.0, 2.0]},
        {"name": "num_iterations", "type": "choice", "values": [10000, 20000]},
        {"name": "trial", "type": "range", "bounds": [1, 5]},
        {"name": "random_seed", "type": "fixed", "value": 161833}
    ],
    objective_name="m_total",
    minimize=True
)

In [None]:
NUM_TRIALS = 140
DATASET_DIR = 'data/credit_with_categoricals.csv'
    
loss_sample_size = 50_000
data = pd.read_csv(DATASET_DIR)
data = data.dropna()
loss_sample_size = min(loss_sample_size, len(data))

In [None]:
analysis = tune.run(
    train_evaluate,
    num_samples=NUM_TRIALS,
    search_alg=AxSearch2(axc, mode='min'),  # Note that the argument here is the `AxClient`.
    verbose=1,  # Set this level to 1 to see status updates and to 2 to also see trial results.
    # To use GPU, specify: resources_per_trial={"gpu": 1}.
    resources_per_trial={"cpu": 2},
    max_failures=3,
    progress_reporter=JupyterNotebookReporter(overwrite=True, max_progress_rows=100),
    return_trials=False,
)

In [None]:
results = [{k: v for k, v in analysis.trials[n].last_result.items()
           if 'dim' in k or k in ['m_total', 's_total', 'config'] 
           } for n in range(NUM_TRIALS)]

results2 = []
for result in results:
    result.update(result['config'])
    del result['config']
    
    result_m = {k if k[:2] != 'm_' else k[2:]: v for k, v in result.items() if k[:2] != 's_'}
    result_m['usage_type'] = 'mean'
    result_s = {k if k[:2] != 's_' else k[2:]: v for k, v in result.items() if k[:2] != 'm_'}
    result_s['usage_type'] = 'stddev'
    
    results2.append(result_m)
    results2.append(result_s)

results = pd.DataFrame(results2)

In [None]:
df_exp = results.rename(columns=lambda x: int(x[4:]) if 'dim' in x else x).melt(
    id_vars=['trial', 'num_rows', 'usage_type', 'beta', 'num_iterations', 'random_seed'], 
    var_name='dim', value_name='usage'
)

In [None]:
colors = ["coral", "ocean blue", "ocean blue", "ocean blue"]*8

df_exp['scenario'] = [f'ni{a}_b{b}' for a,b in zip(df_exp['num_iterations'], df_exp['beta'])]
sns.catplot(
    data=df_exp[df_exp['dim']!='total'][df_exp['usage_type']=='mean'], kind='bar', aspect=4,
    x='num_rows', y='usage', hue='dim', row='scenario', 
    palette=sns.xkcd_palette(colors)
)