From 9ccdec3883680eb4c109f01f7d2918593aefbfb2 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 15:41:44 +0800
Subject: [PATCH 01/22] Update classification.py

---
 sml/metrics/classification/classification.py | 110 +++++++++++++++++--
 1 file changed, 102 insertions(+), 8 deletions(-)

diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py
index 6b8c7fe2a..b142aec4f 100644
--- a/sml/metrics/classification/classification.py
+++ b/sml/metrics/classification/classification.py
@@ -98,13 +98,107 @@ def equal_range(x: jnp.ndarray, n_bin: int) -> jnp.ndarray:
     return result
 
 
-# TODO: more evaluation tools
+def _f1_score(y_true, y_pred):
+    """Calculate the F1 score."""
+    tp = jnp.sum(y_true * y_pred)
+    fp = jnp.sum(y_pred) - tp
+    fn = jnp.sum(y_true) - tp
+    f1 = 2 * tp / (2 * tp + fp + fn + 1e-10)
+    return f1
+
+def _precision_score(y_true, y_pred):
+    """Calculate the Precision score."""
+    tp = jnp.sum(y_true * y_pred)
+    fp = jnp.sum(y_pred) - tp
+    precision = tp / (tp + fp + 1e-10)
+    return precision
+
+def _recall_score(y_true, y_pred):
+    """Calculate the Recall score."""
+    tp = jnp.sum(y_true * y_pred)
+    fn = jnp.sum(y_true) - tp
+    recall = tp / (tp + fn + 1e-10)
+    return recall  
+
+def accuracy_score(y_true, y_pred):
+    """Calculate the Accuracy score."""
+    correct = jnp.sum(y_true == y_pred)
+    total = len(y_true)
+    accuracy = correct / total
+    return accuracy
+
+def transform_binary(y_true, y_pred, label):
+    y_true_transform = jnp.where(y_true == label, 1, 0)
+    y_pred_transform = jnp.where(y_pred != label, 0, 1)
+    return y_true_transform, y_pred_transform
+
+def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
+    f1_result = fun_score(_f1_score, y_true, y_pred, average, labels, pos_label, transform)
+    return f1_result
+
+def precision_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
+    f1_result = fun_score(_precision_score, y_true, y_pred, average, labels, pos_label, transform)
+    return f1_result
+
+def recall_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
+    f1_result = fun_score(_recall_score, y_true, y_pred, average, labels, pos_label, transform)
+    return f1_result
+
+def fun_score(fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
+    """
+    Compute precision, recall, f1.
 
+    Args:
+    fun : function, support '_precision_score' / '_recall_score' / '_f1_score'.
 
-def compute_f1_score(
-    true_positive: jnp.ndarray, false_positive: jnp.ndarray, false_negative: jnp.ndarray
-):
-    """Calculate the F1 score."""
-    precision = true_positive / (true_positive + false_positive)
-    recall = true_positive / (true_positive + false_negative)
-    return 2 * precision * recall / (precision + recall)
+    y_true : 1d array-like, ground truth (correct) target values.
+
+    y_pred : 1d array-like, estimated targets as returned by a classifier.
+
+    average : {'binary'} or None, default='binary'
+        This parameter is required for multiclass/multilabel targets.
+        If ``None``, the scores for each class are returned.
+
+        ``'binary'``:
+            Only report results for the class specified by ``pos_label``.
+            This is applicable only if targets (``y_{true,pred}``) are binary
+
+    labels : array-like, default=None
+        The set of labels to include when ``average != 'binary'``.
+
+    pos_label : int, float, default=1
+        The class to report if ``average='binary'`` and the data is binary.
+        If the data are multiclass or multilabel, this will be ignored;
+    
+    transform : bool, default=1
+        The problem is transformed into a binary classification with positive samples labeled 1 and negative samples labeled 0.
+
+    Returns:
+    -------
+    precision : float, shape = [n_unique_labels] for multi-classification
+        Precision score.
+
+    recall : float, shape = [n_unique_labels] for multi-classification
+        Recall score.
+
+    f1 : float, shape = [n_unique_labels] for multi-classification
+        F1 score.
+    """
+
+    if average is None:
+        assert (
+            labels is not None
+        ), f"labels cannot be None"
+        fun_result = []
+        for i in labels:
+            y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i)
+            fun_result.append(fun(y_true_binary, y_pred_binary))
+    elif average == 'binary':
+        if transform is True:
+            y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label)
+        else:
+            y_true_binary, y_pred_binary = y_true, y_pred
+        fun_result = fun(y_true_binary, y_pred_binary)
+    else:
+        raise ValueError("average should be None or 'binary'")
+    return fun_result

From 3dba13c170f753cd416e0ebab453d40d39b519c0 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 15:51:24 +0800
Subject: [PATCH 02/22] Update classification.py

---
 sml/metrics/classification/classification.py | 46 ++++++++++++++------
 1 file changed, 32 insertions(+), 14 deletions(-)

diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py
index b142aec4f..96e5374b7 100644
--- a/sml/metrics/classification/classification.py
+++ b/sml/metrics/classification/classification.py
@@ -106,6 +106,7 @@ def _f1_score(y_true, y_pred):
     f1 = 2 * tp / (2 * tp + fp + fn + 1e-10)
     return f1
 
