# IS 6482 - Week 3 — Classification Metrics and Cross-Validation

**Author:** Varun Gupta

**Agenda:** Confusion Matrix, Classification Metrics, ROC Curve, Cross-Validation  

**Libraries:** `sklearn.tree`, `sklearn.metrics`, `sklearn.model_selection`

**Datasets:** Telco customer churn

---

### Learning goals
By the end of this notebook, you should be able to:

1. Learn how to create **train/test splits** to create a hold out set for model evaluation
2. Produce **confusion matrix** and metrics like **Precision**, **Recall**, **F-score**
3. Report **ROC curve** and **AUC** by using Decision Tree as a probabilistic classifier
4. Use **cross-validation** for model selection (in the case of Decision Tree this would be number of leaves, or pruning parameter)

**Dataset**: Telco Customer Churn (IBM sample) mirrored as a CSV in a public repo.

In [1]:
# ============================================================
# 0) Imports + plotting defaults (make plots readable in slides)
# ============================================================

import numpy as np
import pandas as pd

# Matplotlib / Seaborn for plots
import matplotlib.pyplot as plt
import seaborn as sns

# Scikit-learn for modeling
from sklearn import tree
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold, KFold, cross_validate
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report, RocCurveDisplay, roc_auc_score, get_scorer

# Make results reproducible
RANDOM_STATE = 42

# Part A -- Load the dataset and explore shape, columns

We'll load the Telco churn dataset from a hosted CSV. In a business setting this might come from a database / data warehouse.

Each row is a customer, and the goal is to predict **`Churn`** (whether the customer left).

In [2]:
# ============================
# Load the Telco churn dataset
# ============================

telco_url = "https://raw.githubusercontent.com/plotly/datasets/master/telco-customer-churn-by-IBM.csv"

# Read the CSV into a pandas DataFrame (table)
churn_df = pd.read_csv(telco_url)

# Quick peek
print("Shape:", churn_df.shape)
churn_df.head()

Shape: (7043, 21)


Unnamed: 0,customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,...,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
0,7590-VHVEG,Female,0,Yes,No,1,No,No phone service,DSL,No,...,No,No,No,No,Month-to-month,Yes,Electronic check,29.85,29.85,No
1,5575-GNVDE,Male,0,No,No,34,Yes,No,DSL,Yes,...,Yes,No,No,No,One year,No,Mailed check,56.95,1889.5,No
2,3668-QPYBK,Male,0,No,No,2,Yes,No,DSL,Yes,...,No,No,No,No,Month-to-month,Yes,Mailed check,53.85,108.15,Yes
3,7795-CFOCW,Male,0,No,No,45,No,No phone service,DSL,Yes,...,Yes,Yes,No,No,One year,No,Bank transfer (automatic),42.3,1840.75,No
4,9237-HQITU,Female,0,No,No,2,Yes,No,Fiber optic,No,...,No,No,No,No,Month-to-month,Yes,Electronic check,70.7,151.65,Yes


## 1) First sanity checks (structure, types, missing values)

Before modeling, always ask:
- What columns do we have?
- Are there missing values?
- Are numeric columns accidentally loaded as strings?

