In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from Clustering import Clustering
from DataNoiseAdder import DataNoiseAdder
from DatasetCorruptor import DatasetCorruptor
from DecisionTreeEnsemble import DecisionTreeEnsemble
from SyntheticDataGenerator import SyntheticDataGenerator
from EnsembleDiversity import EnsembleDiversity
from EnsembleMetrics import EnsembleMetrics

from utils import get_dataset, get_ensemble_preds_from_models, get_precision_recall_auc, auprc_threshs
from utils import plot_precision_recall, plot_aroc_at_curve, fitness_scatter
from utils import compute_metrics_in_buckets, flatten_df, compute_cluster_metrics
from utils import get_categorical_and_float_features
from utils import get_clusters_dict, make_noise_preds

from tqdm.notebook import tqdm

import warnings
warnings.filterwarnings('ignore')

### Experiment Args

In [2]:
args = {}
args['ntrls'] = 10
args['ensemble_size'] = 10
args['dataset_path'] = "/Users/scottmerrill/Documents/UNC/Research/OOD-Ensembles/datasets"
args['dataset_name'] = 'heloc_tf' 

# Decision Tree/Model Pool Params
args['num_classifiers'] = 100
args['feature_fraction'] = 0.5
args['data_fraction'] = 0.8
args['max_depth'] = 10
args['min_samples_leaf'] = 5
args['random_state'] = 1
args['clusters_list'] = [3, 10]
args['shift_feature_count'] = 5

AUCTHRESHS = np.array([0.1, 0.2, 0.3, 0.4, 1. ])

In [3]:
x_train, y_train, x_val_id, y_val_id, x_val_ood, y_val_ood = get_dataset(args['dataset_path'] , args['dataset_name'])
num_features = x_train.shape[1]

### Building and Training Model Pool

In [4]:
model_pool = DecisionTreeEnsemble(args['num_classifiers'], 
                                  args['feature_fraction'],
                                  args['data_fraction'],
                                  args['max_depth'],
                                  args['min_samples_leaf'],
                                  args['random_state'])

# train model pool
model_pool.train(x_train, y_train)

In [6]:
    def save(self, file_path):
        """
        Save the entire state of the ensemble to a file.
        
        :param file_path: The path where the model will be saved.
        """
        with open(file_path, 'wb') as f:
            pickle.dump({
                'num_classifiers': self.num_classifiers,
                'feature_fraction': self.feature_fraction,
                'data_fraction': self.data_fraction,
                'max_depth': self.max_depth,
                'min_samples_leaf': self.min_samples_leaf,
                'random_state': self.random_state,
                'classifiers': self.classifiers,
                'feature_subsets': self.feature_subsets,
                'data_subsets': self.data_subsets
            }, f)
        print(f"Ensemble saved to {file_path}")
    
    @classmethod
    def load(cls, file_path):
        """
        Load the ensemble from a saved file.
        
        :param file_path: The path to the saved ensemble file.
        :return: An instance of DecisionTreeEnsemble with the loaded state.
        """
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        
        # Create an instance of the class
        ensemble = cls(
            num_classifiers=data['num_classifiers'],
            feature_fraction=data['feature_fraction'],
            data_fraction=data['data_fraction'],
            max_depth=data['max_depth'],
            min_samples_leaf=data['min_samples_leaf'],
            random_state=data['random_state']
        )
        
        # Restore the saved classifiers, feature, and data subsets
        ensemble.classifiers = data['classifiers']
        ensemble.feature_subsets = data['feature_subsets']
        ensemble.data_subsets = data['data_subsets']
        
        print(f"Ensemble loaded from {file_path}")
        
        return ensemble

Ensemble saved to test.pkl
Ensemble loaded from test.pkl


In [9]:
(model_pool.predict(x_val_id) == model_pool2.predict(x_val_id)).all()

True

### Caching Model Pool Predictions

In [None]:
model_pool_preds = model_pool.predict(x_val_ood)
model_pool_pred_probs = model_pool.predict_proba(x_val_ood)
mp_precision, mp_recall, mp_auc = get_precision_recall_auc(model_pool_pred_probs, y_val_ood, AUCTHRESHS)

### Caching Individual Model Predictions