+
 def _precision_score(y_true, y_pred):
     """Calculate the Precision score."""
     tp = jnp.sum(y_true * y_pred)
@@ -113,12 +114,14 @@ def _precision_score(y_true, y_pred):
     precision = tp / (tp + fp + 1e-10)
     return precision
 
+
 def _recall_score(y_true, y_pred):
     """Calculate the Recall score."""
     tp = jnp.sum(y_true * y_pred)
     fn = jnp.sum(y_true) - tp
     recall = tp / (tp + fn + 1e-10)
-    return recall  
+    return recall
+
 
 def accuracy_score(y_true, y_pred):
     """Calculate the Accuracy score."""
@@ -127,24 +130,41 @@ def accuracy_score(y_true, y_pred):
     accuracy = correct / total
     return accuracy
 
+
 def transform_binary(y_true, y_pred, label):
     y_true_transform = jnp.where(y_true == label, 1, 0)
     y_pred_transform = jnp.where(y_pred != label, 0, 1)
     return y_true_transform, y_pred_transform
 
+
 def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
-    f1_result = fun_score(_f1_score, y_true, y_pred, average, labels, pos_label, transform)
-    return f1_result
+    f1_result = fun_score(
+        _f1_score, y_true, y_pred, average, labels, pos_label, transform
+    )
+     return f1_result
 
-def precision_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
-    f1_result = fun_score(_precision_score, y_true, y_pred, average, labels, pos_label, transform)
-    return f1_result
 
-def recall_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
-    f1_result = fun_score(_recall_score, y_true, y_pred, average, labels, pos_label, transform)
-    return f1_result
+def precision_score(
+    y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
+):
+    f1_result = fun_score(
+        _precision_score, y_true, y_pred, average, labels, pos_label, transform
+    )
+     return f1_result
+
+
+def recall_score(
+    y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
+):
+    f1_result = fun_score(
+        _recall_score, y_true, y_pred, average, labels, pos_label, transform
+    )
+     return f1_result
+
 
-def fun_score(fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
+def fun_score(
+    fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
+):
     """
     Compute precision, recall, f1.
 
@@ -169,7 +189,7 @@ def fun_score(fun, y_true, y_pred, average='binary', labels=None, pos_label=1, t
     pos_label : int, float, default=1
         The class to report if ``average='binary'`` and the data is binary.
         If the data are multiclass or multilabel, this will be ignored;
-    
+
     transform : bool, default=1
         The problem is transformed into a binary classification with positive samples labeled 1 and negative samples labeled 0.
 
@@ -186,9 +206,7 @@ def fun_score(fun, y_true, y_pred, average='binary', labels=None, pos_label=1, t
     """
 
     if average is None:
-        assert (
-            labels is not None
-        ), f"labels cannot be None"
+        assert labels is not None, f"labels cannot be None"
         fun_result = []
         for i in labels:
             y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i)

From b414e76e2d97e5f368a04622c877c90029f52647 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 16:11:42 +0800
Subject: [PATCH 03/22] Update classification.py

---
 sml/metrics/classification/classification.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py
index 96e5374b7..8f78d41d8 100644
--- a/sml/metrics/classification/classification.py
+++ b/sml/metrics/classification/classification.py
@@ -141,7 +141,7 @@ def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transfo
     f1_result = fun_score(
         _f1_score, y_true, y_pred, average, labels, pos_label, transform
     )
-     return f1_result
+    return f1_result
 
 
 def precision_score(
@@ -150,7 +150,7 @@ def precision_score(
     f1_result = fun_score(
         _precision_score, y_true, y_pred, average, labels, pos_label, transform
     )
-     return f1_result
+    return f1_result
 
 
 def recall_score(
@@ -159,7 +159,7 @@ def recall_score(
     f1_result = fun_score(
         _recall_score, y_true, y_pred, average, labels, pos_label, transform
     )
-     return f1_result
+    return f1_result
 
 
 def fun_score(

From 4804cb17cc58bdf837f9daf0a379354a6f62e26b Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 16:38:32 +0800
Subject: [PATCH 04/22] Update classification_test.py

---
 .../classification/classification_test.py     | 38 +++++++++++++++++++
 1 file changed, 38 insertions(+)

diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py
index 8ff59db5b..db14850b7 100644
--- a/sml/metrics/classification/classification_test.py
+++ b/sml/metrics/classification/classification_test.py
@@ -80,5 +80,43 @@ def digitize(y_pred, thresholds):
         np.testing.assert_almost_equal(true_score, score, decimal=2)
 
 
+    def test_classification(self):
+        sim = spsim.Simulator.simple(
+            3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
+        )
+
+        def proc(y_true, y_pred, average='binary', labels=None, pos_label=1 ,transform=1):
+            f1 = f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
+            precision = precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
+            recall = recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
+            accuracy = accuracy_score(y_true, y_pred)
+            return f1,precision,recall,accuracy
+        
+        def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
+            f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
+            precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
+            recall = metrics.recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
+            accuracy = metrics.accuracy_score(y_true, y_pred)
+            return f1,precision,recall,accuracy
+        
+        def check(spu_result, sk_result):
+            for  pair in zip(spu_result, sk_result):
+                np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)
+
+        # Test binary
+        y_true = jnp.array([0, 1, 1, 0, 1, 1])
+        y_pred = jnp.array([0, 0, 1, 0, 1, 1])
+        spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, pos_label=1, transform=0)
+        sk_result = sklearn_proc(y_true, y_pred)
+        check(spu_result, sk_result)
+
+        # Test multiclass
+        y_true = jnp.array([0, 1, 1, 0, 2, 1])
+        y_pred = jnp.array([0, 0, 1, 0, 2, 1])
+        spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, average=None, labels=[0,1,2])
+        sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0,1,2])
+        check(spu_result, sk_result)
+
+
 if __name__ == "__main__":
     unittest.main()