In [3]:
# Basic schema checks
churn_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7043 entries, 0 to 7042
Data columns (total 21 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   customerID        7043 non-null   object 
 1   gender            7043 non-null   object 
 2   SeniorCitizen     7043 non-null   int64  
 3   Partner           7043 non-null   object 
 4   Dependents        7043 non-null   object 
 5   tenure            7043 non-null   int64  
 6   PhoneService      7043 non-null   object 
 7   MultipleLines     7043 non-null   object 
 8   InternetService   7043 non-null   object 
 9   OnlineSecurity    7043 non-null   object 
 10  OnlineBackup      7043 non-null   object 
 11  DeviceProtection  7043 non-null   object 
 12  TechSupport       7043 non-null   object 
 13  StreamingTV       7043 non-null   object 
 14  StreamingMovies   7043 non-null   object 
 15  Contract          7043 non-null   object 
 16  PaperlessBilling  7043 non-null   object 


## 2) Minimal cleaning for modeling

We'll do only the cleaning necessary to make the dataset usable for scikit-learn:

1. Drop identifiers (e.g., `customerID`) — these usually do not help prediction.
2. Convert `TotalCharges` to numeric (it often loads as an object/string).
3. Deal with rows where conversion created missing values.

In [4]:
# 1) Drop ID column if present
if "customerID" in churn_df.columns:
    churn_df = churn_df.drop(columns=["customerID"])

# 2) Convert TotalCharges to numeric (invalid parses become NaN)
#    This is common when a column is "mostly numeric" but has blanks.
if "TotalCharges" in churn_df.columns:
    churn_df["TotalCharges"] = pd.to_numeric(churn_df["TotalCharges"], errors="coerce")

# Check what changed
churn_df["TotalCharges"].info()

<class 'pandas.core.series.Series'>
RangeIndex: 7043 entries, 0 to 7042
Series name: TotalCharges
Non-Null Count  Dtype  
--------------  -----  
7032 non-null   float64
dtypes: float64(1)
memory usage: 55.2 KB


In [5]:
# 3) Inspect rows where TotalCharges is missing after conversion
rows_with_nulls = churn_df["TotalCharges"].isna()

print("Rows with TotalCharges missing:", rows_with_nulls.sum())
churn_df.loc[rows_with_nulls, ["tenure", "MonthlyCharges", "TotalCharges", "Churn"]].head(10)

Rows with TotalCharges missing: 11


Unnamed: 0,tenure,MonthlyCharges,TotalCharges,Churn
488,0,52.55,,No
753,0,20.25,,No
936,0,80.85,,No
1082,0,25.75,,No
1340,0,56.05,,No
3331,0,19.85,,No
3826,0,25.35,,No
4380,0,20.0,,No
5218,0,19.7,,No
6670,0,73.35,,No


In this dataset, missing `TotalCharges` is typically associated with **tenure = 0** (new customers).
For our purposes it is safe to drop these (they are very few, and we know they have not churned since they are still in their first month).

> In a real project you would decide this more carefully (e.g., imputing, adding a missingness indicator, etc.).

In [6]:
# Drop rows with missing TotalCharges
churn_df = churn_df.dropna(subset=["TotalCharges"]).copy()

print("Shape after dropping missing TotalCharges rows:", churn_df.shape)

Shape after dropping missing TotalCharges rows: (7032, 20)


## 4) Separate target (`y`) from features (`X`)

The models in sklearn take the data separated into features and labels/target, we will create:
- `y` = the churn label (True/False)
- `X` = all predictor columns

We will also convert the object columns to category

In [7]:
# Separate the target column
y = churn_df["Churn"].astype(str).str.strip().str.lower()  # normalize strings like " Yes"
y = (y == "yes")                                          # convert to boolean (True = churn)

# A human-friendly version for plots/tables (keeps notebooks readable)
y_label = y.map({False: "No churn", True: "Churn"})

# Drop target from features
X = churn_df.drop(columns=["Churn"])

print("X shape:", X.shape)
print("Churn rate:", y.mean())

X shape: (7032, 19)
Churn rate: 0.26578498293515357


In [None]:
# Convert the columns which should be categorical but are int (SeniorCitizen) or object to category
# First convert Senior Citizen to boolean (True/False) for interpretability later
X["SeniorCitizen"] = (X["SeniorCitizen"] == 1)

# Now convert the object columns and SeniorCitizen to categorical
object_cols = X.select_dtypes(include=["object"]).columns.tolist()
X[["SeniorCitizen"] + object_cols] = X[["SeniorCitizen"] + object_cols].astype("category")
X.info()
# Notice the size decreased from 1 MB to 300 KB

# Part B — Prepare features for scikit-learn (dummy encoding)

Scikit-learn’s decision trees can not handle categorical variables, and needs columns that are **numeric or boolean**.

We'll use **one-hot encoding** (a.k.a. dummy variables) to convert categorical columns into 0/1 indicator columns.

In [None]:
# Convert object columns to 'category' dtype (helps keep track of what is categorical)
# Note: If we just say X_encode = X, X_encode just points to X instead of making a new copy
X_encoded = X.copy()

# We will one-hot encode object columns.
cat_cols = X_encoded.select_dtypes(include=["category"]).columns.tolist()

# One-hot encode categorical (object) columns as boolean.
X_encoded = pd.get_dummies(X_encoded, columns=cat_cols)

print("Encoded X shape:", X_encoded.shape)
X_encoded.head()

# Part C — Train a few trees (different pruning strengths)

Decision trees can overfit easily. To control the complexity, we use the Cost-complexity pruning parameter `ccp_alpha`.

- **Small `ccp_alpha`** → less pruning → larger tree (more complex)
- **Large `ccp_alpha`** → more pruning → smaller tree (simpler)

In [None]:
# Train two trees with different pruning strengths (alpha values)
tree_alpha_high = 0.05   # more pruning (simpler tree)

# impurity criterion is gini by default
# we can optionally pass arguments to enable Pre-Pruning
# e.g. max_depth, min_samples_split, min_samples_leaf, max_leaf_nodes,
#      min_impurity_decrease
tree_model_simple = tree.DecisionTreeClassifier(
    random_state=RANDOM_STATE,
    ccp_alpha=tree_alpha_high,
    criterion="gini"
)

# Note, there is a lot going on behind the scenes here fitting is a complex
# process the first argument is a dataset of the predictors. the second is a
# series of the target or y variable.
tree_model_simple.fit(X_encoded, y)

# Visualize the tree
fig = plt.figure(figsize=(7,7))
_ = tree.plot_tree(tree_model_simple,
                   feature_names=X_encoded.columns.to_list(), # make sure the feature names are in output
                   filled=True) # filled true color codes by the class. shading indicates proportion or quality of split

In [None]:
# Now we build a slightly more complex tree by reducing the complexity penalty
tree_alpha_low  = 0.0   # less pruning (more complex tree)

tree_model_complex = tree.DecisionTreeClassifier(
    random_state=RANDOM_STATE,
    ccp_alpha=tree_alpha_low,
    max_leaf_nodes=10,
    criterion="gini"
)
tree_model_complex.fit(X_encoded, y)
fig = plt.figure(figsize=(7,7))
_ = tree.plot_tree(tree_model_complex,
                   feature_names=X_encoded.columns.to_list(), # make sure the feature names are in output
                   filled=True) # filled true color codes by the class. shading indicates proportion or quality of split


# Part D -- Confusion Matrix and Classification report

## Confusion Matrix

In [None]:
# First produce the predictions from the model
model_preds = tree_model_simple.predict(X_encoded)
model_cm = confusion_matrix(y_true=y, y_pred = model_preds)
model_cm

In [None]:
# Print the confusion matrix in a prettier format by converting it to a DataFrame
class_names = y.unique()
cm_df = pd.DataFrame(model_cm, index = class_names, columns = class_names)
display(cm_df)

In [None]:
ConfusionMatrixDisplay.from_estimator(tree_model_simple, X_encoded, y)

## Classification Report

In [None]:
# Produce the Classification report which has Precision, Recall, Accuracy, F-score
model_cr = classification_report(y, model_preds)
print(model_cr)

In [None]:
# Same thing for tree_model_complex
# Predict Labels
model_preds = tree_model_complex.predict(X_encoded)
# Compute Confusion Matrix by passing true labels, and predictions
model_cm = confusion_matrix(y_true=y, y_pred = model_preds)
class_names = y.unique()
cm_df = pd.DataFrame(model_cm, index = class_names, columns = class_names)
display(cm_df)
# Confusion Matrix
model_cr = classification_report(y, model_preds)
print(model_cr)

# Part E — ROC curve and AUC

We will use `RocCurveDisplay` from `sklearn.metrics` to produce an ROC curve. This function extracts the class probabilities, and adjusts the threshold of a threshold based rule to produce the TPR vs. FPR plot and also computes Area Under Curve (AUC)

In [None]:
# Let us first see how we can extract the predictions from the Decision Tree we fitted
model_preds_prob_sample = tree_model_complex.predict_proba(X_encoded.iloc[0:5,:])
model_preds_sample = tree_model_complex.predict(X_encoded.iloc[0:5,:])
print(model_preds_prob_sample)
print(model_preds_sample)

In [None]:
# Plot the ROC curve
roc_display = RocCurveDisplay.from_estimator(tree_model_complex, X_encoded, y, name='Decision Tree')
# Print the AUC score
auc_score = roc_display.roc_auc
print(f'AUC for the Decision Tree Classifier: {auc_score:.3f}')

In [None]:
# Can also get auc score directly by passing true target values, probabilities (or scores)
auc_score = roc_auc_score(y, tree_model_complex.predict_proba(X_encoded)[:,1])
auc_score

# Part F — Train/Test split -- Model evaluation

We care about performance on **future unseen customers** (population/generalization), not performance on the training sample.

If we evaluate on the same data we trained on, performance is usually **too optimistic**.

We will:
1. Split our data into train and test sets
2. Train a simple and a complex tree model on the train split
3. Compute the classification metrics on test split

In [None]:
# Train/test split
# Saving 30% of data for testing
X_train, X_test, y_train, y_test = train_test_split(
    X_encoded, y,
    test_size=0.30,
    random_state=RANDOM_STATE,
    stratify=y  # keep class balance similar in train and test
)

print("Train shape:", X_train.shape, "| Test shape:", X_test.shape)
# The following command shows that the splitting preserved class balance across
# test and train splits
print("Train churn rate:", y_train.mean(), "| Test churn rate:", y_test.mean())

In [None]:
# Fit the simple and complex models
tree_model_simple.fit(X_train, y_train)
tree_model_complex.fit(X_train, y_train)

In [None]:
# For the simple model produce the Classification Report on train and test data
print(classification_report(y_train, tree_model_simple.predict(X_train)))
print(classification_report(y_test, tree_model_simple.predict(X_test)))

In [None]:
# For the simple model produce the Classification Report on train and test data

# Note: You should not report classification metrics on training data, we are doing it here just to contrast with metrics on test split
print(classification_report(y_train, tree_model_complex.predict(X_train)))
# Note: This is what you should actually look at to estimate the performance you should expect on unseen data
print(classification_report(y_test, tree_model_complex.predict(X_test)))

In [None]:
# Let us check ROC, AUC on the test set as well
roc_display = RocCurveDisplay.from_estimator(tree_model_complex, X_test, y_test, name='Decision Tree')
# Print the AUC score
auc_score = roc_display.roc_auc
print(f'AUC for the Decision Tree Classifier: {auc_score:.3f}')

# Part G — Model Selection using Cross-Validation

Model Selection means deciding how complex a Decision Tree is appropriate to fit to the data set we have. The steps are as follows:

1. First do a train/test split, say 80/20. Keep the test set apart. Make sure class balances is maintained by doing this split (i.e. fraction of y=Yes in train and test are the same). For reproducibility, explicitly pass the random number seed.
2. Split the train set into K folds of (roughly equal) sizes. Again make sure this is done while stratifying on the target. For reproducibility, explicitly pass the random number seed.
3. For each model complexity parameter (e.g. number of leaves),
    1. For each i in (1,...,K):
        - train the Decision Tree on training set with fold i excluded
        - compute the accuracy of this model on fold i. This is an estimate of the generalization performance (error on unseen test data) for this model complexity.
        - save this accuracy metric
    2. Compute summary of the K accuracy measurements (mean, standard deviation)
4. Based on the cross-validation performance, choose the **model class** which is estimated to generalize the best
5. Train the model again, this time on the entire training set.
6. Compute the performance on the hold-out test set (from step 1). This is the estimate of your model's performance on the population distribution.

In [None]:
# Setting up a random K-fold and stratified K-fold samplers

num_folds = 3

print("Stratified K-Fold")
print("-"*20)

# Define how we will create the folds (this is where we can pass the seed for random generator
cv_strategy = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=RANDOM_STATE)
# Iterate through the folds and report fraction yes
for i, (train_indices, test_indices) in enumerate(cv_strategy.split(X_train, y_train)):
    print(f"Fold {i}:")
    print(f"  Test indices = {test_indices}")
    print(f"  Fraction Yes = {y_train.iloc[test_indices].mean()}")

