# Counterfactual explanations

In [1]:
import trustyai

trustyai.init(
    path=[
        "../dep/org/kie/kogito/explainability-core/1.8.0.Final/*",
        "../dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar",
        "../dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar",
        "../dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar",
        "../dep/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar",
        "../dep/org/kie/kie-api/7.55.0.Final/kie-api-7.55.0.Final.jar",
        "../dep/io/micrometer/micrometer-core/1.6.6/micrometer-core-1.6.6.jar",
    ]
)

## Simple example

We start by defining our black-box model, typically represented by

$$
f(\mathbf{x}) = \mathbf{y}
$$

Where $\mathbf{x}=\{x_1, x_2, \dots,x_m\}$ and $\mathbf{y}=\{y_1, y_2, \dots,y_n\}$.

Our example toy model, in this case, takes an all-numerical input $\mathbf{x}$ and return a $\mathbf{y}$ of either `true` or `false` if the sum of the $\mathbf{x}$ components is within a threshold $\epsilon$ of a point $\mathbf{C}$, that is:

$$
f(\mathbf{x}, \epsilon, \mathbf{C})=\begin{cases}
\text{true},\qquad \text{if}\ \mathbf{C}-\epsilon<\sum_{i=1}^m x_i <\mathbf{C}+\epsilon \\
\text{false},\qquad \text{otherwise}
\end{cases}
$$

This model is provided in the `TestUtils` module. We instantiate with a $\mathbf{C}=500$ and $\epsilon=1.0$.

In [2]:
from trustyai.utils import TestUtils

center = 500.0
epsilon = 1.0

model = TestUtils.getSumThresholdModel(center, epsilon)