From 44c001409b45938bdd68481afc1ad2cdb28e0bd7 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 16:47:09 +0800
Subject: [PATCH 05/22] Update classification_emul.py

---
 .../classification/classification_emul.py     | 50 ++++++++++++++++++-
 1 file changed, 49 insertions(+), 1 deletion(-)

diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py
index fbddcf8d6..87e900a5b 100644
--- a/sml/metrics/classification/classification_emul.py
+++ b/sml/metrics/classification/classification_emul.py
@@ -15,12 +15,14 @@
 import sys
 
 import numpy as np
+import jax.numpy as jnp
+from sklearn import metrics
 
 # add ops dir to the path
 sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))
 
 import sml.utils.emulation as emulation
-from sml.metrics.classification.classification import roc_auc_score
+from sml.metrics.classification.classification import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score
 
 
 # TODO: design the enumation framework, just like py.unittest
@@ -48,5 +50,51 @@ def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS):
         emulator.down()
 
 
+
+def emul_Classification(mode: emulation.Mode.MULTIPROCESS):
+    try:
+        # bandwidth and latency only work for docker mode
+        emulator = emulation.Emulator(
+            emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
+        )
+        emulator.up()
+    
+        def proc(y_true, y_pred, average='binary', labels=None, pos_label=1 ,transform=1):
+            f1 = f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
+            precision = precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
+            recall = recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
+            accuracy = accuracy_score(y_true, y_pred)
+            return f1,precision,recall,accuracy
+        
+        def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
+            f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
+            precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
+            recall = metrics.recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
+            accuracy = metrics.accuracy_score(y_true, y_pred)
+            return f1,precision,recall,accuracy
+        
+        def check(spu_result, sk_result):
+            for  pair in zip(spu_result, sk_result):
+                np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)
+
+        # Test binary
+        y_true = jnp.array([0, 1, 1, 0, 1, 1])
+        y_pred = jnp.array([0, 0, 1, 0, 1, 1])
+        spu_result = emulator.run(proc)(y_true, y_pred, pos_label=1, transform=0)
+        sk_result = sklearn_proc(y_true, y_pred)
+        check(spu_result, sk_result)
+
+        # Test multiclass
+        y_true = jnp.array([0, 1, 1, 0, 2, 1])
+        y_pred = jnp.array([0, 0, 1, 0, 2, 1])
+        spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0,1,2])
+        sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0,1,2])
+        check(spu_result, sk_result)
+
+    finally:
+        emulator.down()
+
+
 if __name__ == "__main__":
     emul_SGDClassifier(emulation.Mode.MULTIPROCESS)
+    emul_Classification(emulation.Mode.MULTIPROCESS)

From 6df82f35ace47e5dff7ccd6065e2c6e15a6a7868 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 16:52:14 +0800
Subject: [PATCH 06/22] Update classification_emul.py

---
 .../classification/classification_emul.py     | 68 ++++++++++++++-----
 1 file changed, 52 insertions(+), 16 deletions(-)

diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py
index 87e900a5b..53c53fe23 100644
--- a/sml/metrics/classification/classification_emul.py
+++ b/sml/metrics/classification/classification_emul.py
@@ -22,7 +22,13 @@
 sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))
 
 import sml.utils.emulation as emulation
-from sml.metrics.classification.classification import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score
+from sml.metrics.classification.classification import (
+    roc_auc_score,
+    f1_score,
+    precision_score,
+    recall_score,
+    accuracy_score,
+)
 
 
 # TODO: design the enumation framework, just like py.unittest
@@ -50,7 +56,6 @@ def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS):
         emulator.down()
 
 
-
 def emul_Classification(mode: emulation.Mode.MULTIPROCESS):
     try:
         # bandwidth and latency only work for docker mode
@@ -59,29 +64,60 @@ def emul_Classification(mode: emulation.Mode.MULTIPROCESS):
         )
         emulator.up()
     
-        def proc(y_true, y_pred, average='binary', labels=None, pos_label=1 ,transform=1):
-            f1 = f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
-            precision = precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
-            recall = recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
+def proc(
+            y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
+        ):
+            f1 = f1_score(
+                y_true,
+                y_pred,
+                average=average,
+                labels=labels,
+                pos_label=pos_label,
+                transform=transform,
+            )
+            precision = precision_score(
+                y_true,
+                y_pred,
+                average=average,
+                labels=labels,
+                pos_label=pos_label,
+                transform=transform,
+            )
+            recall = recall_score(
+                y_true,
+                y_pred,
+                average=average,
+                labels=labels,
+                pos_label=pos_label,
+                transform=transform,
+            )
             accuracy = accuracy_score(y_true, y_pred)