cv_strategy = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=RANDOM_STATE)
indices = cv_strategy.get_n_splits(X_train, y_train)

print("="*80)
print("Random K-Fold")
print("-"*20)

# Define how we will create the folds (this is where we can pass the seed for random generator
cv_strategy = KFold(n_splits=num_folds, shuffle=True, random_state=RANDOM_STATE)
# Iterate through the folds and report fraction yes
for i, (train_indices, test_indices) in enumerate(cv_strategy.split(X_train, y_train)):
    print(f"Fold {i}:")
    print(f"  Test indices = {test_indices}")
    print(f"  Fraction Yes = {y_train.iloc[test_indices].mean()}")

In [None]:
# Do cross-validation
# ====================

# List of -log_2 ccp_alphas we will try (ccp_alpha will be: 1/4, 1/16, 1/64, ...)
# We can also use cost_complexity_pruning_path on training data to get a potential list of alphas
minus_log2_alpha_list = np.arange(2,21,2)

# Number of folds for k-fold cross validation
num_folds = 10

# Define how we will create the folds (this is where we can pass the seed for random generator)
cv_strategy = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=RANDOM_STATE)

# This dataframe will store our results
cv_results_df = pd.DataFrame(columns=['minus_log2_alpha', 'mean', 'stddev', 'test_accuracy', 'test_auc'])

