In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import os
import pandas as pd
import numpy as np
import seaborn as sns
from scipy import stats

import torch.optim as optim
import torch
import torch.nn as nn

from sklearn.preprocessing import LabelEncoder

import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display,clear_output

import warnings
warnings.filterwarnings("ignore")

In [None]:
from dataset import CompasDataset
from experiments import Benchmarking
from utils.logger_config import setup_logger
from tqdm import tqdm
from models.wrapper import PYTORCH_MODELS

logger = setup_logger()

In [None]:
from dataset import dataset_loader
from experiments.counterfactual import *
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from sklearn.svm import SVC
from models import PyTorchDNN, PyTorchLinearSVM, PyTorchRBFNet,PyTorchLogisticRegression,PyTorchLinearSVM
from sklearn.metrics import accuracy_score, classification_report
from sklearn.gaussian_process import GaussianProcessClassifier

In [None]:
name = 'compas'
dataset_ares = dataset_loader(name, data_path='data/', dropped_features=[], n_bins=None)

dataset = CompasDataset(dataset_ares=dataset_ares)

In [None]:
input_dim = dataset.get_dataframe().shape[1] - 1
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
Avalues_method = 'max'

counterfactual_algorithms = [
    # 'DiCE',
    'DisCount',
    # 'GlobeCE',
    # 'AReS',
    # 'KNN',
]

experiment = Benchmarking(
    dataset=dataset,
    models=[
        # (BaggingClassifier(),'sklearn'),
        # (GaussianProcessClassifier(),'sklearn'),
        # (XGBClassifier(), 'sklearn'),
        # (LGBMClassifier(),'sklearn'),
        # (PyTorchLogisticRegression(input_dim=input_dim), 'PYT'),
        # (PyTorchDNN(input_dim=input_dim), 'PYT'),
        (PyTorchRBFNet(input_dim=input_dim, hidden_dim=input_dim), 'PYT'),
        # (PyTorchLinearSVM(input_dim=input_dim), 'PYT'),
        # (RandomForestClassifier(), 'sklearn'), 
        # (GradientBoostingClassifier(), 'sklearn'), 
        # (AdaBoostClassifier(), 'sklearn'), 
    ],
    shapley_methods=[
            "Train_Distri",
            "Train_OTMatch",
            "CF_UniformMatch",
            "CF_RandomMatch",
            "CF_OTMatch",
    ],
    distance_metrics=[
        'optimal_transport',
        'mean_difference',
        'median_difference',
        'max_mean_discrepancy',
    ],
    md_baseline=False,
)

experiment.train_and_evaluate_models(random_state=seed)
experiment.models_performance()


logger.info("\n\n------Compute Counterfactuals------")
sample_num = 50
model_counterfactuals = {}
for model, model_name in zip(experiment.models, experiment.model_names):
    model_counterfactuals[model_name] = {}

    for algorithm in counterfactual_algorithms:
        if algorithm == 'DisCount' and model_name not in PYTORCH_MODELS:
            logger.info(f'Skipping {algorithm} for {model_name} due to incompatability')
            continue
        logger.info(f'Computing {model_name} counterfactuals with {algorithm}')
        function_name = f"compute_{algorithm}_counterfactuals"
        try:
            func = globals()[function_name]
            model_counterfactuals[model_name][algorithm] = func(
                experiment.X_test,
                model = model,
                target_name = experiment.dataset.target_name,
                sample_num = sample_num,
                experiment=experiment,
            )
        except KeyError:
            print(f"Function {function_name} is not defined.")




In [None]:
logger.info("\n\n------Compute Shapley Values------")
experiment.compute_intervention_policies(
    model_counterfactuals=model_counterfactuals,
    Avalues_method=Avalues_method,
);

In [None]:
logger.info("\n\n------Evaluating Distance Performance Under Interventions------")
experiment.evaluate_distance_performance_under_interventions(
    intervention_num_list=range(0,101,5),
    trials_num=100,
    replace=False
)


In [None]:
# import pickle 
# with open(f"pickles/{dataset.name}_experiment.pickle", "rb") as input_file:
#     experiment = pickle.load(input_file)

In [None]:
from experiments import plotting

plotting.intervention_vs_distance(experiment, save_to_file=False)

In [None]:
from experiments.latex import TikzPlotGenerator

model_name = 'PyTorchRBFNet'
cf_method = 'DisCount'
data = experiment.distance_results[model_name][cf_method]

plot_generator = TikzPlotGenerator(data)
print(plot_generator.generate_plot_code(
    # metric='optimal_transport',
    # metric='max_mean_discrepancy', 
    # metric='mean_difference', 
    metric='median_difference', 
    methods=[
        "Train_Distri",
        "Train_OTMatch",
        "CF_UniformMatch",
        "CF_RandomMatch",
        "CF_OTMatch",
    ]
))