-            return f1,precision,recall,accuracy
-        
+            return f1, precision, recall, accuracy
+
+
         def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
-            f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
-            precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
-            recall = metrics.recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
+            f1 = metrics.f1_score(
+                y_true, y_pred, average=average, labels=labels, pos_label=pos_label
+            )
+            precision = metrics.precision_score(
+                y_true, y_pred, average=average, labels=labels, pos_label=pos_label
+            )
+            recall = metrics.recall_score(
+                y_true, y_pred, average=average, labels=labels, pos_label=pos_label
+            )
             accuracy = metrics.accuracy_score(y_true, y_pred)
-            return f1,precision,recall,accuracy
-        
+            return f1, precision, recall, accuracy
+
+
         def check(spu_result, sk_result):
-            for  pair in zip(spu_result, sk_result):
+            for pair in zip(spu_result, sk_result):
                 np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)
 
         # Test binary
         y_true = jnp.array([0, 1, 1, 0, 1, 1])
         y_pred = jnp.array([0, 0, 1, 0, 1, 1])
-        spu_result = emulator.run(proc)(y_true, y_pred, pos_label=1, transform=0)
-        sk_result = sklearn_proc(y_true, y_pred)
+        spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2])
+        sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
         check(spu_result, sk_result)
 
         # Test multiclass

From f7bb89d3df69adffe387a74cb0ad537f3524d6b7 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 16:58:17 +0800
Subject: [PATCH 07/22] Update classification_test.py

---
 .../classification/classification_test.py     | 48 +++++++++++++++----
 1 file changed, 39 insertions(+), 9 deletions(-)

diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py
index db14850b7..e7c3da81d 100644
--- a/sml/metrics/classification/classification_test.py
+++ b/sml/metrics/classification/classification_test.py
@@ -18,6 +18,7 @@
 
 import jax.numpy as jnp
 import numpy as np
+from sklearn import metrics
 
 import spu.spu_pb2 as spu_pb2
 import spu.utils.simulation as spsim
@@ -31,6 +32,10 @@
     bin_counts,
     equal_obs,
     roc_auc_score,
+    f1_score,
+    precision_score,
+    recall_score,
+    accuracy_score,
 )
 
 
@@ -79,19 +84,42 @@ def digitize(y_pred, thresholds):
 
         np.testing.assert_almost_equal(true_score, score, decimal=2)
 
-
     def test_classification(self):
         sim = spsim.Simulator.simple(
             3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
         )
 
-        def proc(y_true, y_pred, average='binary', labels=None, pos_label=1 ,transform=1):
-            f1 = f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
-            precision = precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
-            recall = recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label, transform=transform)
+        def proc(
+            y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
+        ):
+            f1 = f1_score(
+                y_true,
+                y_pred,
+                average=average,
+                labels=labels,
+                pos_label=pos_label,
+                transform=transform,
+            )
+            precision = precision_score(
+                y_true,
+                y_pred,
+                average=average,
+                labels=labels,
+                pos_label=pos_label,
+                transform=transform,
+            )
+            recall = recall_score(
+                y_true,
+                y_pred,
+                average=average,
+                labels=labels,
+                pos_label=pos_label,
+                transform=transform,
+            )
             accuracy = accuracy_score(y_true, y_pred)
-            return f1,precision,recall,accuracy
-        
+            return f1, precision, recall, accuracy
+
+
         def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
             f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
             precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
@@ -113,8 +141,10 @@ def check(spu_result, sk_result):
         # Test multiclass
         y_true = jnp.array([0, 1, 1, 0, 2, 1])
         y_pred = jnp.array([0, 0, 1, 0, 2, 1])
-        spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, average=None, labels=[0,1,2])
-        sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0,1,2])
+        spu_result = spsim.sim_jax(sim, proc)(
+            y_true, y_pred, average=None, labels=[0, 1, 2]
+        )
+        sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
         check(spu_result, sk_result)
 
 

From 6d6f2254abdef4ee4a3e2dae582088d3e1de3212 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 16:58:35 +0800
Subject: [PATCH 08/22] Update classification_emul.py

---
 sml/metrics/classification/classification_emul.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py
index 53c53fe23..a5f88e58a 100644
--- a/sml/metrics/classification/classification_emul.py
+++ b/sml/metrics/classification/classification_emul.py
@@ -64,7 +64,7 @@ def emul_Classification(mode: emulation.Mode.MULTIPROCESS):
         )
         emulator.up()
     
