In [1]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.preprocessing import StandardScaler

In [6]:
# 1、获取数据集
iris = load_iris()
# 2、数据基本处理 -- 划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)

# 3、特征工程：标准化
# 实例化一个转换器类
transfer = StandardScaler()
# 调用fit_transform
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

# 4、KNN预估器流程
#  4.1 实例化预估器类
estimator = KNeighborsClassifier()

# 4.2 模型选择与调优——网格搜索和交叉验证 
# 准备要调的超参数
param_dict = {'n_neighbors':[1,3,5,7]}
estimator = GridSearchCV(estimator,param_grid=param_dict,cv=5)

# 4.3 fit数据进行训练
estimator.fit(x_train,y_train)

# 5、评估模型效果
# 方法a：比对预测结果和真实值
y_predict = estimator.predict(x_test)
print("比对预测结果和真实值：\n", y_predict == y_test)

# 方法b：直接计算准确率
score = estimator.score(x_test, y_test)
print("直接计算准确率：\n", score) #这里是训练集训练后在测试集计算

#然后进行评估查看最终选择的结果和交叉验证的结果
print("在交叉验证中验证的最好结果：\n", estimator.best_score_)  # 训练集分为训练集+验证集
print("最好的参数模型：\n", estimator.best_estimator_)
print("每次交叉验证后的准确率结果：\n", estimator.cv_results_)

比对预测结果和真实值：
 [ True  True  True  True  True  True  True False  True  True  True  True
  True  True  True  True  True  True False  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True]
直接计算准确率：
 0.9473684210526315
在交叉验证中验证的最好结果：
 0.9553359683794467
最好的参数模型：
 KNeighborsClassifier()
每次交叉验证后的准确率结果：
 {'mean_fit_time': array([0.00084419, 0.00078721, 0.00080686, 0.0007925 ]), 'std_fit_time': array([5.85565738e-05, 1.23779432e-05, 1.56104280e-05, 1.44631871e-06]), 'mean_score_time': array([0.00314808, 0.00352588, 0.00313988, 0.00318003]), 'std_score_time': array([1.23808636e-04, 7.56915459e-04, 2.61585929e-05, 3.74309576e-05]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 7],
             mask=[False, False, False, False],
       fill_value='?',
            dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}], 'split0_test_score': array([0.95652174, 0.95652174, 1.        , 1. 