Skip to content

Commit

Permalink
Add support for SAX and PAA for CF sparsity validation (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
isaksamsten committed Nov 30, 2022
1 parent 6838b03 commit f7627d6
Showing 1 changed file with 47 additions and 6 deletions.
53 changes: 47 additions & 6 deletions src/wildboar/metrics/_counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

from ..distance import paired_distance, pairwise_distance
from ..explain.counterfactual import proximity
from ..transform import (
piecewice_aggregate_approximation,
symbolic_aggregate_approximation,
)
from ..utils.validation import check_array


Expand Down Expand Up @@ -276,7 +280,15 @@ def proximity_score(
)


def compactness_score(x_true, x_counterfactuals, *, rtol=1.0e-5, atol=1.0e-8):
def compactness_score(
x_true,
x_counterfactuals,
*,
window=None,
n_bins=None,
atol=1.0e-8,
average=True,
):
"""Return the compactness of the counterfactuals as measured by the
fraction of changed timesteps. The fewer timesteps have changed between the original
and the counterfactual, the lower the score.
Expand All @@ -291,11 +303,14 @@ def compactness_score(x_true, x_counterfactuals, *, rtol=1.0e-5, atol=1.0e-8):
or (n_samples, n_dims, n_timeteps)
The counterfactual samples
rtol : float, optional
Parameter to `np.isclose`.
atol : float,
The absolute tolerance.
window : int, optional
If set, evaluate the difference between windows of specified size.
atol : float, optional
Parameter to `np.isclose`.
n_bins : int, optional
If set, evaluate the set overlap of SAX transformed series.
Returns
-------
Expand All @@ -316,7 +331,33 @@ def compactness_score(x_true, x_counterfactuals, *, rtol=1.0e-5, atol=1.0e-8):
% (x_true.shape, x_counterfactuals.shape)
)

return 1 - np.mean(np.isclose(x_counterfactuals, x_true, rtol=rtol, atol=atol))
if window is not None:
if n_bins is not None:

def score(x_counterfactuals, x_true):
x_counterfactuals = symbolic_aggregate_approximation(
x_counterfactuals, window=window, n_bins=n_bins
)
x_true = symbolic_aggregate_approximation(
x_true, window=window, n_bins=n_bins
)
return x_counterfactuals == x_true

else:

def score(x_counterfactuals, x_true):
x_counterfactuals = piecewice_aggregate_approximation(
x_counterfactuals, window=window
)
x_true = piecewice_aggregate_approximation(x_true, window=window)
return np.mean(np.isclose(x_counterfactuals, x_true, rtol=0, atol=atol))

else:

def score(x_counterfactuals, x_true):
return np.isclose(x_counterfactuals, x_true, rtol=0, atol=atol)

return 1 - np.mean(score(x_counterfactuals, x_true), axis=None if average else 0)


def validity_score(y_pred, y_counterfactual, sample_weight=None):
Expand Down

0 comments on commit f7627d6

Please sign in to comment.