# Counterfactual explanations

In [1]:
import trustyai

trustyai.init(
    path=[
        "../dep/org/kie/kogito/explainability-core/1.8.0.Final/*",
        "../dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar",
        "../dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar",
        "../dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar",
        "../dep/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar",
        "../dep/org/kie/kie-api/7.55.0.Final/kie-api-7.55.0.Final.jar",
        "../dep/io/micrometer/micrometer-core/1.6.6/micrometer-core-1.6.6.jar",
    ]
)

## Simple example

We start by defining our black-box model, typically represented by

$$
f(\mathbf{x}) = \mathbf{y}
$$

Where $\mathbf{x}=\{x_1, x_2, \dots,x_m\}$ and $\mathbf{y}=\{y_1, y_2, \dots,y_n\}$.

Our example toy model, in this case, takes an all-numerical input $\mathbf{x}$ and return a $\mathbf{y}$ of either `true` or `false` if the sum of the $\mathbf{x}$ components is within a threshold $\epsilon$ of a point $\mathbf{C}$, that is:

$$
f(\mathbf{x}, \epsilon, \mathbf{C})=\begin{cases}
\text{true},\qquad \text{if}\ \mathbf{C}-\epsilon<\sum_{i=1}^m x_i <\mathbf{C}+\epsilon \\
\text{false},\qquad \text{otherwise}
\end{cases}
$$

This model is provided in the `TestUtils` module. We instantiate with a $\mathbf{C}=500$ and $\epsilon=1.0$.

In [4]:
from trustyai.utils import TestUtils

center = 500.0
epsilon = 10.0

model = TestUtils.getSumThresholdModel(center, epsilon)

Next we need to define a **goal**.
If our model is $f(\mathbf{x'})=\mathbf{y'}$ we are then defining our $\mathbf{y'}$ and the counterfactual result will be the $\mathbf{x'}$ which satisfies $f(\mathbf{x'})=\mathbf{y'}$.

We will define our goal as `true`, that is, the sum is withing the vicinity of a (to be defined) point $\mathbf{C}$. The goal is a list of `Output` which take the following parameters

- The feature name
- The feature type
- The feature value (wrapped in `Value`)
- A confidence threshold, which we will leave at zero (no threshold)

In [5]:
from trustyai.model import Output, Type, Value

goal = [Output("inside", Type.BOOLEAN, Value(True), 0.0)]

In [None]:
import random
from trustyai.model import FeatureFactory

features = [FeatureFactory.newNumericalFeature(f"f-num{i+1}", random.random()*10.0) for i in range(4)]

for f in features:
    print(f"Feature {f.getName()} has value {f.getValue()}")

In [None]:
constraints = [False] * 4

In [None]:
from trustyai.model.domain import NumericalFeatureDomain

feature_boundaries = [NumericalFeatureDomain.create(0.0, 1000.0)] * 4

In [None]:
from trustyai.model import DataDomain

data_domain = DataDomain(feature_boundaries)

In [None]:
center = 500.0
epsilon = 10.0

In [None]:
from trustyai.utils import TestUtils

model = TestUtils.getSumThresholdModel(center, epsilon)

In [None]:
from org.optaplanner.core.config.solver.termination import TerminationConfig
from org.kie.kogito.explainability.local.counterfactual import CounterfactualConfigurationFactory
from java.lang import Long

termination_config = TerminationConfig().withScoreCalculationCountLimit(Long.valueOf(10_000))

solver_config = (
        CounterfactualConfigurationFactory.builder()
        .withTerminationConfig(termination_config)
        .build()
    )

In [None]:
from org.kie.kogito.explainability.local.counterfactual import CounterfactualExplainer

explainer = CounterfactualExplainer.builder().withSolverConfig(solver_config).build()

In [None]:
from trustyai.model import PredictionFeatureDomain, PredictionInput, PredictionOutput

inputs = PredictionInput(features)
outputs = PredictionOutput(goal)
domain = PredictionFeatureDomain(data_domain.getFeatureDomains())

In [None]:
import uuid
from trustyai.model import CounterfactualPrediction

prediction = CounterfactualPrediction(inputs, outputs, domain, constraints, None, uuid.uuid4())

In [None]:
explanation_async = explainer.explainAsync(prediction, model)

In [None]:
explanation = explanation_async.get()

In [None]:
for entity in explanation.getEntities():
    print(entity)