# アヤメのデータを利用し、グリッドサーチ・クロスバリデーションを使い、決定木のモデル選択を行う

In [69]:
from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
iris = load_iris()
X, y = iris.data, iris.target

訓練データ、テストデータに分ける

In [70]:
X_train, X_test, y_train, y_test = train_test_split (X, y, test_size=0.3, random_state=123)

決定木を使う

In [71]:
clf = DecisionTreeClassifier()

訓練データに決定木を学習させる

In [72]:
clf.fit(X_train, y_train)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=None, splitter='best')

テストデータに当てはめ

In [73]:
y_pred=clf.predict(X_test)
y_pred

array([1, 2, 2, 1, 0, 2, 1, 0, 0, 1, 2, 0, 1, 2, 2, 2, 0, 0, 1, 0, 0, 1,
       0, 2, 0, 0, 0, 2, 2, 0, 2, 1, 0, 0, 1, 1, 2, 0, 0, 1, 1, 0, 2, 2,
       2])

予測精度を見る

In [74]:
from sklearn. metrics import classification_report
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        18
           1       0.83      1.00      0.91        10
           2       1.00      0.88      0.94        17

    accuracy                           0.96        45
   macro avg       0.94      0.96      0.95        45
weighted avg       0.96      0.96      0.96        45



クロスバリデーションを行う

In [75]:
from sklearn.model_selection import cross_val_score
print(cross_val_score(clf, X_train, y_train, cv=2))

[0.96226415 0.92307692]


In [76]:
print(cross_val_score(clf, X_train, y_train, cv=3))

[0.94444444 0.91428571 0.88235294]


In [77]:
print(cross_val_score(clf, X_train, y_train, cv=4))

[1.         0.92307692 0.96153846 0.84615385]


In [78]:
print(cross_val_score(clf, X_train, y_train, cv=5))

[1.         0.95454545 0.95238095 1.         0.8       ]


グリッドサーチ、クロスバリデーション（10回）を使って、ハイパーパラメータdepthを求める

In [79]:
cv = GridSearchCV(clf, param_grid=param_grid, cv=10)
param_grid = {'max_depth': [3, 4, 5]}
cv.fit(X_train, y_train)

GridSearchCV(cv=10, error_score='raise-deprecating',
             estimator=DecisionTreeClassifier(class_weight=None,
                                              criterion='gini', max_depth=None,
                                              max_features=None,
                                              max_leaf_nodes=None,
                                              min_impurity_decrease=0.0,
                                              min_impurity_split=None,
                                              min_samples_leaf=1,
                                              min_samples_split=2,
                                              min_weight_fraction_leaf=0.0,
                                              presort=False, random_state=None,
                                              splitter='best'),
             iid='warn', n_jobs=None, param_grid={'max_depth': [3, 4, 5]},
             pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
             scoring=N

最適なdepthを求める

In [80]:
cv.best_params_

{'max_depth': 5}

最適なモデルを求める

In [81]:
cv.best_estimator_

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=5,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=None, splitter='best')

テストデータに当てはめる

In [82]:
y_pred=cv.predict(X_test)
y_pred

array([1, 2, 2, 1, 0, 1, 1, 0, 0, 1, 2, 0, 1, 2, 2, 2, 0, 0, 1, 0, 0, 1,
       0, 2, 0, 0, 0, 2, 2, 0, 2, 1, 0, 0, 1, 1, 2, 0, 0, 1, 1, 0, 2, 2,
       2])

予測精度を求める

In [83]:
from sklearn. metrics import classification_report
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        18
           1       0.77      1.00      0.87        10
           2       1.00      0.82      0.90        17

    accuracy                           0.93        45
   macro avg       0.92      0.94      0.92        45
weighted avg       0.95      0.93      0.93        45

