Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
split unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Dec 18, 2019
1 parent ddb36da commit 1af5450
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion _unittests/ut_onnxrt/test_rt_valid_model_naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def fit_classification_model(self, model, n_classes, is_int=False,
model.fit(X_train, y_train)
return model, X_test

def test_model_bernoulli_nb_binary_classification(self):
def test_model_bernoulli_nb_bc_python(self):
model, X = self.fit_classification_model(BernoulliNB(), 2)
model_onnx = convert_sklearn(
model, "?", [("input", FloatTensorType([None, X.shape[1]]))],
Expand All @@ -44,12 +44,28 @@ def test_model_bernoulli_nb_binary_classification(self):
got2 = DataFrame(got['output_probability']).values
self.assertEqualArray(exp, got2, decimal=4)

def test_model_bernoulli_nb_bc_onnxruntime1(self):
model, X = self.fit_classification_model(BernoulliNB(), 2)
model_onnx = convert_sklearn(
model, "?", [("input", FloatTensorType([None, X.shape[1]]))],
dtype=numpy.float32)
exp1 = model.predict(X)
exp = model.predict_proba(X)

oinf = OnnxInference(model_onnx, runtime='onnxruntime1')
got = oinf.run({'input': X})
self.assertEqualArray(exp1, got['output_label'])
got2 = DataFrame(got['output_probability']).values
self.assertEqualArray(exp, got2, decimal=4)

def test_model_bernoulli_nb_bc_onnxruntime2(self):
model, X = self.fit_classification_model(BernoulliNB(), 2)
model_onnx = convert_sklearn(
model, "?", [("input", FloatTensorType([None, X.shape[1]]))],
dtype=numpy.float32)
exp1 = model.predict(X)
exp = model.predict_proba(X)

oinf = OnnxInference(model_onnx, runtime='onnxruntime2')
got = oinf.run({'input': X})
self.assertEqualArray(exp1, got['output_label'])
Expand Down

0 comments on commit 1af5450

Please sign in to comment.