### MLP多层感知机

In [3]:
from utils.dataset_utils import get_classes_indexes_counts
from sklearn.metrics import confusion_matrix
# 导入必要的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report, accuracy_score

# 加载鸢尾花数据集
data = load_iris()
X = data.data  # 特征
y = data.target  # 标签

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
classes, counts = get_classes_indexes_counts(y_test)
print(counts)
# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 构建并训练MLP模型
mlp = MLPClassifier(hidden_layer_sizes=(10, 20), max_iter=1000, random_state=42)
mlp.fit(X_train, y_train)
index_pred_proba=mlp.predict_proba(X_test)
# 预测和评估模型
y_pred = mlp.predict(X_test)

# 输出结果
print("准确率:", accuracy_score(y_test, y_pred))
print("\n分类报告:\n", classification_report(y_test, y_pred))
# 打印混淆矩阵
print("Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))

print(index_pred_proba.shape)
index_pred_proba

[19 13 13]
准确率: 1.0

分类报告:
               precision    recall  f1-score   support

           0       1.00      1.00      1.00        19
           1       1.00      1.00      1.00        13
           2       1.00      1.00      1.00        13

    accuracy                           1.00        45
   macro avg       1.00      1.00      1.00        45
weighted avg       1.00      1.00      1.00        45

Confusion Matrix:
[[19  0  0]
 [ 0 13  0]
 [ 0  0 13]]
(45, 3)


array([[2.91056625e-03, 9.91228682e-01, 5.86075128e-03],
       [9.99108751e-01, 8.77196677e-04, 1.40524637e-05],
       [5.74537035e-08, 4.27430270e-05, 9.99957200e-01],
       [4.67312665e-03, 9.34363883e-01, 6.09629901e-02],
       [1.97859873e-03, 9.42066591e-01, 5.59548102e-02],
       [9.98202675e-01, 1.77834615e-03, 1.89785227e-05],
       [5.29061246e-03, 9.91297100e-01, 3.41228743e-03],
       [2.82604279e-05, 2.94536771e-03, 9.97026372e-01],
       [2.98332660e-04, 6.47006190e-01, 3.52695478e-01],
       [1.96331651e-03, 9.96794197e-01, 1.24248655e-03],
       [8.82234971e-04, 6.62609508e-02, 9.32856814e-01],
       [9.98674599e-01, 1.31881849e-03, 6.58259368e-06],
       [9.99447127e-01, 5.47749496e-04, 5.12321192e-06],
       [9.98673841e-01, 1.31787978e-03, 8.27946659e-06],
       [9.99567107e-01, 4.20681281e-04, 1.22121393e-05],
       [1.50547672e-02, 8.94667075e-01, 9.02781573e-02],
       [1.85193646e-05, 2.10973036e-03, 9.97871750e-01],
       [1.37829723e-03, 9.97811