Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add] f1_score/precision_score/recall_score/accuracy_score #405

Merged
merged 22 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 121 additions & 7 deletions sml/metrics/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,127 @@ def equal_range(x: jnp.ndarray, n_bin: int) -> jnp.ndarray:
return result


# TODO: more evaluation tools
def _f1_score(y_true, y_pred):
"""Calculate the F1 score."""
tp = jnp.sum(y_true * y_pred)
fp = jnp.sum(y_pred) - tp
fn = jnp.sum(y_true) - tp
f1 = 2 * tp / (2 * tp + fp + fn + 1e-10)
return f1


def _precision_score(y_true, y_pred):
"""Calculate the Precision score."""
tp = jnp.sum(y_true * y_pred)
fp = jnp.sum(y_pred) - tp
precision = tp / (tp + fp + 1e-10)
return precision


def _recall_score(y_true, y_pred):
"""Calculate the Recall score."""
tp = jnp.sum(y_true * y_pred)
fn = jnp.sum(y_true) - tp
recall = tp / (tp + fn + 1e-10)
return recall


def compute_f1_score(
true_positive: jnp.ndarray, false_positive: jnp.ndarray, false_negative: jnp.ndarray
def accuracy_score(y_true, y_pred):
"""Calculate the Accuracy score."""
correct = jnp.sum(y_true == y_pred)
total = len(y_true)
accuracy = correct / total
return accuracy


def transform_binary(y_true, y_pred, label):
y_true_transform = jnp.where(y_true == label, 1, 0)
y_pred_transform = jnp.where(y_pred != label, 0, 1)
return y_true_transform, y_pred_transform


def f1_score(
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
):
"""Calculate the F1 score."""
precision = true_positive / (true_positive + false_positive)
recall = true_positive / (true_positive + false_negative)
return 2 * precision * recall / (precision + recall)
f1_result = fun_score(
_f1_score, y_true, y_pred, average, labels, pos_label, transform
)
return f1_result


def precision_score(
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
):
f1_result = fun_score(
_precision_score, y_true, y_pred, average, labels, pos_label, transform
)
return f1_result


def recall_score(
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
):
f1_result = fun_score(
_recall_score, y_true, y_pred, average, labels, pos_label, transform
)
return f1_result


def fun_score(
fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
):
"""
Compute precision, recall, f1.

Args:
fun : function, support '_precision_score' / '_recall_score' / '_f1_score'.

y_true : 1d array-like, ground truth (correct) target values.

y_pred : 1d array-like, estimated targets as returned by a classifier.

average : {'binary'} or None, default='binary'
This parameter is required for multiclass/multilabel targets.
If ``None``, the scores for each class are returned.

``'binary'``:
Only report results for the class specified by ``pos_label``.
This is applicable only if targets (``y_{true,pred}``) are binary

labels : array-like, default=None
The set of labels to include when ``average != 'binary'``.

pos_label : int, float, default=1
The class to report if ``average='binary'`` and the data is binary.
If the data are multiclass or multilabel, this will be ignored;

transform : bool, default=True
Binary classification only. If True, then the transformation of label to 0/1 will be done explicitly. Else, you can do it beforehand which decrease the costs of this function.

Returns:
-------
precision : float, shape = [n_unique_labels] for multi-classification
Precision score.

recall : float, shape = [n_unique_labels] for multi-classification
Recall score.

f1 : float, shape = [n_unique_labels] for multi-classification
F1 score.
"""

if average is None:
assert labels is not None, f"labels cannot be None"
fun_result = []
for i in labels:
y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i)
fun_result.append(fun(y_true_binary, y_pred_binary))
elif average == 'binary':
if transform:
y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label)
else:
y_true_binary, y_pred_binary = y_true, y_pred
fun_result = fun(y_true_binary, y_pred_binary)
else:
raise ValueError("average should be None or 'binary'")
return fun_result
113 changes: 94 additions & 19 deletions sml/metrics/classification/classification_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,113 @@
import sys

import numpy as np
import jax.numpy as jnp
from sklearn import metrics

# add ops dir to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))

import sml.utils.emulation as emulation
from sml.metrics.classification.classification import roc_auc_score
from sml.metrics.classification.classification import (
roc_auc_score,
f1_score,
precision_score,
recall_score,
accuracy_score,
)