# Iterate over ccp_alpha parameter
for i in minus_log2_alpha_list:
    # Define the Tree classifier object
    tree_model = tree.DecisionTreeClassifier(random_state = RANDOM_STATE, ccp_alpha = 2.0**(-i))

    # Do the cross-validation runs and collate scores for the num_folds runs
    # There is a lot that is happening in this function!!
    result_list = cross_val_score(tree_model, X_train, y_train, cv = cv_strategy, scoring = 'accuracy')

    # ============================================================
    # Fit the model with this ccp_alpha on the full training data
    # In real workflows you will NOT do this for every alpha, but for the alpha
    # corresponding to the model you select based on Cross-validation!
    # We are doing this here only for illustration
    tree_model.fit(X_train, y_train)

    # ============================================================
    # Find accuracy on the hold out test set
    # NOTE: We are peeking at the test set across alphas for illustration only; do not do this in real workflows.
    # In real workflows you will only do this at the end for the one alpha you
    # select as your chosen model complexity parameter!
    test_score = tree_model.score(X_test, y_test)

    # ============================================================
    # Let us also look at the auc score on test data
    # In real workflows you will only do this at the end for the one alpha you
    # select as your chosen model complexity parameter!
    auc_score = roc_auc_score(y_test, tree_model.predict_proba(X_test)[:,1])

    # Add this row to the data frame
    new_data = {'minus_log2_alpha': i, 'mean': result_list.mean() ,
                'stddev' : result_list.std(ddof=1) , 'test_accuracy' : test_score,
                'test_auc' : auc_score}
    cv_results_df.loc[len(cv_results_df)] = new_data

