Skip to content

Commit

Permalink
[Add] f1_score/precision_score/recall_score/accuracy_score (#405)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #383 

## Possible side effects?

- Performance: support multi-classification

- Backward compatibility:
  • Loading branch information
tarantula-leo committed Nov 19, 2023
1 parent c8b7283 commit 5746ad0
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 27 deletions.
128 changes: 121 additions & 7 deletions sml/metrics/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,127 @@ def equal_range(x: jnp.ndarray, n_bin: int) -> jnp.ndarray:
return result


# TODO: more evaluation tools
def _f1_score(y_true, y_pred):
"""Calculate the F1 score."""
tp = jnp.sum(y_true * y_pred)
fp = jnp.sum(y_pred) - tp
fn = jnp.sum(y_true) - tp
f1 = 2 * tp / (2 * tp + fp + fn + 1e-10)
return f1


def _precision_score(y_true, y_pred):
"""Calculate the Precision score."""
tp = jnp.sum(y_true * y_pred)
fp = jnp.sum(y_pred) - tp
precision = tp / (tp + fp + 1e-10)
return precision


def _recall_score(y_true, y_pred):
"""Calculate the Recall score."""
tp = jnp.sum(y_true * y_pred)
fn = jnp.sum(y_true) - tp
recall = tp / (tp + fn + 1e-10)
return recall


def compute_f1_score(
true_positive: jnp.ndarray, false_positive: jnp.ndarray, false_negative: jnp.ndarray
def accuracy_score(y_true, y_pred):
"""Calculate the Accuracy score."""
correct = jnp.sum(y_true == y_pred)
total = len(y_true)
accuracy = correct / total
return accuracy


def transform_binary(y_true, y_pred, label):
y_true_transform = jnp.where(y_true == label, 1, 0)
y_pred_transform = jnp.where(y_pred != label, 0, 1)
return y_true_transform, y_pred_transform


def f1_score(
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
):
"""Calculate the F1 score."""
precision = true_positive / (true_positive + false_positive)
recall = true_positive / (true_positive + false_negative)
return 2 * precision * recall / (precision + recall)
f1_result = fun_score(
_f1_score, y_true, y_pred, average, labels, pos_label, transform
)
return f1_result


def precision_score(
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
):
f1_result = fun_score(
_precision_score, y_true, y_pred, average, labels, pos_label, transform
)
return f1_result


def recall_score(
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
):
f1_result = fun_score(
_recall_score, y_true, y_pred, average, labels, pos_label, transform
)
return f1_result


def fun_score(
fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
):
"""
Compute precision, recall, f1.
Args:
fun : function, support '_precision_score' / '_recall_score' / '_f1_score'.
y_true : 1d array-like, ground truth (correct) target values.
y_pred : 1d array-like, estimated targets as returned by a classifier.
average : {'binary'} or None, default='binary'
This parameter is required for multiclass/multilabel targets.
If ``None``, the scores for each class are returned.
``'binary'``:
Only report results for the class specified by ``pos_label``.
This is applicable only if targets (``y_{true,pred}``) are binary
labels : array-like, default=None
The set of labels to include when ``average != 'binary'``.
pos_label : int, float, default=1
The class to report if ``average='binary'`` and the data is binary.
If the data are multiclass or multilabel, this will be ignored;
transform : bool, default=True
Binary classification only. If True, then the transformation of label to 0/1 will be done explicitly. Else, you can do it beforehand which decrease the costs of this function.
Returns:
-------
precision : float, shape = [n_unique_labels] for multi-classification
Precision score.
recall : float, shape = [n_unique_labels] for multi-classification
Recall score.
f1 : float, shape = [n_unique_labels] for multi-classification
F1 score.
"""

if average is None:
assert labels is not None, f"labels cannot be None"
fun_result = []
for i in labels:
y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i)
fun_result.append(fun(y_true_binary, y_pred_binary))
elif average == 'binary':
if transform:
y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label)
else:
y_true_binary, y_pred_binary = y_true, y_pred
fun_result = fun(y_true_binary, y_pred_binary)
else:
raise ValueError("average should be None or 'binary'")
return fun_result
113 changes: 94 additions & 19 deletions sml/metrics/classification/classification_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,113 @@
import sys

import numpy as np
import jax.numpy as jnp
from sklearn import metrics

# add ops dir to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))

