From 9ccdec3883680eb4c109f01f7d2918593aefbfb2 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 15:41:44 +0800 Subject: [PATCH 01/22] Update classification.py --- sml/metrics/classification/classification.py | 110 +++++++++++++++++-- 1 file changed, 102 insertions(+), 8 deletions(-) diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py index 6b8c7fe2..b142aec4 100644 --- a/sml/metrics/classification/classification.py +++ b/sml/metrics/classification/classification.py @@ -98,13 +98,107 @@ def equal_range(x: jnp.ndarray, n_bin: int) -> jnp.ndarray: return result -# TODO: more evaluation tools +def _f1_score(y_true, y_pred): + """Calculate the F1 score.""" + tp = jnp.sum(y_true * y_pred) + fp = jnp.sum(y_pred) - tp + fn = jnp.sum(y_true) - tp + f1 = 2 * tp / (2 * tp + fp + fn + 1e-10) + return f1 + +def _precision_score(y_true, y_pred): + """Calculate the Precision score.""" + tp = jnp.sum(y_true * y_pred) + fp = jnp.sum(y_pred) - tp + precision = tp / (tp + fp + 1e-10) + return precision + +def _recall_score(y_true, y_pred): + """Calculate the Recall score.""" + tp = jnp.sum(y_true * y_pred) + fn = jnp.sum(y_true) - tp + recall = tp / (tp + fn + 1e-10) + return recall + +def accuracy_score(y_true, y_pred): + """Calculate the Accuracy score.""" + correct = jnp.sum(y_true == y_pred) + total = len(y_true) + accuracy = correct / total + return accuracy + +def transform_binary(y_true, y_pred, label): + y_true_transform = jnp.where(y_true == label, 1, 0) + y_pred_transform = jnp.where(y_pred != label, 0, 1) + return y_true_transform, y_pred_transform + +def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): + f1_result = fun_score(_f1_score, y_true, y_pred, average, labels, pos_label, transform) + return f1_result + +def precision_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): + f1_result = fun_score(_precision_score, y_true, y_pred, average, labels, pos_label, transform) + return f1_result + +def recall_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): + f1_result = fun_score(_recall_score, y_true, y_pred, average, labels, pos_label, transform) + return f1_result + +def fun_score(fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): + """ + Compute precision, recall, f1. + Args: + fun : function, support '_precision_score' / '_recall_score' / '_f1_score'. -def compute_f1_score( - true_positive: jnp.ndarray, false_positive: jnp.ndarray, false_negative: jnp.ndarray -): - """Calculate the F1 score.""" - precision = true_positive / (true_positive + false_positive) - recall = true_positive / (true_positive + false_negative) - return 2 * precision * recall / (precision + recall) + y_true : 1d array-like, ground truth (correct) target values. + + y_pred : 1d array-like, estimated targets as returned by a classifier. + + average : {'binary'} or None, default='binary' + This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. + + ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary + + labels : array-like, default=None + The set of labels to include when ``average != 'binary'``. + + pos_label : int, float, default=1 + The class to report if ``average='binary'`` and the data is binary. + If the data are multiclass or multilabel, this will be ignored; + + transform : bool, default=1 + The problem is transformed into a binary classification with positive samples labeled 1 and negative samples labeled 0. + + Returns: + ------- + precision : float, shape = [n_unique_labels] for multi-classification + Precision score. + + recall : float, shape = [n_unique_labels] for multi-classification + Recall score. + + f1 : float, shape = [n_unique_labels] for multi-classification + F1 score. + """ + + if average is None: + assert ( + labels is not None + ), f"labels cannot be None" + fun_result = [] + for i in labels: + y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i) + fun_result.append(fun(y_true_binary, y_pred_binary)) + elif average == 'binary': + if transform is True: + y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label) + else: + y_true_binary, y_pred_binary = y_true, y_pred + fun_result = fun(y_true_binary, y_pred_binary) + else: + raise ValueError("average should be None or 'binary'") + return fun_result From 3dba13c170f753cd416e0ebab453d40d39b519c0 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 15:51:24 +0800 Subject: [PATCH 02/22] Update classification.py --- sml/metrics/classification/classification.py | 46 ++++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py index b142aec4..96e5374b 100644 --- a/sml/metrics/classification/classification.py +++ b/sml/metrics/classification/classification.py @@ -106,6 +106,7 @@ def _f1_score(y_true, y_pred): f1 = 2 * tp / (2 * tp + fp + fn + 1e-10) return f1 + def _precision_score(y_true, y_pred): """Calculate the Precision score.""" tp = jnp.sum(y_true * y_pred) @@ -113,12 +114,14 @@ def _precision_score(y_true, y_pred): precision = tp / (tp + fp + 1e-10) return precision + def _recall_score(y_true, y_pred): """Calculate the Recall score.""" tp = jnp.sum(y_true * y_pred) fn = jnp.sum(y_true) - tp recall = tp / (tp + fn + 1e-10) - return recall + return recall + def accuracy_score(y_true, y_pred): """Calculate the Accuracy score.""" @@ -127,24 +130,41 @@ def accuracy_score(y_true, y_pred): accuracy = correct / total return accuracy + def transform_binary(y_true, y_pred, label): y_true_transform = jnp.where(y_true == label, 1, 0) y_pred_transform = jnp.where(y_pred != label, 0, 1) return y_true_transform, y_pred_transform + def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): - f1_result = fun_score(_f1_score, y_true, y_pred, average, labels, pos_label, transform) - return f1_result + f1_result = fun_score( + _f1_score, y_true, y_pred, average, labels, pos_label, transform + ) + return f1_result -def precision_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): - f1_result = fun_score(_precision_score, y_true, y_pred, average, labels, pos_label, transform) - return f1_result -def recall_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): - f1_result = fun_score(_recall_score, y_true, y_pred, average, labels, pos_label, transform) - return f1_result +def precision_score( + y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 +): + f1_result = fun_score( + _precision_score, y_true, y_pred, average, labels, pos_label, transform + ) + return f1_result + + +def recall_score( + y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 +): + f1_result = fun_score( + _recall_score, y_true, y_pred, average, labels, pos_label, transform + ) + return f1_result + -def fun_score(fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): +def fun_score( + fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 +): """ Compute precision, recall, f1. @@ -169,7 +189,7 @@ def fun_score(fun, y_true, y_pred, average='binary', labels=None, pos_label=1, t pos_label : int, float, default=1 The class to report if ``average='binary'`` and the data is binary. If the data are multiclass or multilabel, this will be ignored; - + transform : bool, default=1 The problem is transformed into a binary classification with positive samples labeled 1 and negative samples labeled 0. @@ -186,9 +206,7 @@ def fun_score(fun, y_true, y_pred, average='binary', labels=None, pos_label=1, t """ if average is None: - assert ( - labels is not None - ), f"labels cannot be None" + assert labels is not None, f"labels cannot be None" fun_result = [] for i in labels: y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i) From b414e76e2d97e5f368a04622c877c90029f52647 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:11:42 +0800 Subject: [PATCH 03/22] Update classification.py --- sml/metrics/classification/classification.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py index 96e5374b..8f78d41d 100644 --- a/sml/metrics/classification/classification.py +++ b/sml/metrics/classification/classification.py @@ -141,7 +141,7 @@ def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transfo f1_result = fun_score( _f1_score, y_true, y_pred, average, labels, pos_label, transform ) - return f1_result + return f1_result def precision_score( @@ -150,7 +150,7 @@ def precision_score( f1_result = fun_score( _precision_score, y_true, y_pred, average, labels, pos_label, transform ) - return f1_result + return f1_result def recall_score( @@ -159,7 +159,7 @@ def recall_score( f1_result = fun_score( _recall_score, y_true, y_pred, average, labels, pos_label, transform ) - return f1_result + return f1_result def fun_score( From 4804cb17cc58bdf837f9daf0a379354a6f62e26b Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:38:32 +0800 Subject: [PATCH 04/22] Update classification_test.py --- .../classification/classification_test.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index 8ff59db5..db14850b 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -80,5 +80,43 @@ def digitize(y_pred, thresholds): np.testing.assert_almost_equal(true_score, score, decimal=2) + def test_classification(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + ) + + def proc(y_true, y_pred, average='binary', labels=None, pos_label=1 ,transform=1): + f1 = f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) + precision = precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) + recall = recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) + accuracy = accuracy_score(y_true, y_pred) + return f1,precision,recall,accuracy + + def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): + f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) + precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) + recall = metrics.recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) + accuracy = metrics.accuracy_score(y_true, y_pred) + return f1,precision,recall,accuracy + + def check(spu_result, sk_result): + for pair in zip(spu_result, sk_result): + np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5) + + # Test binary + y_true = jnp.array([0, 1, 1, 0, 1, 1]) + y_pred = jnp.array([0, 0, 1, 0, 1, 1]) + spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, pos_label=1, transform=0) + sk_result = sklearn_proc(y_true, y_pred) + check(spu_result, sk_result) + + # Test multiclass + y_true = jnp.array([0, 1, 1, 0, 2, 1]) + y_pred = jnp.array([0, 0, 1, 0, 2, 1]) + spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, average=None, labels=[0,1,2]) + sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0,1,2]) + check(spu_result, sk_result) + + if __name__ == "__main__": unittest.main() From 44c001409b45938bdd68481afc1ad2cdb28e0bd7 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:47:09 +0800 Subject: [PATCH 05/22] Update classification_emul.py --- .../classification/classification_emul.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index fbddcf8d..87e900a5 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -15,12 +15,14 @@ import sys import numpy as np +import jax.numpy as jnp +from sklearn import metrics # add ops dir to the path sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) import sml.utils.emulation as emulation -from sml.metrics.classification.classification import roc_auc_score +from sml.metrics.classification.classification import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score # TODO: design the enumation framework, just like py.unittest @@ -48,5 +50,51 @@ def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS): emulator.down() + +def emul_Classification(mode: emulation.Mode.MULTIPROCESS): + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + ) + emulator.up() + + def proc(y_true, y_pred, average='binary', labels=None, pos_label=1 ,transform=1): + f1 = f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) + precision = precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) + recall = recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) + accuracy = accuracy_score(y_true, y_pred) + return f1,precision,recall,accuracy + + def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): + f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) + precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) + recall = metrics.recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) + accuracy = metrics.accuracy_score(y_true, y_pred) + return f1,precision,recall,accuracy + + def check(spu_result, sk_result): + for pair in zip(spu_result, sk_result): + np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5) + + # Test binary + y_true = jnp.array([0, 1, 1, 0, 1, 1]) + y_pred = jnp.array([0, 0, 1, 0, 1, 1]) + spu_result = emulator.run(proc)(y_true, y_pred, pos_label=1, transform=0) + sk_result = sklearn_proc(y_true, y_pred) + check(spu_result, sk_result) + + # Test multiclass + y_true = jnp.array([0, 1, 1, 0, 2, 1]) + y_pred = jnp.array([0, 0, 1, 0, 2, 1]) + spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0,1,2]) + sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0,1,2]) + check(spu_result, sk_result) + + finally: + emulator.down() + + if __name__ == "__main__": emul_SGDClassifier(emulation.Mode.MULTIPROCESS) + emul_Classification(emulation.Mode.MULTIPROCESS) From 6df82f35ace47e5dff7ccd6065e2c6e15a6a7868 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:52:14 +0800 Subject: [PATCH 06/22] Update classification_emul.py --- .../classification/classification_emul.py | 68 ++++++++++++++----- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index 87e900a5..53c53fe2 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -22,7 +22,13 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) import sml.utils.emulation as emulation -from sml.metrics.classification.classification import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score +from sml.metrics.classification.classification import ( + roc_auc_score, + f1_score, + precision_score, + recall_score, + accuracy_score, +) # TODO: design the enumation framework, just like py.unittest @@ -50,7 +56,6 @@ def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS): emulator.down() - def emul_Classification(mode: emulation.Mode.MULTIPROCESS): try: # bandwidth and latency only work for docker mode @@ -59,29 +64,60 @@ def emul_Classification(mode: emulation.Mode.MULTIPROCESS): ) emulator.up() - def proc(y_true, y_pred, average='binary', labels=None, pos_label=1 ,transform=1): - f1 = f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) - precision = precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) - recall = recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) +def proc( + y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 + ): + f1 = f1_score( + y_true, + y_pred, + average=average, + labels=labels, + pos_label=pos_label, + transform=transform, + ) + precision = precision_score( + y_true, + y_pred, + average=average, + labels=labels, + pos_label=pos_label, + transform=transform, + ) + recall = recall_score( + y_true, + y_pred, + average=average, + labels=labels, + pos_label=pos_label, + transform=transform, + ) accuracy = accuracy_score(y_true, y_pred) - return f1,precision,recall,accuracy - + return f1, precision, recall, accuracy + + def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): - f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) - precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) - recall = metrics.recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) + f1 = metrics.f1_score( + y_true, y_pred, average=average, labels=labels, pos_label=pos_label + ) + precision = metrics.precision_score( + y_true, y_pred, average=average, labels=labels, pos_label=pos_label + ) + recall = metrics.recall_score( + y_true, y_pred, average=average, labels=labels, pos_label=pos_label + ) accuracy = metrics.accuracy_score(y_true, y_pred) - return f1,precision,recall,accuracy - + return f1, precision, recall, accuracy + + def check(spu_result, sk_result): - for pair in zip(spu_result, sk_result): + for pair in zip(spu_result, sk_result): np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5) # Test binary y_true = jnp.array([0, 1, 1, 0, 1, 1]) y_pred = jnp.array([0, 0, 1, 0, 1, 1]) - spu_result = emulator.run(proc)(y_true, y_pred, pos_label=1, transform=0) - sk_result = sklearn_proc(y_true, y_pred) + spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2]) + sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) check(spu_result, sk_result) # Test multiclass From f7bb89d3df69adffe387a74cb0ad537f3524d6b7 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:58:17 +0800 Subject: [PATCH 07/22] Update classification_test.py --- .../classification/classification_test.py | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index db14850b..e7c3da81 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import numpy as np +from sklearn import metrics import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as spsim @@ -31,6 +32,10 @@ bin_counts, equal_obs, roc_auc_score, + f1_score, + precision_score, + recall_score, + accuracy_score, ) @@ -79,19 +84,42 @@ def digitize(y_pred, thresholds): np.testing.assert_almost_equal(true_score, score, decimal=2) - def test_classification(self): sim = spsim.Simulator.simple( 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 ) - def proc(y_true, y_pred, average='binary', labels=None, pos_label=1 ,transform=1): - f1 = f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) - precision = precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) - recall = recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform) + def proc( + y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 + ): + f1 = f1_score( + y_true, + y_pred, + average=average, + labels=labels, + pos_label=pos_label, + transform=transform, + ) + precision = precision_score( + y_true, + y_pred, + average=average, + labels=labels, + pos_label=pos_label, + transform=transform, + ) + recall = recall_score( + y_true, + y_pred, + average=average, + labels=labels, + pos_label=pos_label, + transform=transform, + ) accuracy = accuracy_score(y_true, y_pred) - return f1,precision,recall,accuracy - + return f1, precision, recall, accuracy + + def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) @@ -113,8 +141,10 @@ def check(spu_result, sk_result): # Test multiclass y_true = jnp.array([0, 1, 1, 0, 2, 1]) y_pred = jnp.array([0, 0, 1, 0, 2, 1]) - spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, average=None, labels=[0,1,2]) - sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0,1,2]) + spu_result = spsim.sim_jax(sim, proc)( + y_true, y_pred, average=None, labels=[0, 1, 2] + ) + sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) check(spu_result, sk_result) From 6d6f2254abdef4ee4a3e2dae582088d3e1de3212 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:58:35 +0800 Subject: [PATCH 08/22] Update classification_emul.py --- sml/metrics/classification/classification_emul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index 53c53fe2..a5f88e58 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -64,7 +64,7 @@ def emul_Classification(mode: emulation.Mode.MULTIPROCESS): ) emulator.up() -def proc( + def proc( y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 ): f1 = f1_score( From 9e5adb8a5606d8626928e07656b958a44e885e0c Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 17:02:33 +0800 Subject: [PATCH 09/22] Update classification_emul.py --- sml/metrics/classification/classification_emul.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index a5f88e58..6df71b94 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -63,7 +63,7 @@ def emul_Classification(mode: emulation.Mode.MULTIPROCESS): emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 ) emulator.up() - + def proc( y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 ): @@ -94,7 +94,6 @@ def proc( accuracy = accuracy_score(y_true, y_pred) return f1, precision, recall, accuracy - def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): f1 = metrics.f1_score( y_true, y_pred, average=average, labels=labels, pos_label=pos_label @@ -108,7 +107,6 @@ def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): accuracy = metrics.accuracy_score(y_true, y_pred) return f1, precision, recall, accuracy - def check(spu_result, sk_result): for pair in zip(spu_result, sk_result): np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5) @@ -123,8 +121,8 @@ def check(spu_result, sk_result): # Test multiclass y_true = jnp.array([0, 1, 1, 0, 2, 1]) y_pred = jnp.array([0, 0, 1, 0, 2, 1]) - spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0,1,2]) - sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0,1,2]) + spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2]) + sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) check(spu_result, sk_result) finally: From 8c193e5bc853f5c0b780ed897c994d441354c788 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 17:03:52 +0800 Subject: [PATCH 10/22] Update classification_test.py --- .../classification/classification_test.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index e7c3da81..b1c8e007 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -121,14 +121,20 @@ def proc( def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): - f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) - precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) - recall = metrics.recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label) + f1 = metrics.f1_score( + y_true, y_pred, average=average, labels=labels, pos_label=pos_label + ) + precision = metrics.precision_score( + y_true, y_pred, average=average, labels=labels, pos_label=pos_label + ) + recall = metrics.recall_score( + y_true, y_pred, average=average, labels=labels, pos_label=pos_label + ) accuracy = metrics.accuracy_score(y_true, y_pred) - return f1,precision,recall,accuracy + return f1, precision, recall, accuracy def check(spu_result, sk_result): - for pair in zip(spu_result, sk_result): + for pair in zip(spu_result, sk_result): np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5) # Test binary From 6258145c0740bd2b2bb993cd46eda2b5b04e16e7 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 17:05:08 +0800 Subject: [PATCH 11/22] Update classification_test.py --- sml/metrics/classification/classification_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index b1c8e007..fd5dafc7 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -119,7 +119,6 @@ def proc( accuracy = accuracy_score(y_true, y_pred) return f1, precision, recall, accuracy - def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): f1 = metrics.f1_score( y_true, y_pred, average=average, labels=labels, pos_label=pos_label @@ -132,7 +131,7 @@ def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): ) accuracy = metrics.accuracy_score(y_true, y_pred) return f1, precision, recall, accuracy - + def check(spu_result, sk_result): for pair in zip(spu_result, sk_result): np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5) From 56c219961d79a72b1251fe4302eae4ddcca04c71 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 17:59:23 +0800 Subject: [PATCH 12/22] change function name --- sml/metrics/classification/classification_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index fd5dafc7..89c37bd0 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -40,7 +40,7 @@ class UnitTests(unittest.TestCase): - def test_simple(self): + def test_auc(self): sim = spsim.Simulator.simple( 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 ) From 208b81f2fe73b5601dbfc6ee5e2eaa58dcca63af Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 17:59:59 +0800 Subject: [PATCH 13/22] change function name --- sml/metrics/classification/classification_emul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index 6df71b94..4da6a3a6 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -33,7 +33,7 @@ # TODO: design the enumation framework, just like py.unittest # all emulation action should begin with `emul_` (for reflection) -def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS): +def emul_auc(mode: emulation.Mode.MULTIPROCESS): try: # bandwidth and latency only work for docker mode emulator = emulation.Emulator( @@ -130,5 +130,5 @@ def check(spu_result, sk_result): if __name__ == "__main__": - emul_SGDClassifier(emulation.Mode.MULTIPROCESS) + emul_auc(emulation.Mode.MULTIPROCESS) emul_Classification(emulation.Mode.MULTIPROCESS) From 25626b0cf8f9ce92f98b62ecc7a47633837c7034 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 21:09:31 +0800 Subject: [PATCH 14/22] Update classification_emul.py --- .../classification/classification_emul.py | 161 +++++++++--------- 1 file changed, 76 insertions(+), 85 deletions(-) diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index 4da6a3a6..6f698d1c 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -34,101 +34,92 @@ # TODO: design the enumation framework, just like py.unittest # all emulation action should begin with `emul_` (for reflection) def emul_auc(mode: emulation.Mode.MULTIPROCESS): - try: - # bandwidth and latency only work for docker mode - emulator = emulation.Emulator( - emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 - ) - emulator.up() - - # Create dataset - row = 10000 - y_true = np.random.randint(0, 2, (row,)) - y_pred = np.random.random((row,)) - - # Run - result = emulator.run(roc_auc_score)( - y_true, y_pred - ) # X, y should be two-dimension array - print(result) + # Create dataset + row = 10000 + y_true = np.random.randint(0, 2, (row,)) + y_pred = np.random.random((row,)) - finally: - emulator.down() + # Run + result = emulator.run(roc_auc_score)( + y_true, y_pred + ) # X, y should be two-dimension array + print(result) def emul_Classification(mode: emulation.Mode.MULTIPROCESS): - try: - # bandwidth and latency only work for docker mode - emulator = emulation.Emulator( - emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + def proc( + y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 + ): + f1 = f1_score( + y_true, + y_pred, + average=average, + labels=labels, + pos_label=pos_label, + transform=transform, ) - emulator.up() + precision = precision_score( + y_true, + y_pred, + average=average, + labels=labels, + pos_label=pos_label, + transform=transform, + ) + recall = recall_score( + y_true, + y_pred, + average=average, + labels=labels, + pos_label=pos_label, + transform=transform, + ) + accuracy = accuracy_score(y_true, y_pred) + return f1, precision, recall, accuracy - def proc( - y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 - ): - f1 = f1_score( - y_true, - y_pred, - average=average, - labels=labels, - pos_label=pos_label, - transform=transform, - ) - precision = precision_score( - y_true, - y_pred, - average=average, - labels=labels, - pos_label=pos_label, - transform=transform, - ) - recall = recall_score( - y_true, - y_pred, - average=average, - labels=labels, - pos_label=pos_label, - transform=transform, - ) - accuracy = accuracy_score(y_true, y_pred) - return f1, precision, recall, accuracy + def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): + f1 = metrics.f1_score( + y_true, y_pred, average=average, labels=labels, pos_label=pos_label + ) + precision = metrics.precision_score( + y_true, y_pred, average=average, labels=labels, pos_label=pos_label + ) + recall = metrics.recall_score( + y_true, y_pred, average=average, labels=labels, pos_label=pos_label + ) + accuracy = metrics.accuracy_score(y_true, y_pred) + return f1, precision, recall, accuracy - def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1): - f1 = metrics.f1_score( - y_true, y_pred, average=average, labels=labels, pos_label=pos_label - ) - precision = metrics.precision_score( - y_true, y_pred, average=average, labels=labels, pos_label=pos_label - ) - recall = metrics.recall_score( - y_true, y_pred, average=average, labels=labels, pos_label=pos_label - ) - accuracy = metrics.accuracy_score(y_true, y_pred) - return f1, precision, recall, accuracy + def check(spu_result, sk_result): + for pair in zip(spu_result, sk_result): + np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5) - def check(spu_result, sk_result): - for pair in zip(spu_result, sk_result): - np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5) + # Test binary + y_true = jnp.array([0, 1, 1, 0, 1, 1]) + y_pred = jnp.array([0, 0, 1, 0, 1, 1]) + spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2]) + sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) + check(spu_result, sk_result) - # Test binary - y_true = jnp.array([0, 1, 1, 0, 1, 1]) - y_pred = jnp.array([0, 0, 1, 0, 1, 1]) - spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2]) - sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) - check(spu_result, sk_result) + # Test multiclass + y_true = jnp.array([0, 1, 1, 0, 2, 1]) + y_pred = jnp.array([0, 0, 1, 0, 2, 1]) + spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2]) + sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) + check(spu_result, sk_result) - # Test multiclass - y_true = jnp.array([0, 1, 1, 0, 2, 1]) - y_pred = jnp.array([0, 0, 1, 0, 2, 1]) - spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2]) - sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) - check(spu_result, sk_result) +if __name__ == "__main__": + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, + emulation.Mode.MULTIPROCESS, + bandwidth=300, + latency=20, + ) + emulator.up() + emul_auc(emulation.Mode.MULTIPROCESS) + emul_Classification(emulation.Mode.MULTIPROCESS) finally: emulator.down() - - -if __name__ == "__main__": - emul_auc(emulation.Mode.MULTIPROCESS) - emul_Classification(emulation.Mode.MULTIPROCESS) From d2c2ba07989c73b21760b505bbfaba4ee37a272c Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Thu, 16 Nov 2023 21:13:25 +0800 Subject: [PATCH 15/22] Update classification_emul.py --- sml/metrics/classification/classification_emul.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index 6f698d1c..c023e366 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -47,9 +47,7 @@ def emul_auc(mode: emulation.Mode.MULTIPROCESS): def emul_Classification(mode: emulation.Mode.MULTIPROCESS): - def proc( - y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 - ): + def proc(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): f1 = f1_score( y_true, y_pred, From 61f2e19795cda2e922839351adce78d6047d3d86 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:18:22 +0800 Subject: [PATCH 16/22] Update classification_emul.py --- sml/metrics/classification/classification_emul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index c023e366..3712abda 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -95,8 +95,8 @@ def check(spu_result, sk_result): # Test binary y_true = jnp.array([0, 1, 1, 0, 1, 1]) y_pred = jnp.array([0, 0, 1, 0, 1, 1]) - spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2]) - sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) + spu_result = emulator.run(proc)(y_true, y_pred, pos_label=1, transform=0) + sk_result = sklearn_proc(y_true, y_pred) check(spu_result, sk_result) # Test multiclass From 285b276e8caa62d4215a11a0508c711e09c2e007 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:20:51 +0800 Subject: [PATCH 17/22] Update classification.py --- sml/metrics/classification/classification.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py index 8f78d41d..9d42973e 100644 --- a/sml/metrics/classification/classification.py +++ b/sml/metrics/classification/classification.py @@ -137,7 +137,7 @@ def transform_binary(y_true, y_pred, label): return y_true_transform, y_pred_transform -def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1): +def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True): f1_result = fun_score( _f1_score, y_true, y_pred, average, labels, pos_label, transform ) @@ -145,7 +145,7 @@ def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transfo def precision_score( - y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 + y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True ): f1_result = fun_score( _precision_score, y_true, y_pred, average, labels, pos_label, transform @@ -154,7 +154,7 @@ def precision_score( def recall_score( - y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 + y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True ): f1_result = fun_score( _recall_score, y_true, y_pred, average, labels, pos_label, transform @@ -163,7 +163,7 @@ def recall_score( def fun_score( - fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1 + fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True ): """ Compute precision, recall, f1. @@ -190,8 +190,8 @@ def fun_score( The class to report if ``average='binary'`` and the data is binary. If the data are multiclass or multilabel, this will be ignored; - transform : bool, default=1 - The problem is transformed into a binary classification with positive samples labeled 1 and negative samples labeled 0. + transform : bool, default=True + Binary classification only. If True, then the transformation of label to 0/1 will be done explicitly. Else, you can do it beforehand which decrease the costs of this function. Returns: ------- @@ -212,7 +212,7 @@ def fun_score( y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i) fun_result.append(fun(y_true_binary, y_pred_binary)) elif average == 'binary': - if transform is True: + if transform: y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label) else: y_true_binary, y_pred_binary = y_true, y_pred From 562659551686263fa35322d8f8b852e980f3bc07 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:27:23 +0800 Subject: [PATCH 18/22] Update classification.py --- sml/metrics/classification/classification.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py index 9d42973e..9bbd12ca 100644 --- a/sml/metrics/classification/classification.py +++ b/sml/metrics/classification/classification.py @@ -137,7 +137,9 @@ def transform_binary(y_true, y_pred, label): return y_true_transform, y_pred_transform -def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True): +def f1_score( + y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True +): f1_result = fun_score( _f1_score, y_true, y_pred, average, labels, pos_label, transform ) @@ -212,7 +214,7 @@ def fun_score( y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i) fun_result.append(fun(y_true_binary, y_pred_binary)) elif average == 'binary': - if transform: + if transform is True: y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label) else: y_true_binary, y_pred_binary = y_true, y_pred From c1decdb0e06479668b9a03804a2a8393ca71fbaf Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Fri, 17 Nov 2023 17:01:36 +0800 Subject: [PATCH 19/22] Update classification.py --- sml/metrics/classification/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py index 9bbd12ca..2d3882fc 100644 --- a/sml/metrics/classification/classification.py +++ b/sml/metrics/classification/classification.py @@ -214,7 +214,7 @@ def fun_score( y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i) fun_result.append(fun(y_true_binary, y_pred_binary)) elif average == 'binary': - if transform is True: + if transform: y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label) else: y_true_binary, y_pred_binary = y_true, y_pred From 90e3cd2d90cca8bcf572dc028d38641a3524fd96 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Fri, 17 Nov 2023 17:02:30 +0800 Subject: [PATCH 20/22] Update classification_test.py --- sml/metrics/classification/classification_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index 89c37bd0..cd4901d6 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -139,16 +139,14 @@ def check(spu_result, sk_result): # Test binary y_true = jnp.array([0, 1, 1, 0, 1, 1]) y_pred = jnp.array([0, 0, 1, 0, 1, 1]) - spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, pos_label=1, transform=0) + spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(y_true, y_pred, 'binary', None, 1, False) sk_result = sklearn_proc(y_true, y_pred) check(spu_result, sk_result) # Test multiclass y_true = jnp.array([0, 1, 1, 0, 2, 1]) y_pred = jnp.array([0, 0, 1, 0, 2, 1]) - spu_result = spsim.sim_jax(sim, proc)( - y_true, y_pred, average=None, labels=[0, 1, 2] - ) + spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(y_true, y_pred, None, [0,1,2], 1, True) sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) check(spu_result, sk_result) From ed5434f3c199eb703ea4add5b14e3aeb6e8a7a28 Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Fri, 17 Nov 2023 17:03:49 +0800 Subject: [PATCH 21/22] Update classification_test.py --- sml/metrics/classification/classification_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index cd4901d6..6c9634a9 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -139,14 +139,18 @@ def check(spu_result, sk_result): # Test binary y_true = jnp.array([0, 1, 1, 0, 1, 1]) y_pred = jnp.array([0, 0, 1, 0, 1, 1]) - spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(y_true, y_pred, 'binary', None, 1, False) + spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))( + y_true, y_pred, 'binary', None, 1, False + ) sk_result = sklearn_proc(y_true, y_pred) check(spu_result, sk_result) # Test multiclass y_true = jnp.array([0, 1, 1, 0, 2, 1]) y_pred = jnp.array([0, 0, 1, 0, 2, 1]) - spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(y_true, y_pred, None, [0,1,2], 1, True) + spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))( + y_true, y_pred, None, [0, 1, 2], 1, True + ) sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) check(spu_result, sk_result) From 1f36b09807a7d385fe6b957a69a1e87343ce714b Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Fri, 17 Nov 2023 17:06:00 +0800 Subject: [PATCH 22/22] Update classification_emul.py --- sml/metrics/classification/classification_emul.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index 3712abda..bdebd734 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -95,14 +95,18 @@ def check(spu_result, sk_result): # Test binary y_true = jnp.array([0, 1, 1, 0, 1, 1]) y_pred = jnp.array([0, 0, 1, 0, 1, 1]) - spu_result = emulator.run(proc)(y_true, y_pred, pos_label=1, transform=0) + spu_result = emulator.run(proc, static_argnums=(2, 5))( + y_true, y_pred, 'binary', None, 1, False + ) sk_result = sklearn_proc(y_true, y_pred) check(spu_result, sk_result) # Test multiclass y_true = jnp.array([0, 1, 1, 0, 2, 1]) y_pred = jnp.array([0, 0, 1, 0, 2, 1]) - spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2]) + spu_result = emulator.run(proc, static_argnums=(2, 5))( + y_true, y_pred, None, [0, 1, 2], 1, True + ) sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) check(spu_result, sk_result)