This repository has been archived by the owner on Jan 13, 2024. It is now read-only.
/
svm_converters.py
57 lines (49 loc) · 1.8 KB
/
svm_converters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
@file
@brief Rewrites some of the converters implemented in
:epkg:`sklearn-onnx`.
"""
import numpy
from skl2onnx.operator_converters.support_vector_machines import (
convert_sklearn_svm_regressor,
convert_sklearn_svm_classifier)
def _op_type_domain_regressor(container):
"""
Defines *op_type* and *op_domain* based on
`container.dtype`.
"""
if container.dtype == numpy.float32:
return 'SVMRegressor', 'ai.onnx.ml', 1
if container.dtype == numpy.float64:
return 'SVMRegressorDouble', 'mlprodict', 1
raise RuntimeError("Unsupported dtype {}.".format(container.dtype))
def _op_type_domain_classifier(container):
"""
Defines *op_type* and *op_domain* based on
`container.dtype`.
"""
if container.dtype == numpy.float32:
return 'SVMClassifier', 'ai.onnx.ml', 1
if container.dtype == numpy.float64:
return 'SVMClassifierDouble', 'mlprodict', 1
raise RuntimeError("Unsupported dtype {}.".format(container.dtype))
def new_convert_sklearn_svm_regressor(scope, operator, container):
"""
Rewrites the converters implemented in
:epkg:`sklearn-onnx` to support an operator supporting
doubles.
"""
op_type, op_domain, op_version = _op_type_domain_regressor(container)
convert_sklearn_svm_regressor(
scope, operator, container, op_type=op_type, op_domain=op_domain,
op_version=op_version)
def new_convert_sklearn_svm_classifier(scope, operator, container):
"""
Rewrites the converters implemented in
:epkg:`sklearn-onnx` to support an operator supporting
doubles.
"""
op_type, op_domain, op_version = _op_type_domain_classifier(container)
convert_sklearn_svm_classifier(
scope, operator, container, op_type=op_type, op_domain=op_domain,
op_version=op_version)