# [Multi-class Counterfactual explanation using DiCE](https://github.com/interpretml/DiCE)
DiCE generated Counterfactual explanations for any ML model. The algorithm is based on the paper [Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations, 2020](https://arxiv.org/pdf/1905.07697).

dice_ml.Data(features={}) can be defined to ensure feasibility of the Counterfactual (set).

explainer.generate_counterfactuals(features_to_vary=[]) can also be defined to explicitly specify which features to perturb.

    

In [11]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import dice_ml
from dice_ml.utils import helpers  # helper functions

In [14]:
# Load the Iris dataset
iris = load_iris()
X = iris.data
y = iris.target_names[iris.target]

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

iris.target_names

array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

In [4]:
# Train a Random Forest Classifier
rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)

# Evaluate the model
y_pred = rf.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.2f}")

Accuracy: 1.00


In [5]:
# Create a DataFrame for the features
features_df = pd.DataFrame(X_train, columns=iris.feature_names)

# Create a DataFrame for the target
target_df = pd.DataFrame(y_train, columns=['species'])

# Combine features and target into a single DataFrame
dice_data = pd.concat([features_df, target_df], axis=1)

# Define the data object for DiCE
dice_data = dice_ml.Data(dataframe=dice_data, continuous_features=iris.feature_names, outcome_name='species')

In [6]:
# Create the DiCE model object
dice_model = dice_ml.Model(model=rf, backend="sklearn")

In [7]:
# suppress UserWarning
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)

# Initialize the DiCE explainer
explainer = dice_ml.Dice(dice_data, dice_model, method="random")

# Choose a sample instance from the test set
sample_instance = X_test[0:1]
sample_instance_df = pd.DataFrame(sample_instance, columns=iris.feature_names)

# Generate counterfactuals for the sample instance
counterfactuals = explainer.generate_counterfactuals(sample_instance_df, total_CFs=5, desired_class=2)

# Visualize the counterfactuals
counterfactuals.visualize_as_dataframe(show_only_changes=True)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  3.78it/s]

Query instance (original outcome : versicolor)





Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),species
0,6.1,2.8,4.7,1.2,versicolor



Diverse Counterfactual set (new outcome: virginica)


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),species
0,-,2.2,5.1,-,virginica
1,-,-,6.3,2.1,virginica
2,-,-,6.5,-,virginica
3,-,3.2,6.2,-,virginica
4,-,-,6.2,-,virginica


## Explanation
Diverse Counterfactual set gives us the feature-perturbed version of the query instance. The last row of the Diverse Counterfactual set shows that if the query instance were to have a petal length of 6.2 cm as oppossed to 4.7 cm; that alone would be enough to classify the species as virginica, not versicolor. The interpretation is the same for the other rows.

In [8]:
# Local feature importance scores
# For a given sample_instance_df, higher score implies a feature is (locally) important to get the desired_class

sample_instance = X_test[0:1]
sample_instance_df = pd.DataFrame(sample_instance, columns=iris.feature_names)

local = explainer.local_feature_importance(sample_instance_df, total_CFs=10, desired_class=2)
print(local.local_importance)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  3.16it/s]

[{'petal length (cm)': 0.9, 'sepal length (cm)': 0.5, 'sepal width (cm)': 0.3, 'petal width (cm)': 0.2}]





For the sample (X_test[0:1]), petal length has highest local importance in determining the target class.

In [9]:
# Global feature importance scores
# Aggregate the local features importance of each input observations to get the global importance

sample_instance = X_test[0:20]
sample_instance_df = pd.DataFrame(sample_instance, columns=iris.feature_names)

gbl = explainer.global_feature_importance(sample_instance_df, total_CFs=10, desired_class=2)
print(gbl.summary_importance)

100%|██████████| 20/20 [00:06<00:00,  2.96it/s]

{'petal length (cm)': 0.675, 'petal width (cm)': 0.535, 'sepal length (cm)': 0.355, 'sepal width (cm)': 0.22}





For a set of samples (X_test[0:20]), petal length and petal width have high global importance

In [10]:
# Analyze the counterfactuals
print("Local Counterfactuals:")
print(local.cf_examples_list[0].final_cfs_df)

print("\nGlobal Counterfactuals:")
print(gbl.cf_examples_list[0].final_cfs_df)

Local Counterfactuals:
   sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \
0                6.1               2.8                6.1               1.2   
1                7.6               2.8                6.0               1.2   
2                6.1               4.0                4.7               2.2   
3                7.3               2.8                6.8               1.2   
4                6.1               2.8                6.3               1.9   
5                5.6               2.8                5.7               1.2   
6                6.1               2.0                5.1               1.2   
7                6.1               3.2                6.0               1.2   
8                4.5               2.8                5.8               1.2   
9                6.8               2.8                6.2               1.2   

     species  
0  virginica  
1  virginica  
2  virginica  
3  virginica  
4  virginica  
5  virginica  
6 