In [None]:
# Plot results
# ===================

fig, ax = plt.subplots(figsize=(7, 5))

# Produce a plot of mean cross-validation accuracy for each model class
ax.plot(cv_results_df['minus_log2_alpha'], cv_results_df['mean'], marker = 'o', label = 'Mean CV accuracy')

# Overlay "Error bars" using standard deviation of CV accuracy on the folds
ax.fill_between(cv_results_df['minus_log2_alpha'],
                 cv_results_df['mean'] - cv_results_df['stddev'],
                 cv_results_df['mean'] + cv_results_df['stddev'],
                 alpha=0.2,
                 label="CV ±1 SD")

# Overlay Test Accuracy (on held out test data)
ax.plot(cv_results_df['minus_log2_alpha'], cv_results_df['test_accuracy'], marker = 'o', linestyle ='--' , label = 'Test accuracy')
# ticks and labels
ax.set_xticks(minus_log2_alpha_list)
ax.set_xlabel(r'$-\log_2 \alpha$')
ax.set_ylabel('Accuracy')
ax.legend()

# Add plot for AUC on test data using secondary axis
color2 = 'tab:red'
ax2 = ax.twinx()
ax2.plot(cv_results_df['minus_log2_alpha'], cv_results_df['test_auc'], marker = 'o', color = color2, label = 'AUC')
ax2.set_ylabel('AUC', color = color2)
ax2.tick_params(axis='y', labelcolor=color2)

plt.show()
plt.tight_layout()


In [None]:
# YOUDO: Experiment with this block
# The cross_validate function gives more flexibility in using different scoring
# rules to perform cross validation

num_folds = 10

# 1. Pick ONE scoring rule (string) for the sweep:
scoring_rule = "roc_auc"
# scoring_rule = "accuracy"
# scoring_rule = "balanced_accuracy"
# scoring_rule = "precision"
# scoring_rule = "recall"
# scoring_rule = "f1"
# scoring_rule = "f1_macro"
# scoring_rule = "average_precision"  # PR-AUC