# TODO: design the enumation framework, just like py.unittest
# all emulation action should begin with `emul_` (for reflection)
def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS):
try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
def emul_auc(mode: emulation.Mode.MULTIPROCESS):
# Create dataset
row = 10000
y_true = np.random.randint(0, 2, (row,))
y_pred = np.random.random((row,))

# Run
result = emulator.run(roc_auc_score)(
y_true, y_pred
) # X, y should be two-dimension array
print(result)


def emul_Classification(mode: emulation.Mode.MULTIPROCESS):
def proc(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
f1 = f1_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
emulator.up()
precision = precision_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
recall = recall_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
accuracy = accuracy_score(y_true, y_pred)
return f1, precision, recall, accuracy

def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
f1 = metrics.f1_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
precision = metrics.precision_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
recall = metrics.recall_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
accuracy = metrics.accuracy_score(y_true, y_pred)
return f1, precision, recall, accuracy

# Create dataset
row = 10000
y_true = np.random.randint(0, 2, (row,))
y_pred = np.random.random((row,))
def check(spu_result, sk_result):
for pair in zip(spu_result, sk_result):
np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)

# Run
result = emulator.run(roc_auc_score)(
y_true, y_pred
) # X, y should be two-dimension array
print(result)
# Test binary
y_true = jnp.array([0, 1, 1, 0, 1, 1])
y_pred = jnp.array([0, 0, 1, 0, 1, 1])
spu_result = emulator.run(proc, static_argnums=(2, 5))(
y_true, y_pred, 'binary', None, 1, False
)
sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)

finally:
emulator.down()
# Test multiclass
y_true = jnp.array([0, 1, 1, 0, 2, 1])
y_pred = jnp.array([0, 0, 1, 0, 2, 1])
spu_result = emulator.run(proc, static_argnums=(2, 5))(
y_true, y_pred, None, [0, 1, 2], 1, True
)
sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
check(spu_result, sk_result)


if __name__ == "__main__":
emul_SGDClassifier(emulation.Mode.MULTIPROCESS)
try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
emulation.CLUSTER_ABY3_3PC,
emulation.Mode.MULTIPROCESS,
bandwidth=300,
latency=20,
)
emulator.up()
emul_auc(emulation.Mode.MULTIPROCESS)
emul_Classification(emulation.Mode.MULTIPROCESS)
finally:
emulator.down()
77 changes: 76 additions & 1 deletion sml/metrics/classification/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import jax.numpy as jnp
import numpy as np
from sklearn import metrics

import spu.spu_pb2 as spu_pb2
import spu.utils.simulation as spsim
Expand All @@ -31,11 +32,15 @@
bin_counts,
equal_obs,
roc_auc_score,
f1_score,
precision_score,
recall_score,
accuracy_score,
)


class UnitTests(unittest.TestCase):
def test_simple(self):
def test_auc(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
Expand Down Expand Up @@ -79,6 +84,76 @@ def digitize(y_pred, thresholds):

np.testing.assert_almost_equal(true_score, score, decimal=2)

def test_classification(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

def proc(
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
):
f1 = f1_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
precision = precision_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
recall = recall_score(
y_true,
y_pred,
average=average,
labels=labels,
pos_label=pos_label,
transform=transform,
)
accuracy = accuracy_score(y_true, y_pred)
return f1, precision, recall, accuracy

def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
f1 = metrics.f1_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
precision = metrics.precision_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
recall = metrics.recall_score(
y_true, y_pred, average=average, labels=labels, pos_label=pos_label
)
accuracy = metrics.accuracy_score(y_true, y_pred)
return f1, precision, recall, accuracy

def check(spu_result, sk_result):
for pair in zip(spu_result, sk_result):
np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)

# Test binary
y_true = jnp.array([0, 1, 1, 0, 1, 1])
y_pred = jnp.array([0, 0, 1, 0, 1, 1])
spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(
y_true, y_pred, 'binary', None, 1, False
)
sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)

# Test multiclass
y_true = jnp.array([0, 1, 1, 0, 2, 1])
y_pred = jnp.array([0, 0, 1, 0, 2, 1])
spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(
y_true, y_pred, None, [0, 1, 2], 1, True
)
sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
check(spu_result, sk_result)


if __name__ == "__main__":
unittest.main()
Loading