[特征缩放的重要性](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_scaling_importance.html#sphx-glr-auto-examples-preprocessing-plot-scaling-importance-py)

In [3]:
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_wine
from sklearn.pipeline import make_pipeline

In [4]:
wine = load_wine()

In [20]:
# 显示其描述
print("【DESCR】\n", wine.DESCR)
# 显示其特征名称
print("【feature_names】\n", wine.feature_names)
# 显示其特征数据
print("【data】\n", wine.data.shape)
print(wine.data[:5])
# 显示其标签数据
print("【target】\n", wine.target.shape)
print(wine.target[:5])

【DESCR】
 .. _wine_dataset:

Wine recognition dataset
------------------------

**Data Set Characteristics:**

    :Number of Instances: 178 (50 in each of three classes)
    :Number of Attributes: 13 numeric, predictive attributes and the class
    :Attribute Information:
 		- Alcohol
 		- Malic acid
 		- Ash
		- Alcalinity of ash  
 		- Magnesium
		- Total phenols
 		- Flavanoids
 		- Nonflavanoid phenols
 		- Proanthocyanins
		- Color intensity
 		- Hue
 		- OD280/OD315 of diluted wines
 		- Proline

    - class:
            - class_0
            - class_1
            - class_2
		
    :Summary Statistics:
    
                                   Min   Max   Mean     SD
    Alcohol:                      11.0  14.8    13.0   0.8
    Malic Acid:                   0.74  5.80    2.34  1.12
    Ash:                          1.36  3.23    2.36  0.27
    Alcalinity of Ash:            10.6  30.0    19.5   3.3
    Magnesium:                    70.0 162.0    99.7  14.3
    Total Phenols:        

In [6]:
X, y = wine.data, wine.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [9]:
# Fit to data and predict using pipelined GNB and PCA
unscaled_clf = make_pipeline(PCA(n_components=2), GaussianNB())
unscaled_clf.fit(X_train, y_train)
pred_test = unscaled_clf.predict(X_test)

In [7]:
# Fit to data and predict using pipelined scaling, GNB and PCA
std_clf = make_pipeline(StandardScaler(), PCA(n_components=2), GaussianNB())
std_clf.fit(X_train, y_train)
pred_test_std = std_clf.predict(X_test)

In [10]:
# Show prediction accuracies in scaled and unscaled data.
print("\nPrediction accuracy for the normal test dataset with PCA")
print(f"{accuracy_score(y_test, pred_test):.2%}\n")


Prediction accuracy for the normal test dataset with PCA
81.48%



In [8]:
print("\nPrediction accuracy for the standardized test dataset with PCA")
print(f"{accuracy_score(y_test, pred_test_std):.2%}\n")


Prediction accuracy for the standardized test dataset with PCA
98.15%



In [17]:
# Convert into ONNX format
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
initial_type = [('float_input', FloatTensorType([None, X_train.shape[1]]))]
onx = convert_sklearn(std_clf, initial_types=initial_type)
with open("saved_model/GNB_wine.onnx", "wb") as f:
    f.write(onx.SerializeToString())



In [18]:
# Compute the prediction with ONNX Runtime
import onnxruntime as rt
import numpy
sess = rt.InferenceSession("saved_model/GNB_wine.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test[:5].astype(numpy.float32)})[0]
pred_onx

array([0, 0, 2, 0, 1], dtype=int64)

In [19]:
X_test[:5]

array([[1.364e+01, 3.100e+00, 2.560e+00, 1.520e+01, 1.160e+02, 2.700e+00,
        3.030e+00, 1.700e-01, 1.660e+00, 5.100e+00, 9.600e-01, 3.360e+00,
        8.450e+02],
       [1.421e+01, 4.040e+00, 2.440e+00, 1.890e+01, 1.110e+02, 2.850e+00,
        2.650e+00, 3.000e-01, 1.250e+00, 5.240e+00, 8.700e-01, 3.330e+00,
        1.080e+03],
       [1.293e+01, 2.810e+00, 2.700e+00, 2.100e+01, 9.600e+01, 1.540e+00,
        5.000e-01, 5.300e-01, 7.500e-01, 4.600e+00, 7.700e-01, 2.310e+00,
        6.000e+02],
       [1.373e+01, 1.500e+00, 2.700e+00, 2.250e+01, 1.010e+02, 3.000e+00,
        3.250e+00, 2.900e-01, 2.380e+00, 5.700e+00, 1.190e+00, 2.710e+00,
        1.285e+03],
       [1.237e+01, 1.170e+00, 1.920e+00, 1.960e+01, 7.800e+01, 2.110e+00,
        2.000e+00, 2.700e-01, 1.040e+00, 4.680e+00, 1.120e+00, 3.480e+00,
        5.100e+02]])