diff --git a/_doc/examples/plot_transformer_discrepancy.py b/_doc/examples/plot_transformer_discrepancy.py new file mode 100644 index 00000000..0bb8aea5 --- /dev/null +++ b/_doc/examples/plot_transformer_discrepancy.py @@ -0,0 +1,281 @@ +""" +.. _example-transform-discrepancy: + +Dealing with discrepancies (tf-idf) +=================================== + +.. index:: td-idf + +`TfidfVectorizer `_ +is one transform for which the corresponding converted onnx model +may produce different results. The larger the vocabulary is, +the higher the probability to get different result is. +This example proposes a equivalent model with no discrepancies. + +.. contents:: + :local: + +Imports, setups ++++++++++++++++ + +All imports. It also registered onnx converters for :epgk:`xgboost` +and :epkg:`lightgbm`. +""" +import pprint +import numpy +import pandas +from sklearn.pipeline import Pipeline +from sklearn.compose import ColumnTransformer +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.linear_model import LogisticRegression +from onnxruntime import InferenceSession +from mlprodict.onnx_conv import to_onnx +from mlprodict.plotting.text_plot import onnx_simple_text_plot +from mlprodict.onnxrt import OnnxInference +from mlprodict.sklapi import OnnxTransformer, OnnxSpeedupTransformer + + +def print_sparse_matrix(m): + nonan = numpy.nan_to_num(m) + mi, ma = nonan.min(), nonan.max() + if mi == ma: + ma += 1 + mat = numpy.empty(m.shape, dtype=numpy.str_) + mat[:, :] = '.' + if hasattr(m, 'todense'): + dense = m.todense() + else: + dense = m + for i in range(m.shape[0]): + for j in range(m.shape[1]): + if dense[i, j] > 0: + c = int((dense[i, j] - mi) / (ma - mi) * 25) + mat[i, j] = chr(ord('A') + c) + return '\n'.join(''.join(line) for line in mat) + + +def max_diff(a, b): + if a.shape != b.shape: + raise ValueError( + f"Cannot compare matrices with different shapes " + f"{a.shape} != {b.shape}.") + d = numpy.abs(a - b).max() + return d + +################################ +# Artificial datasets +# +++++++++++++++++++ +# +# Iris + a text column. + + +strings = numpy.array([ + "This a sentence.", + "This a sentence with more characters $^*&'(-...", + """var = ClassName(var2, user=mail@anywhere.com, pwd""" + """=")_~-('&]@^\\`|[{#")""", + "c79857654", + "https://complex-url.com/;76543u3456?g=hhh&h=23", + "This is a kind of timestamp 01-03-05T11:12:13", + "https://complex-url.com/;dd76543u3456?g=ddhhh&h=23", +]).reshape((-1, 1)) +labels = numpy.array(['http' in s for s in strings[:, 0]], dtype=numpy.int64) +data = [] + +pprint.pprint(strings) + +################################ +# Fit a TfIdfVectorizer +# +++++++++++++++++++++ + +tfidf = Pipeline([ + ('pre', ColumnTransformer([ + ('tfidf', TfidfVectorizer(ngram_range=(1, 2)), 0) + ])) +]) + +################################ +# We leave a couple of strings out of the training set. + +tfidf.fit(strings[:-2]) +tr = tfidf.transform(strings) +tfidf_step = tfidf.steps[0][1].transformers_[0][1] +pprint.pprint(f"output columns: {tfidf_step.get_feature_names_out()}") +print(f"rendered outputs, shape={tr.shape!r}") +print(print_sparse_matrix(tr)) + +################################ +# Conversion to ONNX +# ++++++++++++++++++ + +onx = to_onnx(tfidf, strings) +print(onnx_simple_text_plot(onx)) + + +################################ +# Execution with ONNX and explanation of the discrepancies +# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +for rt in ['python', 'onnxruntime1']: + oinf = OnnxInference(onx, runtime=rt) + got = oinf.run({'X': strings})['variable'] + d = max_diff(tr, got) + data.append(dict(diff=d, runtime=rt, exp='baseline')) + print(f"runtime={rt!r}, shape={got.shape!r}, "f"differences={d:g}") + print(print_sparse_matrix(got)) + +################################ +# The conversion to ONNX is not exactly the same. The Tokenizer +# produces differences. By looking at the tokenized strings by onnx, +# word `h` appears in sequence `amp|h|23` and the bi-grams `amp,23` +# is never produced on this short example. + +oinf = OnnxInference(onx, runtime='python', inplace=False) +res = oinf.run({'X': strings}, intermediate=True) +pprint.pprint(list(map(lambda s: '|'.join(s), res['tokenized']))) + +################################ +# By default, :epkg:`scikit-learn` uses a regular expression. + +print(f"tokenizer pattern: {tfidf_step.token_pattern!r}.") + +################################ +# :epkg:`onnxruntime` uses :epkg:`re2` to handle the regular expression +# and there are differences with python regular expressions. + +onx = to_onnx(tfidf, strings, + options={TfidfVectorizer: {'tokenexp': r'(?u)\b\w\w+\b'}}) +print(onnx_simple_text_plot(onx)) +try: + InferenceSession(onx.SerializeToString()) +except Exception as e: + print(f"ERROR: {e!r}.") + +################################ +# A pipeline +# ++++++++++ +# +# Let's assume the pipeline is followed by a logistic regression. + +pipe = Pipeline([ + ('pre', ColumnTransformer([ + ('tfidf', TfidfVectorizer(ngram_range=(1, 2)), 0)])), + ('logreg', LogisticRegression())]) +pipe.fit(strings[:-2], labels[:-2]) +pred = pipe.predict_proba(strings) +print(f"predictions:\n{pred}") + +################################ +# Let's convert into ONNX and check the predictions. + +onx = to_onnx(pipe, strings, options={'zipmap': False}) +for rt in ['python', 'onnxruntime1']: + oinf = OnnxInference(onx, runtime=rt) + pred_onx = oinf.run({'X': strings})['probabilities'] + d = max_diff(pred, pred_onx) + data.append(dict(diff=d, runtime=rt, exp='replace')) + print(f"ONNX prediction {rt!r} - diff={d}:\n{pred_onx!r}") + +################################ +# There are discrepancies introduced by the fact the regular expression +# uses in ONNX and by scikit-learn are not exactly the same. +# In this case, the runtime cannot replicate what python does. +# The runtime can be changed (see :epkg:`onnxruntime-extensions`). +# This example explores another direction. +# +# Replace the TfIdfVectorizer by ONNX before next step +# ++++++++++++++++++++++++++++++++++++++++++++++++++++ +# +# Let's start by training the +# :class:`sklearn.feature_extraction.text.TfidfVectorizer`. + +tfidf = TfidfVectorizer(ngram_range=(1, 2)) +tfidf.fit(strings[:-2, 0]) + +######################################### +# Once it is trained, we convert it into ONNX and replace +# it by a new transformer using onnx to transform the feature. +# That's the purpose of class +# :class:`mlprodict.sklapi.onnx_transformer.OnnxTransformer`. +# It takes an onnx graph and executes it to transform +# the input features. It follows scikit-learn API. + +onx = to_onnx(tfidf, strings) + +pipe = Pipeline([ + ('pre', ColumnTransformer([ + ('tfidf', OnnxTransformer(onx, runtime='onnxruntime1'), [0])])), + ('logreg', LogisticRegression())]) +pipe.fit(strings[:-2], labels[:-2]) +pred = pipe.predict_proba(strings) +print(f"predictions:\n{pred}") + +######################################### +# Let's convert the whole pipeline to ONNX. + +onx = to_onnx(pipe, strings, options={'zipmap': False}) +for rt in ['python', 'onnxruntime1']: + oinf = OnnxInference(onx, runtime=rt) + pred_onx = oinf.run({'X': strings})['probabilities'] + d = max_diff(pred, pred_onx) + data.append(dict(diff=d, runtime=rt, exp='OnnxTransformer')) + print(f"ONNX prediction {rt!r} - diff={d}:\n{pred_onx!r}") + +######################################### +# There are no discrepancies anymore. +# However this option implies to train first a transformer, +# to convert it into ONNX and to replace it by an equivalent +# transformer based on ONNX. Another class is doing all of it +# automatically. +# +# Train with scikit-learn, transform with ONNX +# ++++++++++++++++++++++++++++++++++++++++++++ +# +# Everything is done with the following class: +# :class:`mlprodict.sklapi.onnx_speed_up.OnnxSpeedupTransformer`. + +pipe = Pipeline([ + ('pre', ColumnTransformer([ + ('tfidf', OnnxSpeedupTransformer( + TfidfVectorizer(ngram_range=(1, 2)), + runtime='onnxruntime1', + enforce_float32=False), 0)])), + ('logreg', LogisticRegression())]) +pipe.fit(strings[:-2], labels[:-2]) +pred = pipe.predict_proba(strings) +print(f"predictions:\n{pred}") + +######################################### +# Let's convert the whole pipeline to ONNX. + +onx = to_onnx(pipe, strings, options={'zipmap': False}) +for rt in ['python', 'onnxruntime1']: + oinf = OnnxInference(onx, runtime=rt) + pred_onx = oinf.run({'X': strings})['probabilities'] + d = max_diff(pred, pred_onx) + data.append(dict(diff=d, runtime=rt, exp='OnnxSpeedupTransformer')) + print(f"ONNX prediction {rt!r} - diff={d}:\n{pred_onx!r}") + +############################################ +# This class was originally created to replace one +# part of a pipeline with ONNX to speed up predictions. +# There is no discrepancy. Let's display the pipeline. +print(onnx_simple_text_plot(onx)) + +############################################ +# Graph +# +++++ + +df = pandas.DataFrame(data) +df + +############################################ +# plot + +df[df.runtime == 'onnxruntime1'][['exp', 'diff']].set_index( + 'exp').plot(kind='barh') + + +# import matplotlib.pyplot as plt +# plt.show() diff --git a/_doc/examples/plot_usparse_xgboost.py b/_doc/examples/plot_usparse_xgboost.py index 451d34b2..ae132fc0 100644 --- a/_doc/examples/plot_usparse_xgboost.py +++ b/_doc/examples/plot_usparse_xgboost.py @@ -33,8 +33,6 @@ from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer -from sklearn.experimental import ( # noqa - enable_hist_gradient_boosting) # noqa from sklearn.ensemble import ( RandomForestClassifier, HistGradientBoostingClassifier) from xgboost import XGBClassifier @@ -92,6 +90,96 @@ # sparse matrices to be converted into dense matrices. +def make_pipeline(model, insert_replace, sparse_threshold): + if model == HistGradientBoostingClassifier: + kwargs = dict(max_iter=5) + elif model == XGBClassifier: + kwargs = dict(n_estimators=5, use_label_encoder=False) + else: + kwargs = dict(n_estimators=5) + + if insert_replace: + pipe = Pipeline([ + ('union', ColumnTransformer([ + ('scale1', StandardScaler(), [0, 1]), + ('subject', + Pipeline([ + ('count', CountVectorizer()), + ('tfidf', TfidfTransformer()), + ('repl', ReplaceTransformer()), # added transformer + ]), "text"), + ], sparse_threshold=sparse_threshold)), + ('cast', CastTransformer()), + ('cls', model(max_depth=3, **kwargs)), + ]) + else: + pipe = Pipeline([ + ('union', ColumnTransformer([ + ('scale1', StandardScaler(), [0, 1]), + ('subject', + Pipeline([ + ('count', CountVectorizer()), + ('tfidf', TfidfTransformer()) + ]), "text"), + ], sparse_threshold=sparse_threshold)), + ('cast', CastTransformer()), + ('cls', model(max_depth=3, **kwargs)), + ]) + return pipe + + +def model_to_onnx(pipe, options): + with warnings.catch_warnings(record=False): + warnings.simplefilter("ignore", (FutureWarning, UserWarning)) + model_onnx = to_onnx( + pipe, + initial_types=[('input', FloatTensorType([None, 2])), + ('text', StringTensorType([None, 1]))], + target_opset={'': 14, 'ai.onnx.ml': 2}, + options=options) + + with open('model.onnx', 'wb') as f: + f.write(model_onnx.SerializeToString()) + return model_onnx + + +def print_status(obs, inputs, pipe, model_onnx, pred_onx, diff, verbose): + if verbose: + def td(a): + if hasattr(a, 'todense'): + b = a.todense() + ind = set(a.indices) + for i in range(b.shape[1]): + if i not in ind: + b[0, i] = numpy.nan + return b + return a + + oinf = OnnxInference(model_onnx) + pred_onx2 = oinf.run(inputs) + diff2 = numpy.abs( + pred_onx2['probabilities'].ravel() - + pipe.predict_proba(df).ravel()).sum() + obs['discrepency2'] = diff2 + + if diff > 0.1: + for i, (l1, l2) in enumerate( + zip(pipe.predict_proba(df), + pred_onx['probabilities'])): + d = numpy.abs(l1 - l2).sum() + if verbose and d > 0.1: + print("\nDISCREPENCY DETAILS") + print(d, i, l1, l2) + pre = pipe.steps[0][-1].transform(df) + print("idf", pre[i].dtype, td(pre[i])) + pre2 = pipe.steps[1][-1].transform(pre) + print("cas", pre2[i].dtype, td(pre2[i])) + inter = oinf.run(inputs, intermediate=True) + onx = inter['tfidftr_norm'] + print("onx", onx.dtype, onx[i]) + onx = inter['variable3'] + + def make_pipelines(df_train, y_train, models=None, sparse_threshold=1., replace_nan=False, insert_replace=False, verbose=False): @@ -104,42 +192,7 @@ def make_pipelines(df_train, y_train, models=None, pipes = [] for model in tqdm(models): - - if model == HistGradientBoostingClassifier: - kwargs = dict(max_iter=5) - elif model == XGBClassifier: - kwargs = dict(n_estimators=5, use_label_encoder=False) - else: - kwargs = dict(n_estimators=5) - - if insert_replace: - pipe = Pipeline([ - ('union', ColumnTransformer([ - ('scale1', StandardScaler(), [0, 1]), - ('subject', - Pipeline([ - ('count', CountVectorizer()), - ('tfidf', TfidfTransformer()), - ('repl', ReplaceTransformer()), - ]), "text"), - ], sparse_threshold=sparse_threshold)), - ('cast', CastTransformer()), - ('cls', model(max_depth=3, **kwargs)), - ]) - else: - pipe = Pipeline([ - ('union', ColumnTransformer([ - ('scale1', StandardScaler(), [0, 1]), - ('subject', - Pipeline([ - ('count', CountVectorizer()), - ('tfidf', TfidfTransformer()) - ]), "text"), - ], sparse_threshold=sparse_threshold)), - ('cast', CastTransformer()), - ('cls', model(max_depth=3, **kwargs)), - ]) - + pipe = make_pipeline(model, insert_replace, sparse_threshold) try: pipe.fit(df_train, y_train) except TypeError as e: @@ -150,68 +203,25 @@ def make_pipelines(df_train, y_train, models=None, options = {model: {'zipmap': False}} if replace_nan: options[TfidfTransformer] = {'nan': True} + model_onnx = model_to_onnx(pipe, options) # convert - with warnings.catch_warnings(record=False): - warnings.simplefilter("ignore", (FutureWarning, UserWarning)) - model_onnx = to_onnx( - pipe, - initial_types=[('input', FloatTensorType([None, 2])), - ('text', StringTensorType([None, 1]))], - target_opset={'': 14, 'ai.onnx.ml': 2}, - options=options) - - with open('model.onnx', 'wb') as f: - f.write(model_onnx.SerializeToString()) oinf = OnnxInference(model_onnx) inputs = {"input": df[[0, 1]].values.astype(numpy.float32), "text": df[["text"]].values} pred_onx = oinf.run(inputs) + # check + diff = numpy.abs( pred_onx['probabilities'].ravel() - pipe.predict_proba(df).ravel()).sum() - if verbose: - def td(a): - if hasattr(a, 'todense'): - b = a.todense() - ind = set(a.indices) - for i in range(b.shape[1]): - if i not in ind: - b[0, i] = numpy.nan - return b - return a - - oinf = OnnxInference(model_onnx) - pred_onx2 = oinf.run(inputs) - diff2 = numpy.abs( - pred_onx2['probabilities'].ravel() - - pipe.predict_proba(df).ravel()).sum() - - if diff > 0.1: - for i, (l1, l2) in enumerate( - zip(pipe.predict_proba(df), - pred_onx['probabilities'])): - d = numpy.abs(l1 - l2).sum() - if verbose and d > 0.1: - print("\nDISCREPENCY DETAILS") - print(d, i, l1, l2) - pre = pipe.steps[0][-1].transform(df) - print("idf", pre[i].dtype, td(pre[i])) - pre2 = pipe.steps[1][-1].transform(pre) - print("cas", pre2[i].dtype, td(pre2[i])) - inter = oinf.run(inputs, intermediate=True) - onx = inter['tfidftr_norm'] - print("onx", onx.dtype, onx[i]) - onx = inter['variable3'] - obs = dict(model=model.__name__, discrepencies=diff, model_onnx=model_onnx, pipe=pipe) - if verbose: - obs['discrepency2'] = diff2 + print_status(obs, inputs, pipe, model_onnx, pred_onx, diff, verbose) pipes.append(obs) return pipes diff --git a/_doc/sphinxdoc/source/tutorials/tutorial_skl/tutorial_1_simple.rst b/_doc/sphinxdoc/source/tutorials/tutorial_skl/tutorial_1_simple.rst index 6a6fa6f8..f3af14f1 100644 --- a/_doc/sphinxdoc/source/tutorials/tutorial_skl/tutorial_1_simple.rst +++ b/_doc/sphinxdoc/source/tutorials/tutorial_skl/tutorial_1_simple.rst @@ -20,8 +20,6 @@ used in the ONNX graph. ../../gyexamples/plot_dbegin_options ../../gyexamples/plot_dbegin_options_list ../../gyexamples/plot_dbegin_options_zipmap - ../../gyexamples/plot_ebegin_float_double - ../../gyexamples/plot_funny_sigmoid ../../gyexamples/plot_fbegin_investigate ../../gyexamples/plot_gbegin_dataframe ../../gyexamples/plot_gbegin_transfer_learning diff --git a/_doc/sphinxdoc/source/tutorials/tutorial_skl/tutorial_4_complex.rst b/_doc/sphinxdoc/source/tutorials/tutorial_skl/tutorial_4_complex.rst index a67d6458..d2451a80 100644 --- a/_doc/sphinxdoc/source/tutorials/tutorial_skl/tutorial_4_complex.rst +++ b/_doc/sphinxdoc/source/tutorials/tutorial_skl/tutorial_4_complex.rst @@ -1,11 +1,24 @@ -Complex Scenarios -================= +Complex Scenarios and discepancies +================================== Discrepencies may happen. Let's see some unexpected cases. +Dealing with discrepancies +++++++++++++++++++++++++++ + +.. toctree:: + :maxdepth: 1 + + ../../gyexamples/plot_ebegin_float_double + ../../gyexamples/plot_funny_sigmoid + +Unexpected issues ++++++++++++++++++ + .. toctree:: :maxdepth: 1 ../../gyexamples/plot_usparse_xgboost ../../gyexamples/plot_gexternal_lightgbm_reg + ../../gyexamples/plot_transformer_discrepancy