Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
. venv/bin/activate
pip install -r requirements.txt
pip install -r requirements-dev.txt
pip install onnxruntime-training

- save_cache:
paths:
Expand Down
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ coverage.html/*
.coverage
dist/*
build/*
data/*
*egg-info/*
doc/auto_examples/*
doc/_static/viz.js
Expand All @@ -15,3 +16,7 @@ temp_*
examples/pipeline_lightgbm.onnx
examples/model.onnx
tests/model.onnx
*.jpg
*.onnx
*.pt

1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ install:
- pip install -r requirements.txt
- pip install scikit-learn
- pip install -r requirements-dev.txt
- pip install onnxruntime-training
before_script:
- gcc --version
- python setup.py build_ext --inplace
Expand Down
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ before_test:
- "%PYTHON%\\python -u setup.py build_ext --inplace"

test_script:
- "%PYTHON%\\python -m unittest discover tests"
- "%PYTHON%\\python -m unittest discover tests --verbose"
- "%PYTHON%\\python -m flake8 tests"
- "%PYTHON%\\python -m flake8 onnxcustom"
- "%PYTHON%\\python -m flake8 examples"
Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
python setup.py build_ext --inplace
displayName: 'Build package'
- script: |
python -m pytest -v -v
python -m pytest tests -v -v
displayName: 'Runs Unit Tests'
- script: |
python -u setup.py bdist_wheel
Expand Down
3 changes: 1 addition & 2 deletions bin/flake8.bat
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
set current=%~dp0
set root=%current%..
cd %root%
set pythonexe="c:\Python387_x64\python.exe"
if not exist %pythonexe% set pythonexe="c:\Python372_x64\python.exe"
set pythonexe="python.exe"

@echo running 'python -m autopep8 --in-place --aggressive --aggressive -r'
%pythonexe% -m autopep8 --in-place --aggressive --aggressive -r onnxcustom tests examples setup.py doc/conf.py
Expand Down
5 changes: 2 additions & 3 deletions bin/unittest.bat
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
set current=%~dp0
set root=%current%..
cd %root%
set pythonexe="c:\Python387_x64\python.exe"
if not exist %pythonexe% set pythonexe="c:\Python372_x64\python.exe"
set pythonexe="python.exe"

@echo running 'python -m unittest discover tests'
%pythonexe% -m unittest discover tests
%pythonexe% -m unittest discover tests --verbose

if %errorlevel% neq 0 exit /b %errorlevel%
@echo Done Testing.
1 change: 1 addition & 0 deletions doc/tutorial_1_simple.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ used in the ONNX graph.
auto_examples/plot_fbegin_investigate
auto_examples/plot_gbegin_dataframe
auto_examples/plot_gbegin_transfer_learning
auto_examples/plot_gbegin_cst
2 changes: 2 additions & 0 deletions doc/tutorial_2_new_converter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ Following section shows how to create a custom converter.
auto_examples/plot_lcustom_options
auto_examples/plot_mcustom_parser
auto_examples/plot_mcustom_parser_dataframe
auto_examples/plot_catwoe_transformer
auto_examples/plot_woe_transformer
1 change: 1 addition & 0 deletions doc/tutorial_4_complex.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Discrepencies may happen. Let's see some unexpected cases.
:maxdepth: 1

auto_examples/plot_usparse_xgboost
auto_examples/plot_gexternal_lightgbm_reg
3 changes: 2 additions & 1 deletion examples/plot_abegin_convert_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
# into single float and ONNX runtimes may not fully
# support doubles.

onx = to_onnx(ereg, X_train[:1].astype(numpy.float32))
onx = to_onnx(ereg, X_train[:1].astype(numpy.float32),
target_opset=12)

###################################
# Prediction with ONNX
Expand Down
233 changes: 233 additions & 0 deletions examples/plot_catwoe_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""
.. _example-catwoe-transformer:

Converter for WOEEncoder from categorical_encoder
=================================================

`WOEEncoder <https://contrib.scikit-learn.org/category_encoders/woe.html>`_
is a transformer implemented in `categorical_encoder
<https://contrib.scikit-learn.org/category_encoders/>`_ and as such,
any converter would not be included in *sklearn-onnx* which only
implements converters for *scikit-learn* models. Anyhow, this
example demonstrates how to implement a custom converter
for *WOEEncoder*. This code is not fully tested for all possible
cases the original encoder can handle.

.. index:: WOE, WOEEncoder

.. contents::
:local:

A simple example
++++++++++++++++

Let's take the `Iris dataset
<https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html>`_.
Every feature is converter into integer.
"""
import numpy as np
from onnxruntime import InferenceSession
from sklearn.datasets import load_iris
from sklearn.preprocessing import OrdinalEncoder as SklOrdinalEncoder
from category_encoders import WOEEncoder, OrdinalEncoder
from skl2onnx import update_registered_converter, to_onnx, get_model_alias
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.common.utils import check_input_and_output_numbers
from skl2onnx.algebra.onnx_ops import OnnxCast
from skl2onnx.algebra.onnx_operator import OnnxSubEstimator
from skl2onnx.sklapi import WOETransformer
import skl2onnx.sklapi.register # noqa

data = load_iris()
X, y = data.data, data.target
X = X.astype(np.int64)[:, :2]
y = (y == 2).astype(np.int64)

woe = WOEEncoder(cols=[0]).fit(X, y)
print(woe.transform(X[:5]))

########################################
# Let's look into the trained parameters of the model.
# It appears that WOEEncoder uses an OrdinalEncoder
# but not the one from scikit-learn. We need to add a
# converter for this model tool.

print("encoder", type(woe.ordinal_encoder), woe.ordinal_encoder)
print("mapping", woe.mapping)
print("encoder.mapping", woe.ordinal_encoder.mapping)
print("encoder.cols", woe.ordinal_encoder.cols)

######################################
# Custom converter for OrdinalEncoder
# +++++++++++++++++++++++++++++++++++
#
# We start from example :ref:`l-plot-custom-converter`
# and then write the conversion.


def ordenc_to_sklearn(op_mapping):
"Converts OrdinalEncoder mapping to scikit-learn OrdinalEncoder."
cats = []
for column_map in op_mapping:
col = column_map['col']
while len(cats) <= col:
cats.append(None)
mapping = column_map['mapping']
res = []
for i in range(mapping.shape[0]):
if np.isnan(mapping.index[i]):
continue
ind = mapping.iloc[i]
while len(res) <= ind:
res.append(0)
res[ind] = mapping.index[i]
cats[col] = np.array(res, dtype=np.int64)

skl_ord = SklOrdinalEncoder(categories=cats, dtype=np.int64)
skl_ord.categories_ = cats
return skl_ord


def ordinal_encoder_shape_calculator(operator):
check_input_and_output_numbers(
operator, input_count_range=1, output_count_range=1)
input_type = operator.inputs[0].type.__class__
input_dim = operator.inputs[0].get_first_dimension()
shape = operator.inputs[0].type.shape
second_dim = None if len(shape) != 2 else shape[1]
output_type = input_type([input_dim, second_dim])
operator.outputs[0].type = output_type


def ordinal_encoder_converter(scope, operator, container):
op = operator.raw_operator
opv = container.target_opset
X = operator.inputs[0]

skl_ord = ordenc_to_sklearn(op.mapping)
cat = OnnxSubEstimator(skl_ord, X, op_version=opv,
output_names=operator.outputs[:1])
cat.add_to(scope, container)


update_registered_converter(
OrdinalEncoder, "CategoricalEncoderOrdinalEncoder",
ordinal_encoder_shape_calculator,
ordinal_encoder_converter)


###################################
# Let's compute the output one a short example.


enc = OrdinalEncoder(cols=[0, 1])
enc.fit(X)
print(enc.transform(X[:5]))


###################################
# Let's check the ONNX conversion produces the same results.


ord_onx = to_onnx(enc, X[:1], target_opset=14)
sess = InferenceSession(ord_onx.SerializeToString())
print(sess.run(None, {'X': X[:5]})[0])

######################################
# That works.
#
# Custom converter for WOEEncoder
# +++++++++++++++++++++++++++++++
#
# We start from example :ref:`l-plot-custom-converter`
# and then write the conversion.


def woeenc_to_sklearn(op_mapping):
"Converts WOEEncoder mapping to scikit-learn OrdinalEncoder."
cats = []
ws = []
for column_map in op_mapping.items():
col = column_map[0]
while len(cats) <= col:
cats.append('passthrough')
ws.append(None)
mapping = column_map[1]
intervals = []
weights = []
for i in range(mapping.shape[0]):
ind = mapping.index[i]
if ind < 0:
continue
intervals.append((float(ind - 1), float(ind), False, True))
weights.append(mapping.iloc[i])
cats[col] = intervals
ws[col] = weights

skl = WOETransformer(intervals=cats, weights=ws, onehot=False)
skl.fit(None)
return skl


def woe_encoder_parser(
scope, model, inputs, custom_parsers=None):
if len(inputs) != 1:
raise RuntimeError(
"Unexpected number of inputs: %d != 1." % len(inputs))
if inputs[0].type is None:
raise RuntimeError(
"Unexpected type: %r." % (inputs[0], ))
alias = get_model_alias(type(model))
this_operator = scope.declare_local_operator(alias, model)
this_operator.inputs.append(inputs[0])
this_operator.outputs.append(
scope.declare_local_variable('catwoe', FloatTensorType()))
return this_operator.outputs


def woe_encoder_shape_calculator(operator):
check_input_and_output_numbers(
operator, input_count_range=1, output_count_range=1)
input_dim = operator.inputs[0].get_first_dimension()
shape = operator.inputs[0].type.shape
second_dim = None if len(shape) != 2 else shape[1]
output_type = FloatTensorType([input_dim, second_dim])
operator.outputs[0].type = output_type


def woe_encoder_converter(scope, operator, container):
op = operator.raw_operator
opv = container.target_opset
X = operator.inputs[0]

sub = OnnxSubEstimator(op.ordinal_encoder, X,
op_version=opv)
cast = OnnxCast(sub, op_version=opv, to=np.float32)
skl_ord = woeenc_to_sklearn(op.mapping)
cat = OnnxSubEstimator(skl_ord, cast, op_version=opv,
output_names=operator.outputs[:1],
input_types=[FloatTensorType()])
cat.add_to(scope, container)


update_registered_converter(
WOEEncoder, "CategoricalEncoderWOEEncoder",
woe_encoder_shape_calculator,
woe_encoder_converter,
parser=woe_encoder_parser)


###################################
# Let's compute the output one a short example.

woe = WOEEncoder(cols=[0, 1]).fit(X, y)
print(woe.transform(X[:5]))


###################################
# Let's check the ONNX conversion produces the same results.


woe_onx = to_onnx(woe, X[:1], target_opset=14)
sess = InferenceSession(woe_onx.SerializeToString())
print(sess.run(None, {'X': X[:5]})[0])
Loading