In [None]:
model_pool.train_preds = model_pool.get_individual_predictions(x_train).T
model_pool.train_pred_probs = model_pool.get_individual_probabilities(x_train)

model_pool.val_id_preds = model_pool.get_individual_predictions(x_val_id).T
model_pool.val_id_pred_probs = model_pool.get_individual_probabilities(x_val_id)

model_pool.val_ood_preds = model_pool.get_individual_predictions(x_val_ood).T
model_pool.val_ood_pred_probs = model_pool.get_individual_probabilities(x_val_ood)

### Clustering Data and noise methods

In [None]:
all_clusters_dict = {}
all_clusters_dict['train_preds'], all_clusters_dict['val_id']  = get_clusters_dict(x_train, x_val_id, args['clusters_list'])
all_clusters_dict = make_noise_preds(x_train, y_train, x_val_id, model_pool, args['shift_feature_count'], args['clusters_list'], all_clusters_dict)

### Synthetic Data

In [None]:
generator = SyntheticDataGenerator(x_train, y_train)

# interpolation
interp_x, interp_y = generator.interpolate(x_train.shape[0])
model_pool.synth_interp_preds = model_pool.get_individual_predictions(interp_x).T
model_pool.synth_interp_pred_probs = model_pool.get_individual_probabilities(interp_x)
del interp_x

# GMM
gmm_x, gmm_y = generator.gaussian_mixture(x_train.shape[0])
model_pool.synth_gmm_preds = model_pool.get_individual_predictions(gmm_x).T
model_pool.synth_gmm_pred_probs = model_pool.get_individual_probabilities(gmm_x)
del gmm_x

dt_x, dt_y = generator.decision_tree(x_train.shape[0])
model_pool.synth_dt_preds = model_pool.get_individual_predictions(dt_x).T
model_pool.synth_dt_pred_probs = model_pool.get_individual_probabilities(dt_x)
del dt_x

synth_data_dict = {'synth_interp':interp_y.astype('int64'),
                   'synth_gmm': gmm_y.astype('int64'),
                   'synth_dt': dt_y.astype('int64')}

### Label Shift

In [None]:
y_train_flipped = DataNoiseAdder.label_flip(y_train)
y_val_flipped = DataNoiseAdder.label_flip(y_val_id)

In [None]:
y_train.dtype

### Random Search Loop

In [None]:
a = [x for x in dir(model_pool) if 'preds' in x]
b = [x for x in dir(model_pool) if 'pred_' in x]

a.remove(f'val_ood_preds')
b.remove(f'val_ood_pred_probs')

pred_attributes = [(a[i], b[i], a[i].split('_')[0]+'_'+a[i].split('_')[1]) for i in range(len(a))]

In [None]:
pred_attributes

In [None]:
precisions_df = pd.DataFrame()
recalls_df = pd.DataFrame()
aucs_df = pd.DataFrame()
fitness_df = pd.DataFrame()

