In [13]:
# importing libraries
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score, confusion_matrix, classification_report
from sklearn.preprocessing import label_binarize

# loading digit dataset
data_set = fetch_openml('mnist_784', version=1)
X, Y = data_set.data / 255.0, data_set.target.astype(int)
Y_bin = label_binarize(Y, classes=np.unique(Y))


X_train, X_test, Y_train, Y_test, Y_bin_train, Y_bin_test = train_test_split(X, Y, Y_bin, test_size=0.2, random_state=72)


clf = DecisionTreeClassifier(max_depth=10, random_state=72)
clf.fit(X_train, Y_train)

# different evaluatiion methods
Y_pred = clf.predict(X_test)
Y_prob = clf.predict_proba(X_test)
print(f"Accuracy: {accuracy_score(Y_test, Y_pred):.4f}")
print(f"Balanced Accuracy: {balanced_accuracy_score(Y_test, Y_pred):.4f}")
print(f"ROC-AUC: {roc_auc_score(Y_bin_test, Y_prob, multi_class='ovr'):.4f}")
print("\nConfusion Matrix:\n", confusion_matrix(Y_test, Y_pred))
print("\nClassification Report:\n", classification_report(Y_test, Y_pred))




Accuracy: 0.8596
Balanced Accuracy: 0.8575
ROC-AUC: 0.9529

Confusion Matrix:
 [[1266    0   23   10    2   19   17    0   18   10]
 [   2 1555   13    9   21    7    0    8   16    6]
 [  26   15 1175   24   20   10   24   41   48   18]
 [  18    8   32 1180    6   82    6   57   32   34]
 [   2    4   15   10 1160   19   16   16   24  114]
 [  16   26   18   82   17  964   26    6   40   24]
 [  20    2   31   15   19   27 1172    3   24    4]
 [   4    8   32    7   16   10    2 1288    7   46]
 [  10   33   56   47   20   28   16   11 1085   73]
 [   8    6   18   27   62   36   11   51   19 1189]]

Classification Report:
               precision    recall  f1-score   support

           0       0.92      0.93      0.93      1365
           1       0.94      0.95      0.94      1637
           2       0.83      0.84      0.84      1401
           3       0.84      0.81      0.82      1455
           4       0.86      0.84      0.85      1380
           5       0.80      0.79      0