-def proc(
+        def proc(
             y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
         ):
             f1 = f1_score(

From 9e5adb8a5606d8626928e07656b958a44e885e0c Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 17:02:33 +0800
Subject: [PATCH 09/22] Update classification_emul.py

---
 sml/metrics/classification/classification_emul.py | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py
index a5f88e58a..6df71b94d 100644
--- a/sml/metrics/classification/classification_emul.py
+++ b/sml/metrics/classification/classification_emul.py
@@ -63,7 +63,7 @@ def emul_Classification(mode: emulation.Mode.MULTIPROCESS):
             emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
         )
         emulator.up()
-    
+
         def proc(
             y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
         ):
@@ -94,7 +94,6 @@ def proc(
             accuracy = accuracy_score(y_true, y_pred)
             return f1, precision, recall, accuracy
 
-
         def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
             f1 = metrics.f1_score(
                 y_true, y_pred, average=average, labels=labels, pos_label=pos_label
@@ -108,7 +107,6 @@ def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
             accuracy = metrics.accuracy_score(y_true, y_pred)
             return f1, precision, recall, accuracy
 
-
         def check(spu_result, sk_result):
             for pair in zip(spu_result, sk_result):
                 np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)
@@ -123,8 +121,8 @@ def check(spu_result, sk_result):
         # Test multiclass
         y_true = jnp.array([0, 1, 1, 0, 2, 1])
         y_pred = jnp.array([0, 0, 1, 0, 2, 1])
-        spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0,1,2])
-        sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0,1,2])
+        spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2])
+        sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
         check(spu_result, sk_result)
 
     finally:

From 8c193e5bc853f5c0b780ed897c994d441354c788 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 17:03:52 +0800
Subject: [PATCH 10/22] Update classification_test.py

---
 .../classification/classification_test.py        | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py
index e7c3da81d..b1c8e007f 100644
--- a/sml/metrics/classification/classification_test.py
+++ b/sml/metrics/classification/classification_test.py
@@ -121,14 +121,20 @@ def proc(
 
 
         def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
-            f1 = metrics.f1_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
-            precision = metrics.precision_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
-            recall = metrics.recall_score(y_true, y_pred, average=average, labels=labels, pos_label=pos_label)
+            f1 = metrics.f1_score(
+                y_true, y_pred, average=average, labels=labels, pos_label=pos_label
+            )
+            precision = metrics.precision_score(
+                y_true, y_pred, average=average, labels=labels, pos_label=pos_label
+            )
+            recall = metrics.recall_score(
+                y_true, y_pred, average=average, labels=labels, pos_label=pos_label
+            )
             accuracy = metrics.accuracy_score(y_true, y_pred)
-            return f1,precision,recall,accuracy
+            return f1, precision, recall, accuracy
         
         def check(spu_result, sk_result):
-            for  pair in zip(spu_result, sk_result):
+            for pair in zip(spu_result, sk_result):
                 np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)
 
         # Test binary

From 6258145c0740bd2b2bb993cd46eda2b5b04e16e7 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 17:05:08 +0800
Subject: [PATCH 11/22] Update classification_test.py

---
 sml/metrics/classification/classification_test.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py
index b1c8e007f..fd5dafc7a 100644
--- a/sml/metrics/classification/classification_test.py
+++ b/sml/metrics/classification/classification_test.py
@@ -119,7 +119,6 @@ def proc(
             accuracy = accuracy_score(y_true, y_pred)
             return f1, precision, recall, accuracy
 
-
         def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
             f1 = metrics.f1_score(
                 y_true, y_pred, average=average, labels=labels, pos_label=pos_label
@@ -132,7 +131,7 @@ def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
             )
             accuracy = metrics.accuracy_score(y_true, y_pred)
             return f1, precision, recall, accuracy
-        
+
         def check(spu_result, sk_result):
             for pair in zip(spu_result, sk_result):
                 np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)

From 56c219961d79a72b1251fe4302eae4ddcca04c71 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 17:59:23 +0800
Subject: [PATCH 12/22] change function name

---
 sml/metrics/classification/classification_test.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py
index fd5dafc7a..89c37bd01 100644
--- a/sml/metrics/classification/classification_test.py
+++ b/sml/metrics/classification/classification_test.py
@@ -40,7 +40,7 @@
 
 
 class UnitTests(unittest.TestCase):
-    def test_simple(self):
+    def test_auc(self):
         sim = spsim.Simulator.simple(
             3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
         )

From 208b81f2fe73b5601dbfc6ee5e2eaa58dcca63af Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 17:59:59 +0800
Subject: [PATCH 13/22] change function name

---
 sml/metrics/classification/classification_emul.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py
index 6df71b94d..4da6a3a61 100644
--- a/sml/metrics/classification/classification_emul.py
+++ b/sml/metrics/classification/classification_emul.py
@@ -33,7 +33,7 @@
 
 # TODO: design the enumation framework, just like py.unittest
 # all emulation action should begin with `emul_` (for reflection)