import sml.utils.emulation as emulation
from sml.metrics.classification.classification import roc_auc_score
from sml.metrics.classification.classification import (
roc_auc_score,
f1_score,
precision_score,
recall_score,
accuracy_score,
)


# TODO: design the enumation framework, just like py.unittest
# all emulation action should begin with `emul_` (for reflection)
def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS):
try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
def emul_auc(mode: emulation.Mode.MULTIPROCESS):
# Create dataset
row = 10000
y_true = np.random.randint(0, 2, (row,))
y_pred = np.random.random((row,))

# Run
result = emulator.run(roc_auc_score)(
y_true, y_pred
) # X, y should be two-dimension array
print(result)


def emul_Classification(mode: emulation.Mode.MULTIPROCESS):
def proc(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
f1 = f1_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
emulator.up()
precision = precision_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
recall = recall_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
accuracy = accuracy_score(y_true, y_pred)
return f1, precision, recall, accuracy

def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
f1 = metrics.f1_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
precision = metrics.precision_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
recall = metrics.recall_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
accuracy = metrics.accuracy_score(y_true, y_pred)
return f1, precision, recall, accuracy

# Create dataset
row = 10000
y_true = np.random.randint(0, 2, (row,))
y_pred = np.random.random((row,))
def check(spu_result, sk_result):
for pair in zip(spu_result, sk_result):
np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)

# Run
result = emulator.run(roc_auc_score)(
y_true, y_pred
) # X, y should be two-dimension array
print(result)
# Test binary
y_true = jnp.array([0, 1, 1, 0, 1, 1])
y_pred = jnp.array([0, 0, 1, 0, 1, 1])
spu_result = emulator.run(proc, static_argnums=(2, 5))(
y_true, y_pred, 'binary', None, 1, False
)
sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)

finally:
emulator.down()
# Test multiclass
y_true = jnp.array([0, 1, 1, 0, 2, 1])
y_pred = jnp.array([0, 0, 1, 0, 2, 1])
spu_result = emulator.run(proc, static_argnums=(2, 5))(
y_true, y_pred, None, [0, 1, 2], 1, True
)
sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
check(spu_result, sk_result)


if __name__ == "__main__":
emul_SGDClassifier(emulation.Mode.MULTIPROCESS)
try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
emulation.CLUSTER_ABY3_3PC,
emulation.Mode.MULTIPROCESS,
bandwidth=300,
latency=20,
)
emulator.up()
emul_auc(emulation.Mode.MULTIPROCESS)
emul_Classification(emulation.Mode.MULTIPROCESS)
finally:
emulator.down()
77 changes: 76 additions & 1 deletion sml/metrics/classification/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import jax.numpy as jnp
import numpy as np
from sklearn import metrics

import spu.spu_pb2 as spu_pb2
import spu.utils.simulation as spsim
Expand All @@ -31,11 +32,15 @@
bin_counts,
equal_obs,
roc_auc_score,
f1_score,
precision_score,
recall_score,
accuracy_score,
)


class UnitTests(unittest.TestCase):
def test_simple(self):
def test_auc(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
Expand Down Expand Up @@ -79,6 +84,76 @@ def digitize(y_pred, thresholds):

np.testing.assert_almost_equal(true_score, score, decimal=2)

def test_classification(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

def proc(
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
):
f1 = f1_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
precision = precision_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
recall = recall_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
accuracy = accuracy_score(y_true, y_pred)
return f1, precision, recall, accuracy

def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
f1 = metrics.f1_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
precision = metrics.precision_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
recall = metrics.recall_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
accuracy = metrics.accuracy_score(y_true, y_pred)
return f1, precision, recall, accuracy

def check(spu_result, sk_result):
for pair in zip(spu_result, sk_result):
np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)

# Test binary
y_true = jnp.array([0, 1, 1, 0, 1, 1])
y_pred = jnp.array([0, 0, 1, 0, 1, 1])
spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(
y_true, y_pred, 'binary', None, 1, False
)
sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)

# Test multiclass
y_true = jnp.array([0, 1, 1, 0, 2, 1])
y_pred = jnp.array([0, 0, 1, 0, 2, 1])
spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(
y_true, y_pred, None, [0, 1, 2], 1, True
)
sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
check(spu_result, sk_result)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5746ad0

Please sign in to comment.