In [2]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV

In [1]:
def knn_iris():
    """
    用 KNN 算法对鸢尾花进行分类
    """
    # 1. 获取数据
    iris = load_iris()

    # 2. 划分数据集
    x_train, x_test, y_train, y_test = train_test_split(
        iris.data, iris.target, random_state=6)

    # 3. 特征工程: 标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)

    # 4. KNN 算法预估器
    estimator = KNeighborsClassifier(n_neighbors=3)
    estimator.fit(x_train, y_train)

    # 5. 模型评估
    # 方法1: 直接比对真实值和预测值
    y_predict = estimator.predict(x_test)
    print("y_predict: \n", y_predict)
    print("直接比对真实值和预测值: \n", y_test == y_predict)

    # 方法2: 计算准确率
    score = estimator.score(x_test, y_test)
    print("准确率为: \n", score)
    return None


In [3]:
def knn_iris_gscv():
    """
    用 KNN 算法对鸢尾花进行分类, 添加网络搜索和交叉验证
    """
    # 1) 获取数据
    iris = load_iris()

    # 2) 划分数据集
    x_train, x_test, y_train, y_test = train_test_split(
        iris.data, iris.target, random_state=22)

    # 3) 特征工程: 标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)

    # 4) KNN 算法预估器
    estimator = KNeighborsClassifier()

    # 加入网络搜索与交叉验证
    # 参数准备
    param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}
    estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)
    estimator.fit(x_train, y_train)

    # 5) 模型评估
    # 方法1: 直接对比真实值和预测值
    y_predict = estimator.predict(x_test)
    print("y_predict: \n", y_predict)
    print("直接对比真实值和预测值: \n", y_test == y_predict)

    # 方法2: 计算准确率
    score = estimator.score(x_test, y_test)
    print("准确率为: \n", score)

    print("最佳参数为: \n", estimator.best_params_)
    print("最佳结果为: \n", estimator.best_score_)
    print("最佳估计器为: \n", estimator.best_estimator_)
    print("最终交叉验证结果为: \n", estimator.cv_results_)

    return None


In [5]:
if __name__ == "__main__":
    # code1: 用 KNN 算法对鸢尾花进行分类
    # knn_iris()
    # code2: 利用KNN算法对鸢尾花进行分类, 添加网格搜索和交叉验证
    knn_iris_gscv()

y_predict: 
 [0 2 0 0 2 1 2 0 2 1 2 1 2 2 1 1 2 1 1 0 0 2 0 0 1 1 1 2 0 1 0 1 0 0 1 2 1
 2]
直接比对真实值和预测值: 
 [ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True False  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True False  True
  True  True]
准确率为: 
 0.9473684210526315
最佳参数为: 
 {'n_neighbors': 11}
最佳结果为: 
 0.9734848484848484
最佳估计器为: 
 KNeighborsClassifier(n_neighbors=11)
交叉验证结果为: 
 {'mean_fit_time': array([2.46381760e-04, 9.65833664e-05, 9.41514969e-05, 9.29117203e-05,
       1.10650063e-04, 9.57489014e-05]), 'std_fit_time': array([4.07031047e-04, 2.06875349e-06, 6.34838436e-07, 1.81902595e-06,
       1.96921190e-05, 2.20637041e-06]), 'mean_score_time': array([0.00039165, 0.00025673, 0.00025871, 0.00025949, 0.00029321,
       0.00026369]), 'std_score_time': array([3.05778014e-04, 1.17183641e-05, 1.70773122e-05, 2.35299208e-05,
       4.15034362e-05, 8.68315097e-06]), 'param_n_neighbors': ma