-def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS):
+def emul_auc(mode: emulation.Mode.MULTIPROCESS):
     try:
         # bandwidth and latency only work for docker mode
         emulator = emulation.Emulator(
@@ -130,5 +130,5 @@ def check(spu_result, sk_result):
 
 
 if __name__ == "__main__":
-    emul_SGDClassifier(emulation.Mode.MULTIPROCESS)
+    emul_auc(emulation.Mode.MULTIPROCESS)
     emul_Classification(emulation.Mode.MULTIPROCESS)

From 25626b0cf8f9ce92f98b62ecc7a47633837c7034 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 21:09:31 +0800
Subject: [PATCH 14/22] Update classification_emul.py

---
 .../classification/classification_emul.py     | 161 +++++++++---------
 1 file changed, 76 insertions(+), 85 deletions(-)

diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py
index 4da6a3a61..6f698d1c8 100644
--- a/sml/metrics/classification/classification_emul.py
+++ b/sml/metrics/classification/classification_emul.py
@@ -34,101 +34,92 @@
 # TODO: design the enumation framework, just like py.unittest
 # all emulation action should begin with `emul_` (for reflection)
 def emul_auc(mode: emulation.Mode.MULTIPROCESS):
-    try:
-        # bandwidth and latency only work for docker mode
-        emulator = emulation.Emulator(
-            emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
-        )
-        emulator.up()
-
-        # Create dataset
-        row = 10000
-        y_true = np.random.randint(0, 2, (row,))
-        y_pred = np.random.random((row,))
-
-        # Run
-        result = emulator.run(roc_auc_score)(
-            y_true, y_pred
-        )  # X, y should be two-dimension array
-        print(result)
+    # Create dataset
+    row = 10000
+    y_true = np.random.randint(0, 2, (row,))
+    y_pred = np.random.random((row,))
 
-    finally:
-        emulator.down()
+    # Run
+    result = emulator.run(roc_auc_score)(
+        y_true, y_pred
+    )  # X, y should be two-dimension array
+    print(result)
 
 
 def emul_Classification(mode: emulation.Mode.MULTIPROCESS):
-    try:
-        # bandwidth and latency only work for docker mode
-        emulator = emulation.Emulator(
-            emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
+    def proc(
+        y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
+    ):
+        f1 = f1_score(
+            y_true,
+            y_pred,
+            average=average,
+            labels=labels,
+            pos_label=pos_label,
+            transform=transform,
         )
-        emulator.up()
+        precision = precision_score(
+            y_true,
+            y_pred,
+            average=average,
+            labels=labels,
+            pos_label=pos_label,
+            transform=transform,
+        )
+        recall = recall_score(
+            y_true,
+            y_pred,
+            average=average,
+            labels=labels,
+            pos_label=pos_label,
+            transform=transform,
+        )
+        accuracy = accuracy_score(y_true, y_pred)
+        return f1, precision, recall, accuracy
 
-        def proc(
-            y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
-        ):
-            f1 = f1_score(
-                y_true,
-                y_pred,
-                average=average,
-                labels=labels,
-                pos_label=pos_label,
-                transform=transform,
-            )
-            precision = precision_score(
-                y_true,
-                y_pred,
-                average=average,
-                labels=labels,
-                pos_label=pos_label,
-                transform=transform,
-            )
-            recall = recall_score(
-                y_true,
-                y_pred,
-                average=average,
-                labels=labels,
-                pos_label=pos_label,
-                transform=transform,
-            )
-            accuracy = accuracy_score(y_true, y_pred)
-            return f1, precision, recall, accuracy
+    def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
+        f1 = metrics.f1_score(
+            y_true, y_pred, average=average, labels=labels, pos_label=pos_label
+        )
+        precision = metrics.precision_score(
+            y_true, y_pred, average=average, labels=labels, pos_label=pos_label
+        )
+        recall = metrics.recall_score(
+            y_true, y_pred, average=average, labels=labels, pos_label=pos_label
+        )
+        accuracy = metrics.accuracy_score(y_true, y_pred)
+        return f1, precision, recall, accuracy
 
-        def sklearn_proc(y_true, y_pred, average='binary', labels=None, pos_label=1):
-            f1 = metrics.f1_score(
-                y_true, y_pred, average=average, labels=labels, pos_label=pos_label
-            )
-            precision = metrics.precision_score(
-                y_true, y_pred, average=average, labels=labels, pos_label=pos_label
-            )
-            recall = metrics.recall_score(
-                y_true, y_pred, average=average, labels=labels, pos_label=pos_label
-            )
-            accuracy = metrics.accuracy_score(y_true, y_pred)
-            return f1, precision, recall, accuracy
+    def check(spu_result, sk_result):
+        for pair in zip(spu_result, sk_result):
+            np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)
 
-        def check(spu_result, sk_result):
-            for pair in zip(spu_result, sk_result):
-                np.testing.assert_allclose(pair[0], pair[1], rtol=1, atol=1e-5)
+    # Test binary
+    y_true = jnp.array([0, 1, 1, 0, 1, 1])
+    y_pred = jnp.array([0, 0, 1, 0, 1, 1])
+    spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2])
+    sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
+    check(spu_result, sk_result)
 
-        # Test binary
-        y_true = jnp.array([0, 1, 1, 0, 1, 1])
-        y_pred = jnp.array([0, 0, 1, 0, 1, 1])
-        spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2])
-        sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
-        check(spu_result, sk_result)
+    # Test multiclass
+    y_true = jnp.array([0, 1, 1, 0, 2, 1])
+    y_pred = jnp.array([0, 0, 1, 0, 2, 1])
+    spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2])
+    sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
+    check(spu_result, sk_result)
 
-        # Test multiclass
-        y_true = jnp.array([0, 1, 1, 0, 2, 1])
-        y_pred = jnp.array([0, 0, 1, 0, 2, 1])
-        spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2])
-        sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
-        check(spu_result, sk_result)
 
