In [1]:
import pandas as pd

In [2]:
from sklearn.model_selection import train_test_split

In [3]:
from sklearn.tree import DecisionTreeClassifier

In [4]:
from sklearn.model_selection import GridSearchCV

In [5]:
from sklearn.metrics import classification_report

In [6]:
df = pd.read_csv("data/iris.csv")
df.head()

Unnamed: 0,Sepal_Length,Sepal_Width,Petal_Length,Petal_Width,Species
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa


In [7]:
features = df.columns[:-1].tolist()
x = df[features]
y = df["Species"]

In [8]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3)

In [9]:
parameters = [{"criterion":['gini', 'entropy'], 
               "max_depth": [2, 3, 4], 
               "min_samples_leaf": [2, 3, 4, 5]}]
clf = GridSearchCV(DecisionTreeClassifier(), param_grid=parameters, cv=5)

In [10]:
clf.fit(x_train, y_train)

GridSearchCV(cv=5, estimator=DecisionTreeClassifier(),
             param_grid=[{'criterion': ['gini', 'entropy'],
                          'max_depth': [2, 3, 4],
                          'min_samples_leaf': [2, 3, 4, 5]}])

In [11]:
clf.best_params_

{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 2}

In [12]:
print(classification_report(y_train, clf.predict(x_train)))

              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        39
  versicolor       1.00      0.94      0.97        33
   virginica       0.94      1.00      0.97        33

    accuracy                           0.98       105
   macro avg       0.98      0.98      0.98       105
weighted avg       0.98      0.98      0.98       105



In [13]:
print(classification_report(y_test, clf.predict(x_test)))

              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        11
  versicolor       1.00      0.82      0.90        17
   virginica       0.85      1.00      0.92        17

    accuracy                           0.93        45
   macro avg       0.95      0.94      0.94        45
weighted avg       0.94      0.93      0.93        45



In [14]:
clf.cv_results_

{'mean_fit_time': array([0.00473962, 0.00335541, 0.0034627 , 0.00341463, 0.0039773 ,
        0.00335274, 0.00336151, 0.0030499 , 0.00304198, 0.00370255,
        0.00396204, 0.00345945, 0.00319061, 0.00320544, 0.00324368,
        0.00321021, 0.00322104, 0.00321388, 0.00322151, 0.00342722,
        0.00322623, 0.00323215, 0.00324464, 0.00318022]),
 'std_fit_time': array([1.44700198e-03, 7.23925499e-05, 5.34388099e-04, 2.24704673e-04,
        1.43692291e-03, 5.25384106e-05, 6.66759883e-05, 3.74366472e-04,
        4.29299780e-04, 3.46338504e-04, 1.47953204e-04, 2.95965722e-05,
        2.10951810e-05, 2.51766407e-05, 5.71311811e-05, 2.15589465e-05,
        3.74986269e-05, 3.44298222e-05, 3.50661780e-05, 3.72222378e-04,
        4.04664252e-05, 2.61412029e-05, 3.95570840e-05, 5.28006244e-05]),
 'mean_score_time': array([0.00295057, 0.00254006, 0.00261984, 0.00256524, 0.00237694,
        0.00243626, 0.00243092, 0.00229945, 0.00239587, 0.00269108,
        0.00280781, 0.00250826, 0.00232158, 0.00

In [15]:
for param, acc, std in zip(clf.cv_results_['params'], clf.cv_results_['mean_test_score'], clf.cv_results_['std_test_score']):
    print(param, acc, '+/-', 3 * std)

{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 2} 0.9333333333333332 +/- 0.06998542122237644
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 3} 0.9333333333333332 +/- 0.06998542122237644
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 4} 0.9333333333333332 +/- 0.06998542122237644
{'criterion': 'gini', 'max_depth': 2, 'min_samples_leaf': 5} 0.9333333333333332 +/- 0.06998542122237644
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 2} 0.9619047619047618 +/- 0.10690449676496977
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 3} 0.9619047619047618 +/- 0.10690449676496977
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 4} 0.9523809523809523 +/- 0.12777531299998796
{'criterion': 'gini', 'max_depth': 3, 'min_samples_leaf': 5} 0.9428571428571428 +/- 0.10690449676496973
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 2} 0.9428571428571428 +/- 0.057142857142857065
{'criterion': 'gini', 'max_depth': 4, 'min_samples_leaf': 3} 0.