# Decision Tree & Successive Halving Random + Search Example

In [1]:
%load_ext watermark
%watermark -p scikit-learn,mlxtend

scikit-learn: 1.0
mlxtend     : 0.19.0



## Dataset

In [2]:
from sklearn import model_selection
from sklearn.model_selection import train_test_split
from sklearn import datasets


data = datasets.load_breast_cancer()
X, y = data.data, data.target

X_train, X_test, y_train, y_test = \
    train_test_split(X, y, test_size=0.3, random_state=1, stratify=y)

X_train_sub, X_valid, y_train_sub, y_valid = \
    train_test_split(X_train, y_train, test_size=0.2, random_state=1, stratify=y_train)

print('Train/Valid/Test sizes:', y_train.shape[0], y_valid.shape[0], y_test.shape[0])

Train/Valid/Test sizes: 398 80 171


## Successive Halving + Random Search


- More info: 
  - https://scikit-learn.org/stable/modules/grid_search.html#successive-halving-user-guide
  - https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.HalvingRandomSearchCV.html#sklearn.model_selection.HalvingRandomSearchCV

In [3]:
import numpy as np
import scipy.stats

from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import HalvingRandomSearchCV

from sklearn.tree import DecisionTreeClassifier


clf = DecisionTreeClassifier(random_state=123)

params =  {
    'min_samples_split': scipy.stats.randint(2, 12),
    'min_impurity_decrease': scipy.stats.uniform(0.0, 0.5),
    'max_depth': [6, 16, None]
}


search = HalvingRandomSearchCV(
    estimator=clf, 
    param_distributions=params,
    n_candidates='exhaust',
    resource='n_samples',
    factor=3,
    random_state=123,
    n_jobs=1)


search.fit(X_train, y_train)

search.best_score_

0.8882539682539681

In [4]:
search.best_params_

{'max_depth': None,
 'min_impurity_decrease': 0.029838948304784174,
 'min_samples_split': 2}

In [5]:
print(f"Training Accuracy: {search.best_estimator_.score(X_train, y_train):0.2f}")
print(f"Test Accuracy: {search.best_estimator_.score(X_test, y_test):0.2f}")

Training Accuracy: 0.95
Test Accuracy: 0.94