+if __name__ == "__main__":
+    try:
+        # bandwidth and latency only work for docker mode
+        emulator = emulation.Emulator(
+            emulation.CLUSTER_ABY3_3PC,
+            emulation.Mode.MULTIPROCESS,
+            bandwidth=300,
+            latency=20,
+        )
+        emulator.up()
+        emul_auc(emulation.Mode.MULTIPROCESS)
+        emul_Classification(emulation.Mode.MULTIPROCESS)
     finally:
         emulator.down()
-
-
-if __name__ == "__main__":
-    emul_auc(emulation.Mode.MULTIPROCESS)
-    emul_Classification(emulation.Mode.MULTIPROCESS)

From d2c2ba07989c73b21760b505bbfaba4ee37a272c Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Thu, 16 Nov 2023 21:13:25 +0800
Subject: [PATCH 15/22] Update classification_emul.py

---
 sml/metrics/classification/classification_emul.py | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py
index 6f698d1c8..c023e366f 100644
--- a/sml/metrics/classification/classification_emul.py
+++ b/sml/metrics/classification/classification_emul.py
@@ -47,9 +47,7 @@ def emul_auc(mode: emulation.Mode.MULTIPROCESS):
 
 
 def emul_Classification(mode: emulation.Mode.MULTIPROCESS):
-    def proc(
-        y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
-    ):
+    def proc(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
         f1 = f1_score(
             y_true,
             y_pred,

From 61f2e19795cda2e922839351adce78d6047d3d86 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Fri, 17 Nov 2023 11:18:22 +0800
Subject: [PATCH 16/22] Update classification_emul.py

---
 sml/metrics/classification/classification_emul.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py
index c023e366f..3712abda4 100644
--- a/sml/metrics/classification/classification_emul.py
+++ b/sml/metrics/classification/classification_emul.py
@@ -95,8 +95,8 @@ def check(spu_result, sk_result):
     # Test binary
     y_true = jnp.array([0, 1, 1, 0, 1, 1])
     y_pred = jnp.array([0, 0, 1, 0, 1, 1])
-    spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2])
-    sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
+    spu_result = emulator.run(proc)(y_true, y_pred, pos_label=1, transform=0)
+    sk_result = sklearn_proc(y_true, y_pred)
     check(spu_result, sk_result)
 
     # Test multiclass

From 285b276e8caa62d4215a11a0508c711e09c2e007 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Fri, 17 Nov 2023 11:20:51 +0800
Subject: [PATCH 17/22] Update classification.py

---
 sml/metrics/classification/classification.py | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py
index 8f78d41d8..9d42973e8 100644
--- a/sml/metrics/classification/classification.py
+++ b/sml/metrics/classification/classification.py
@@ -137,7 +137,7 @@ def transform_binary(y_true, y_pred, label):
     return y_true_transform, y_pred_transform
 
 
-def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1):
+def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True):
     f1_result = fun_score(
         _f1_score, y_true, y_pred, average, labels, pos_label, transform
     )