for trial in tqdm(range(args['ntrls'])):
    indices = np.random.choice(model_pool.num_classifiers, size=args['ensemble_size'], replace=True)

    # ood preds of sub-ensemble
    ood_preds, ood_pred_probs = get_ensemble_preds_from_models(model_pool.val_ood_pred_probs[indices])

    # save OOD precision/recalls seprately
    precision, recall, auc = get_precision_recall_auc(ood_pred_probs, y_val_ood, AUCTHRESHS)

    recalls_df = pd.concat([recalls_df, pd.DataFrame(recall)], axis=1)
    precisions_df = pd.concat([precisions_df, pd.DataFrame(precision)], axis=1)
    aucs_df = pd.concat([aucs_df, pd.DataFrame(auc)], axis=1)

    tmp = {'generation':trial,
              'ensemble_files':','.join(str(x) for x in indices)}
    cluster_metrics = pd.DataFrame()

    # Compute all Fitness Metrics
    for label_flip in [0, 1]:
        for pred_tuple in pred_attributes:
            preds_name, pred_prob_name, prefix_name = pred_tuple

            # get clustering associated with data transformation
            if 'synth' not in prefix_name:
                clusters_dict = all_clusters_dict[prefix_name]
                
            if 'train' in prefix_name:
                if label_flip:
                    Y = y_train_flipped
                    prefix_name = prefix_name + '_flip'
                else:
                    Y = y_train
                    
            elif 'synth' in prefix_name:
                if label_flip:
                    continue
                else:
                    Y = synth_data_dict[prefix_name]

            else:
                if label_flip:
                    prefix_name = prefix_name + '_flip'
                    Y = y_val_flipped
                else:
                    Y = y_val_id

            model_pool_preds = getattr(model_pool, preds_name)
            model_pool_pred_probs = getattr(model_pool, pred_prob_name)

            model_preds = model_pool_preds[indices]
            model_pred_probs = model_pool_pred_probs[indices]

            # id val preds of sub-ensemble
            ensemble_preds, ensemble_pred_probs = get_ensemble_preds_from_models(model_pred_probs)
            metrics = EnsembleMetrics(Y, ensemble_preds, ensemble_pred_probs[:,1])
            diversity = EnsembleDiversity(Y, model_preds)

            tmp.update({f'{prefix_name}_acc':metrics.accuracy(),
                   f'{prefix_name}_auc':metrics.auc(),
                   f'{prefix_name}_prec':metrics.precision(),
                   f'{prefix_name}_rec':metrics.recall(),
                   f'{prefix_name}_f1':metrics.f1(),
                   f'{prefix_name}_mae':metrics.mean_absolute_error(),
                   f'{prefix_name}_mse':metrics.mean_squared_error(),
                   f'{prefix_name}_logloss':metrics.log_loss(),

                   # diversity
                   f'{prefix_name}_q_statistic':np.mean(diversity.q_statistic()),
                   f'{prefix_name}_correlation_coefficient':np.mean(diversity.correlation_coefficient()),
                   f'{prefix_name}_entropy':np.mean(diversity.entropy()),
                   f'{prefix_name}_diversity_measure':diversity.diversity_measure(),
                   f'{prefix_name}_hamming_distance':np.mean(diversity.hamming_distance()),
                   f'{prefix_name}_error_rate':np.mean(diversity.error_rate()),
                   f'{prefix_name}_auc':np.mean(diversity.auc()),
                   f'{prefix_name}_brier_score':np.mean(diversity.brier_score()),
                   f'{prefix_name}_ensemble_variance':np.mean(diversity.ensemble_variance()),
                  })

            if 'synth' not in prefix_name:
                # compute cluster metrics
                tmp_cluster = compute_cluster_metrics(clusters_dict, ensemble_preds, model_preds, model_pred_probs, Y)
                col_names = [prefix_name + '_' + x for x in tmp_cluster.columns]
                col_names = [name.replace('_val_acc', '') for name in col_names]
                col_names = [name.replace('_train_acc', '') for name in col_names]
                tmp_cluster.columns = col_names

                cluster_metrics = pd.concat([cluster_metrics, tmp_cluster], axis=1)

    raw_metrics = pd.DataFrame([tmp])    
    tmp = pd.concat([raw_metrics, cluster_metrics], axis=1)
    fitness_df = pd.concat([fitness_df, tmp])
    #precisions_df.to_csv(save_path+'/precisions_df.csv', index=False)
    #recalls_df.to_csv(save_path+'/recalls_df.csv', index=False)
    #aucs_df.to_csv(save_path+'/aucs_df.csv', index=False)

In [None]:
pd.set_option('display.max_columns', None)
fitness_df

### Plot Results

In [None]:
fitness_df = fitness_df.reset_index(drop=True)
best_fitness_index = {i+1:index for i,index in enumerate(fitness_df.nlargest(3, 'val_upscaleshift_meanshift_val_error_rate_mean').index)}

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

plot_precision_recall(precisions_df, recalls_df, mp_precision, mp_recall, best_fitness_index, ax=axs[0])
plot_aroc_at_curve(AUCTHRESHS, aucs_df, mp_auc, best_fitness_index, ax=axs[1])

### Fitness Function Diagnosis


In [None]:
cols = ['train_boundaryshift_acc', 'val_upscaleshift_meanshift_val_error_rate_mean']
fig, axs = plt.subplots(1, len(cols), figsize=(12, 6))

for idx, col in enumerate(cols):
    # Plot AUC scatter for different fitness columns
    fitness_scatter(fitness_df, aucs_df, col, ax=axs[idx])
plt.tight_layout(pad=2)

In [None]:
pd.set_option('display.max_columns',None)
fitness_df.head()