Next we need to define a **goal**.
If our model is $f(\mathbf{x'})=\mathbf{y'}$ we are then defining our $\mathbf{y'}$ and the counterfactual result will be the $\mathbf{x'}$ which satisfies $f(\mathbf{x'})=\mathbf{y'}$.

We will define our goal as `true`, that is, the sum is withing the vicinity of a (to be defined) point $\mathbf{C}$. The goal is a list of `Output` which take the following parameters

- The feature name
- The feature type
- The feature value (wrapped in `Value`)
- A confidence threshold, which we will leave at zero (no threshold)

In [3]:
from trustyai.model import Output, Type, Value

goal = [Output("inside", Type.BOOLEAN, Value(True), 0.0)]

We will now define our initial features, $\mathbf{x}$. Each feature can be instantiated by using `FeatureFactory` and in this case we want to use numerical features, so we'll use `FeatureFactory.newNumericalFeature`.

In [4]:
import random
from trustyai.model import FeatureFactory

features = [FeatureFactory.newNumericalFeature(f"x{i+1}", random.random()*10.0) for i in range(4)]

As we can see, the sum of of the features will not be within $\epsilon$ (1.0) of $\mathbf{C}$ (500.0). As such the model prediction will be `false`:

In [5]:
feature_sum = 0.0
for f in features:
    value = f.getValue().asNumber()
    print(f"Feature {f.getName()} has value {value}")
    feature_sum += value
print(f"\nFeatures sum is {feature_sum}")

Feature x1 has value 6.953686434260184
Feature x2 has value 0.6895992287088226
Feature x3 has value 9.429677348990124
Feature x4 has value 3.6853630123991774

Features sum is 20.75832602435831


The next step is to specify the **constraints** of the features, i.e. which features can be changed and which should be fixed. Since we want all features to be able to change, we specify `False` for all of them:

In [6]:
constraints = [False] * 4

Finally, we also specify which are the **bounds** for the counterfactual search. Typically this can be set either using domain-specific knowledge or taken from the data. In this case we simply specify an arbitrary (sensible) value, e.g. all the features can vary between `0` and `1000`.

In [7]:
from trustyai.model.domain import NumericalFeatureDomain

feature_boundaries = [NumericalFeatureDomain.create(0.0, 1000.0)] * 4

In order to use the boundaries in the explainer we need to wrap all of them in a `DataDomain` class:

In [8]:
from trustyai.model import DataDomain

data_domain = DataDomain(feature_boundaries)

We can now instantiate the **explainer** itself.

To do so, we will to configure the termination criteria. For this example we will specify that the counterfactual search should only execute a maximum of 10,000 iterations before stopping and returning whatever the best result is so far.

In [9]:
from org.optaplanner.core.config.solver.termination import TerminationConfig
from org.kie.kogito.explainability.local.counterfactual import CounterfactualConfigurationFactory
from java.lang import Long

termination_config = TerminationConfig().withScoreCalculationCountLimit(Long.valueOf(10_000))

solver_config = (
        CounterfactualConfigurationFactory.builder()
        .withTerminationConfig(termination_config)
        .build()
    )

We can can now instantiate the explainer itself using `CounterfactualExplainer` and our `solver_config` configuration.

In [10]:
from org.kie.kogito.explainability.local.counterfactual import CounterfactualExplainer

explainer = CounterfactualExplainer.builder().withSolverConfig(solver_config).build()

SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.


We will now express the counterfactual problem as defined above.

- `original` represents our $\mathbf{x}$ which know gives a prediction of `False`
- `goals` represents our $\mathbf{y'}$, that is our desired prediction (`True`)
- `domain` repreents the boundaries for the counterfactual search

In [11]:
from trustyai.model import PredictionFeatureDomain, PredictionInput, PredictionOutput

original = PredictionInput(features)
goals = PredictionOutput(goal)
domain = PredictionFeatureDomain(data_domain.getFeatureDomains())

We wrap these quantities in a `CounterfactualPrediction` (the UUID is simply to label the search instance):

In [12]:
import uuid
from trustyai.model import CounterfactualPrediction

prediction = CounterfactualPrediction(original, goals, domain, constraints, None, uuid.uuid4())

We now request the counterfactual $\mathbf{x'}$ which is closest to $\mathbf{x}$ and which satisfies $f(\mathbf{x'}, \epsilon, \mathbf{C})=\mathbf{y'}$:

In [13]:
explanation_async = explainer.explainAsync(prediction, model)

The counterfactual explainer API operates in a asynchronous way, so we need to `.get()` the result:

In [14]:
explanation = explanation_async.get()

We can see that the counterfactual $\mathbf{x'}$

In [15]:
feature_sum = 0.0
for entity in explanation.getEntities():
    print(entity)
    feature_sum += entity.getProposedValue()
    
print(f"\nFeature sum is {feature_sum}")

java.lang.DoubleFeature{value=485.4101987057185, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x1'}
java.lang.DoubleFeature{value=0.6895992287088226, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}
java.lang.DoubleFeature{value=9.291426877845232, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}
java.lang.DoubleFeature{value=3.6853630123991774, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x4'}

Feature sum is 499.0765878246718


## Constrained features

As we've seen, it is possible to constraint a specific feature $x_i$ by setting the _constraints_ list corresponding element to `True`.

In this example, we know want to fix $x_1$ and $x_4$. That is, these features should have the same value in the counterfactual $\mathbf{x'}$ as in the original $\mathbf{x}$.

In [16]:
constraints = [True, False, False, True] # x1, x2, x3 and x4

We simply need to wrap the previous quantities with the new constraints:

In [17]:
prediction = CounterfactualPrediction(original, goals, domain, constraints, None, uuid.uuid4())

And request a new counterfactual explanation

In [18]:
explanation = explainer.explainAsync(prediction, model).get()

We can see that $x_1$ and $x_4$ has the same value as the original and the model satisfies the conditions.

In [19]:
print(f"Original x1: {features[0].getValue()}")
print(f"Original x4: {features[3].getValue()}\n")

for entity in explanation.getEntities():
    print(entity)

Original x1: 6.953686434260184
Original x4: 3.6853630123991774

java.lang.DoubleFeature{value=6.953686434260184, intRangeMinimum=6.953686434260184, intRangeMaximum=6.953686434260184, id='x1'}
java.lang.DoubleFeature{value=0.7810382333337529, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}
java.lang.DoubleFeature{value=488.03303690921916, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}
java.lang.DoubleFeature{value=3.6853630123991774, intRangeMinimum=3.6853630123991774, intRangeMaximum=3.6853630123991774, id='x4'}