@@ -145,7 +145,7 @@ def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transfo
 
 
 def precision_score(
-    y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
+    y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
 ):
     f1_result = fun_score(
         _precision_score, y_true, y_pred, average, labels, pos_label, transform
@@ -154,7 +154,7 @@ def precision_score(
 
 
 def recall_score(
-    y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
+    y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
 ):
     f1_result = fun_score(
         _recall_score, y_true, y_pred, average, labels, pos_label, transform
@@ -163,7 +163,7 @@ def recall_score(
 
 
 def fun_score(
-    fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
+    fun, y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
 ):
     """
     Compute precision, recall, f1.
@@ -190,8 +190,8 @@ def fun_score(
         The class to report if ``average='binary'`` and the data is binary.
         If the data are multiclass or multilabel, this will be ignored;
 
-    transform : bool, default=1
-        The problem is transformed into a binary classification with positive samples labeled 1 and negative samples labeled 0.
+    transform : bool, default=True
+        Binary classification only. If True, then the transformation of label to 0/1 will be done explicitly. Else, you can do it beforehand which decrease the costs of this function.
 
     Returns:
     -------
@@ -212,7 +212,7 @@ def fun_score(
             y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i)
             fun_result.append(fun(y_true_binary, y_pred_binary))
     elif average == 'binary':
-        if transform is True:
+        if transform:
             y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label)
         else:
             y_true_binary, y_pred_binary = y_true, y_pred

From 562659551686263fa35322d8f8b852e980f3bc07 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Fri, 17 Nov 2023 11:27:23 +0800
Subject: [PATCH 18/22] Update classification.py

---
 sml/metrics/classification/classification.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py
index 9d42973e8..9bbd12ca7 100644
--- a/sml/metrics/classification/classification.py
+++ b/sml/metrics/classification/classification.py
@@ -137,7 +137,9 @@ def transform_binary(y_true, y_pred, label):
     return y_true_transform, y_pred_transform
 
 
-def f1_score(y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True):
+def f1_score(
+    y_true, y_pred, average='binary', labels=None, pos_label=1, transform=True
+):
     f1_result = fun_score(
         _f1_score, y_true, y_pred, average, labels, pos_label, transform
     )
@@ -212,7 +214,7 @@ def fun_score(
             y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i)
             fun_result.append(fun(y_true_binary, y_pred_binary))
     elif average == 'binary':
-        if transform:
+        if transform is True:
             y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label)
         else:
             y_true_binary, y_pred_binary = y_true, y_pred

From c1decdb0e06479668b9a03804a2a8393ca71fbaf Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Fri, 17 Nov 2023 17:01:36 +0800
Subject: [PATCH 19/22] Update classification.py

---
 sml/metrics/classification/classification.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py
index 9bbd12ca7..2d3882fca 100644
--- a/sml/metrics/classification/classification.py
+++ b/sml/metrics/classification/classification.py
@@ -214,7 +214,7 @@ def fun_score(
             y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i)
             fun_result.append(fun(y_true_binary, y_pred_binary))
     elif average == 'binary':
-        if transform is True:
+        if transform:
             y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label)
         else:
             y_true_binary, y_pred_binary = y_true, y_pred

From 90e3cd2d90cca8bcf572dc028d38641a3524fd96 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Fri, 17 Nov 2023 17:02:30 +0800
Subject: [PATCH 20/22] Update classification_test.py

---
 sml/metrics/classification/classification_test.py | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py
index 89c37bd01..cd4901d69 100644
--- a/sml/metrics/classification/classification_test.py
+++ b/sml/metrics/classification/classification_test.py
@@ -139,16 +139,14 @@ def check(spu_result, sk_result):
         # Test binary
         y_true = jnp.array([0, 1, 1, 0, 1, 1])
         y_pred = jnp.array([0, 0, 1, 0, 1, 1])
-        spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, pos_label=1, transform=0)
+        spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(y_true, y_pred, 'binary', None, 1, False)
         sk_result = sklearn_proc(y_true, y_pred)
         check(spu_result, sk_result)
 
         # Test multiclass
         y_true = jnp.array([0, 1, 1, 0, 2, 1])
         y_pred = jnp.array([0, 0, 1, 0, 2, 1])
-        spu_result = spsim.sim_jax(sim, proc)(
-            y_true, y_pred, average=None, labels=[0, 1, 2]
-        )
+        spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(y_true, y_pred, None, [0,1,2], 1, True)
         sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
         check(spu_result, sk_result)
 

From ed5434f3c199eb703ea4add5b14e3aeb6e8a7a28 Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Fri, 17 Nov 2023 17:03:49 +0800
Subject: [PATCH 21/22] Update classification_test.py

---
 sml/metrics/classification/classification_test.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py
index cd4901d69..6c9634a97 100644
--- a/sml/metrics/classification/classification_test.py
+++ b/sml/metrics/classification/classification_test.py
@@ -139,14 +139,18 @@ def check(spu_result, sk_result):
         # Test binary
         y_true = jnp.array([0, 1, 1, 0, 1, 1])
         y_pred = jnp.array([0, 0, 1, 0, 1, 1])
-        spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(y_true, y_pred, 'binary', None, 1, False)
+        spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(
+            y_true, y_pred, 'binary', None, 1, False
+        )
         sk_result = sklearn_proc(y_true, y_pred)
         check(spu_result, sk_result)
 
         # Test multiclass
         y_true = jnp.array([0, 1, 1, 0, 2, 1])
         y_pred = jnp.array([0, 0, 1, 0, 2, 1])
-        spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(y_true, y_pred, None, [0,1,2], 1, True)
+        spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 5))(
+            y_true, y_pred, None, [0, 1, 2], 1, True
+        )
         sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
         check(spu_result, sk_result)
 

From 1f36b09807a7d385fe6b957a69a1e87343ce714b Mon Sep 17 00:00:00 2001
From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com>
Date: Fri, 17 Nov 2023 17:06:00 +0800
Subject: [PATCH 22/22] Update classification_emul.py

---
 sml/metrics/classification/classification_emul.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py
index 3712abda4..bdebd734c 100644
--- a/sml/metrics/classification/classification_emul.py
+++ b/sml/metrics/classification/classification_emul.py
@@ -95,14 +95,18 @@ def check(spu_result, sk_result):
     # Test binary
     y_true = jnp.array([0, 1, 1, 0, 1, 1])
     y_pred = jnp.array([0, 0, 1, 0, 1, 1])
-    spu_result = emulator.run(proc)(y_true, y_pred, pos_label=1, transform=0)
+    spu_result = emulator.run(proc, static_argnums=(2, 5))(
+        y_true, y_pred, 'binary', None, 1, False
+    )
     sk_result = sklearn_proc(y_true, y_pred)
     check(spu_result, sk_result)
 
     # Test multiclass
     y_true = jnp.array([0, 1, 1, 0, 2, 1])
     y_pred = jnp.array([0, 0, 1, 0, 2, 1])
-    spu_result = emulator.run(proc)(y_true, y_pred, average=None, labels=[0, 1, 2])
+    spu_result = emulator.run(proc, static_argnums=(2, 5))(
+        y_true, y_pred, None, [0, 1, 2], 1, True
+    )
     sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
     check(spu_result, sk_result)