# You can also see MULTIPLE metrics in one CV call and then select one to plot or use.
# Uncomment this and set metric_to_plot accordingly.
# scoring = {
#     "acc": "accuracy",
#     "bal_acc": "balanced_accuracy",
#     "prec": "precision",
#     "rec": "recall",
#     "f1": "f1",
#     "auc": "roc_auc",
#     "ap": "average_precision",
# }
# metric_to_plot = "auc"   # must be one of the keys above

cv = StratifiedKFold(n_splits = num_folds, shuffle=True, random_state=RANDOM_STATE)

# A scorer object lets you evaluate the *test set* using the exact same scoring rule string.
# (Works for metrics like roc_auc that need predict_proba/decision_function too.)
scorer = get_scorer(scoring_rule)

# List of -log_2 ccp_alphas we will try (ccp_alpha will be: 1/4, 1/16, 1/64, ...)
# We can also use cost_complexity_pruning_path on training data to get a potential list of alphas
minus_log2_alpha_list = np.arange(2,21,2)
alphas = 0.5**minus_log2_alpha_list

mean_cv_scores = []
std_cv_scores  = []
test_scores    = []

for alpha in alphas:
    # Create a fresh estimator for this alpha
    tree_model = tree.DecisionTreeClassifier(
        random_state=RANDOM_STATE,
        ccp_alpha=float(alpha),
        # max_leaf_nodes=10,  # keep if you used it elsewhere; otherwise remove for "pure" pruning
    )

    # --- Cross-validated estimate (replaces cross_val_score) ---
    cv_out = cross_validate(
        tree_model,
        X_train,
        y_train,
        cv=cv,
        scoring=scoring_rule,        # or scoring=scoring for multiple metrics
        return_train_score=False
    )

    fold_scores = cv_out["test_score"]  # for multi-metric: cv_out["test_auc"], etc.
    mean_cv_scores.append(np.mean(fold_scores))
    std_cv_scores.append(np.std(fold_scores, ddof=1))

    # --- Test-set score for THIS alpha ---
     # NOTE: This "peeks" at test across alphas. We are doing this for illustration only!
    tree_model.fit(X_train, y_train)
    test_scores.append(scorer(tree_model, X_test, y_test))

mean_cv_scores = np.asarray(mean_cv_scores)
std_cv_scores  = np.asarray(std_cv_scores)
test_scores    = np.asarray(test_scores)

# Pick alpha by CV (max metric)
best_idx = int(np.argmax(mean_cv_scores))
best_alpha = float(alphas[best_idx])

print(f"Best alpha by CV ({scoring_rule}): {best_alpha:g}")
print(f"  CV mean = {mean_cv_scores[best_idx]:.4f}")
print(f"  Test    = {test_scores[best_idx]:.4f}  (for this alpha)")


In [None]:
# -------------------- Plot CV vs Test across alphas --------------------
fig, ax = plt.subplots()

ax.plot(alphas, mean_cv_scores, marker="o", label=f"CV mean ({scoring_rule})")
ax.fill_between(
    alphas,
    mean_cv_scores - std_cv_scores,
    mean_cv_scores + std_cv_scores,
    alpha=0.2,
    label="CV ±1 SD"
)
ax.plot(alphas, test_scores, marker="o", linestyle="--", label=f"Test ({scoring_rule})")

# Log x-axis only if all alphas are > 0 (log(0) not allowed)
if np.all(alphas > 0):
    ax.set_xscale("log")

ax.set_xlabel("ccp_alpha")
ax.set_ylabel(scoring_rule)
ax.set_title(f"Decision Tree pruning sweep: {scoring_rule}")
ax.legend()
plt.show()

# -------------------- Plot ROC curve on TEST for best alpha --------------------
final_model = tree.DecisionTreeClassifier(random_state=0, ccp_alpha=best_alpha)
final_model.fit(X_train, y_train)
RocCurveDisplay.from_estimator(final_model, X_test, y_test)
plt.title(f"Test ROC curve (best alpha = {best_alpha:g})")
plt.show()


# Summary / takeaway

- Decision trees can fit complex patterns, but they can also **overfit**.
- Use a **validation set / cross-validation** to choose model complexity properly.
- Next week: Compare a single tree to ensembles (Random Forest, Gradient Boosting).