# Step Multimodal

In [None]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append("../../")

from copy import deepcopy

import pandas as pd
from captum.attr import (DeepLift, DeepLiftShap, FeatureAblation,
                         FeaturePermutation, GradientShap, GuidedBackprop,
                         InputXGradient, IntegratedGradients, NoiseTunnel,
                         Saliency, ShapleyValueSampling)
from evobench.continuous import StepMultimodal
from evosolve.continuous import dg2
import plotly.io as pio
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from hell import Surrogate, SurrogateData, plot, util
from hell.linkage import EmpiricalLinkage

seed_everything(42)
pio.renderers.default = "notebook"

In [None]:
benchmark = StepMultimodal(blocks=[10] * 10, step_size=2, verbose=1)

x_preprocessing = Pipeline([
    ("standard-scaler", StandardScaler())
])

y_preprocessing = Pipeline([
    ("min-max-scaler", MinMaxScaler())
])

data = SurrogateData(
    benchmark,
    x_preprocessing, y_preprocessing,
    n_samples=1e5, splits=(0.6, 0.2, 0.2),
    batch_size=200,
)

In [None]:
# ? sanity check

from xgboost import XGBRegressor
from sklearn.metrics import r2_score

xgb_model = XGBRegressor(n_estimators=200, nthread=8)
xgb_model.fit(data.x_train, data.y_train)
y_pred = xgb_model.predict(data.x_test)
r2_score(data.y_test, y_pred)

In [None]:
surrogate = Surrogate(
    benchmark.genome_size,
    x_preprocessing, y_preprocessing,
    n_layers=1, learning_rate=2e-4, weight_decay=1e-8
)

early_stop_callback = EarlyStopping(
   monitor="val/r2",
   min_delta=0.000,
   patience=5,
   verbose=False,
   mode="max"
)

trainer = Trainer(
    max_epochs=100,
    gpus=1,
    progress_bar_refresh_rate=50,
    callbacks=[early_stop_callback]
)

In [None]:
trainer.fit(surrogate, data.data_module)
surrogate.eval()

In [None]:
xai_results = util.test_xais(
    benchmark,
    data.x_preprocessing,
    decomposers=[
        EmpiricalLinkage(benchmark, DeepLift(surrogate), data.x_preprocessing),
        EmpiricalLinkage(benchmark, FeatureAblation(surrogate), data.x_preprocessing),
        EmpiricalLinkage(benchmark, GradientShap(surrogate), data.x_preprocessing),
        EmpiricalLinkage(benchmark, GuidedBackprop(surrogate), data.x_preprocessing),
        EmpiricalLinkage(benchmark, InputXGradient(surrogate), data.x_preprocessing),
        EmpiricalLinkage(benchmark, IntegratedGradients(surrogate), data.x_preprocessing),
        EmpiricalLinkage(benchmark, NoiseTunnel(IntegratedGradients(surrogate)), data.x_preprocessing),
        EmpiricalLinkage(benchmark, Saliency(surrogate), data.x_preprocessing),
    ],
    n_samples=100,
)

In [None]:
benchmark.ffe

In [None]:
benchmark.ffe = 0

In [None]:
dg2_results = util.test_decomposer(
    dg2.EmpiricalLinkage(benchmark), n_samples=100
)

In [None]:
benchmark.ffe

In [None]:
results = pd.concat([xai_results, dg2_results])

In [None]:
plot.hit_ratio(results)

In [None]:
plot.ranking_metric(
    results,
    metric="mean_reciprocal_rank",
    title="Mean Reciprocal Ranking"
)

In [None]:
plot.ranking_metric(
    results,
    metric="mean_average_precision",
    title="Mean Average Precision"
)

In [None]:
plot.ranking_metric(
    results,
    metric="ndcg$1",
    title="NDCG$1"
)

In [None]:
results_mean = results.groupby(by="method").mean()

In [None]:
results_mean.T.to_csv("trap.csv")