# Run procedures

## Imports

In [39]:
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.metrics import classification_report

In [32]:
dirDatasets = Path("../datasets/")
features    = ["n_users", "E"]
n_jobs      = -1
np.random.seed(0)

In [33]:
df_train = pd.read_csv(dirDatasets.joinpath("procedure_1_train.csv"))
X_train  = df_train[features]
y_train  = df_train["target"]

In [34]:
gs = GridSearchCV(
    estimator  = DecisionTreeClassifier(),
    param_grid = {"max_depth": np.arange(1, 10)},
    scoring    = "recall",
    n_jobs     = n_jobs,
    cv         = 10
)
gs.fit(X_train, y_train)
print(gs.best_params_)

{'max_depth': 3}


In [35]:
clf = DecisionTreeClassifier(**gs.best_params_)
clf.fit(X_train, y_train)
print(export_text(clf, feature_names=features))

|--- E <= 1.29
|   |--- E <= 1.25
|   |   |--- E <= 1.23
|   |   |   |--- class: False
|   |   |--- E >  1.23
|   |   |   |--- class: False
|   |--- E >  1.25
|   |   |--- n_users <= 216.00
|   |   |   |--- class: True
|   |   |--- n_users >  216.00
|   |   |   |--- class: False
|--- E >  1.29
|   |--- E <= 1.31
|   |   |--- n_users <= 216.00
|   |   |   |--- class: True
|   |   |--- n_users >  216.00
|   |   |   |--- class: False
|   |--- E >  1.31
|   |   |--- E <= 1.34
|   |   |   |--- class: True
|   |   |--- E >  1.34
|   |   |   |--- class: True



In [46]:
df_test_1 = pd.read_csv(dirDatasets.joinpath("procedure_1_test_1.csv"))
X_test_1  = df_test_1[features]
y_test_1  = df_test_1["target"]

In [47]:
y_pred_1 = clf.predict(X_test_1)

In [48]:
print(classification_report(y_test_1, y_pred_1))

              precision    recall  f1-score   support

       False       1.00      1.00      1.00      4100
        True       1.00      1.00      1.00     20500

    accuracy                           1.00     24600
   macro avg       1.00      1.00      1.00     24600
weighted avg       1.00      1.00      1.00     24600

