# ML501: Decision Trees - CART for Classification & Regression

---

## Learning Objectives

By the end of this notebook, you will be able to:

1. Explain the CART algorithm and recursive binary splitting
2. Understand Gini impurity and entropy as classification split criteria
3. Understand MSE as a regression split criterion
4. Train and tune `DecisionTreeClassifier` and `DecisionTreeRegressor` in scikit-learn
5. Visualize decision trees and interpret feature importances
6. Identify and mitigate overfitting in decision trees

## Prerequisites

- Python fundamentals (loops, functions, classes)
- NumPy and pandas basics
- Basic understanding of supervised learning (classification vs regression)
- Familiarity with train/test splitting and model evaluation

## Table of Contents

1. [CART Algorithm Theory](#1-cart-algorithm-theory)
2. [Classification Split Criteria](#2-classification-split-criteria)
3. [Regression Split Criteria](#3-regression-split-criteria)
4. [Classification Demo: Iris Dataset](#4-classification-demo-iris-dataset)
5. [Tree Visualization and Feature Importance](#5-tree-visualization-and-feature-importance)
6. [Overfitting: Deep vs Pruned Trees](#6-overfitting-deep-vs-pruned-trees)
7. [Regression Demo: Synthetic Data](#7-regression-demo-synthetic-data)
8. [Common Mistakes](#8-common-mistakes)
9. [Exercises](#9-exercises)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris, make_regression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, mean_squared_error

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
sns.set_style('whitegrid')
np.random.seed(42)

## 1. CART Algorithm Theory

**CART (Classification and Regression Trees)** builds a binary tree by recursively splitting the data.

### Recursive Binary Splitting

At each node, the algorithm:
1. Considers every feature and every possible split point
2. Selects the split that minimizes a cost function (impurity for classification, MSE for regression)
3. Partitions the data into two child nodes
4. Repeats until a stopping criterion is met (max depth, min samples, pure node, etc.)

This is a **greedy** algorithm -- it picks the locally optimal split at each step without looking ahead.

### Key Properties of Decision Trees

| Property | Description |
|----------|-------------|
| Non-parametric | No assumptions about data distribution |
| Interpretable | Easy to visualize and explain |
| No scaling needed | Features do not require normalization |
| Handles mixed types | Works with numerical and categorical features |
| Prone to overfitting | Can memorize training data if unconstrained |

## 2. Classification Split Criteria

For classification, CART uses impurity measures to decide the best split.

### Gini Impurity

$$G = 1 - \sum_{k=1}^{K} p_k^2$$

where $p_k$ is the proportion of class $k$ in the node. Gini ranges from 0 (pure node) to $1 - 1/K$ (maximally impure).

### Entropy (Information Gain)

$$H = -\sum_{k=1}^{K} p_k \log_2 p_k$$

Entropy ranges from 0 (pure) to $\log_2 K$ (maximally impure).

In practice, Gini and entropy produce very similar trees. Gini is the default in scikit-learn because it is slightly faster to compute (no logarithm).

In [None]:
# Visualize Gini vs Entropy for a binary classification problem
p = np.linspace(0.001, 0.999, 200)
gini = 1 - p**2 - (1 - p)**2
entropy = -p * np.log2(p) - (1 - p) * np.log2(1 - p)

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(p, gini, label='Gini Impurity', linewidth=2)
ax.plot(p, entropy, label='Entropy', linewidth=2)
ax.set_xlabel('Proportion of Class 1 ($p$)')
ax.set_ylabel('Impurity')
ax.set_title('Gini Impurity vs Entropy (Binary Classification)')
ax.legend()
ax.axvline(x=0.5, color='gray', linestyle='--', alpha=0.5, label='Max impurity')
plt.tight_layout()
plt.show()

## 3. Regression Split Criteria

For regression trees, the split criterion is **Mean Squared Error (MSE)**:

$$\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \bar{y})^2$$

At each node, the algorithm finds the split that minimizes the weighted sum of MSE in the two child nodes. The prediction at each leaf is the mean of the target values in that leaf.

## 4. Classification Demo: Iris Dataset

In [None]:
# Load and explore the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
target_names = iris.target_names

print(f"Features: {feature_names}")
print(f"Classes: {target_names}")
print(f"Shape: {X.shape}")
print(f"Class distribution: {np.bincount(y)}")

In [None]:
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# Train a decision tree with limited depth for interpretability
clf = DecisionTreeClassifier(
    criterion='gini',
    max_depth=3,
    min_samples_split=5,
    min_samples_leaf=2,
    random_state=42
)
clf.fit(X_train, y_train)

train_acc = accuracy_score(y_train, clf.predict(X_train))
test_acc = accuracy_score(y_test, clf.predict(X_test))
print(f"Training accuracy: {train_acc:.4f}")
print(f"Test accuracy:     {test_acc:.4f}")

### Key Parameters of `DecisionTreeClassifier`

| Parameter | Description | Typical Values |
|-----------|-------------|----------------|
| `criterion` | Split quality measure | `'gini'` (default), `'entropy'` |
| `max_depth` | Maximum depth of the tree | 3-10 or `None` |
| `min_samples_split` | Min samples to split a node | 2-20 |
| `min_samples_leaf` | Min samples in a leaf node | 1-10 |

## 5. Tree Visualization and Feature Importance

In [None]:
# Visualize the decision tree
fig, ax = plt.subplots(figsize=(16, 8))
plot_tree(
    clf,
    feature_names=feature_names,
    class_names=target_names,
    filled=True,
    rounded=True,
    fontsize=10,
    ax=ax
)
ax.set_title('Decision Tree on Iris Dataset (max_depth=3)')
plt.tight_layout()
plt.show()

In [None]:
# Feature importance
importances = clf.feature_importances_
feature_imp = pd.Series(importances, index=feature_names).sort_values(ascending=True)

fig, ax = plt.subplots(figsize=(8, 5))
feature_imp.plot(kind='barh', ax=ax, color='steelblue')
ax.set_xlabel('Feature Importance (Gini)')
ax.set_title('Feature Importance from Decision Tree')
plt.tight_layout()
plt.show()

print("Feature importances:")
for name, imp in zip(feature_names, importances):
    print(f"  {name}: {imp:.4f}")

## 6. Overfitting: Deep vs Pruned Trees

In [None]:
# Compare performance across different max_depth values
depths = range(1, 16)
train_scores = []
test_scores = []

for d in depths:
    tree = DecisionTreeClassifier(max_depth=d, random_state=42)
    tree.fit(X_train, y_train)
    train_scores.append(accuracy_score(y_train, tree.predict(X_train)))
    test_scores.append(accuracy_score(y_test, tree.predict(X_test)))

fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(depths, train_scores, 'o-', label='Training Accuracy', linewidth=2)
ax.plot(depths, test_scores, 's-', label='Test Accuracy', linewidth=2)
ax.set_xlabel('max_depth')
ax.set_ylabel('Accuracy')
ax.set_title('Overfitting: Training vs Test Accuracy by Tree Depth')
ax.legend()
ax.set_xticks(list(depths))
plt.tight_layout()
plt.show()

print("Notice: training accuracy reaches 1.0 quickly,")
print("but test accuracy plateaus or even decreases -- classic overfitting.")

In [None]:
# Deep tree vs pruned tree
deep_tree = DecisionTreeClassifier(max_depth=None, random_state=42)
deep_tree.fit(X_train, y_train)

pruned_tree = DecisionTreeClassifier(max_depth=3, min_samples_leaf=5, random_state=42)
pruned_tree.fit(X_train, y_train)

print("Deep Tree (no constraints):")
print(f"  Depth: {deep_tree.get_depth()}, Leaves: {deep_tree.get_n_leaves()}")
print(f"  Train acc: {accuracy_score(y_train, deep_tree.predict(X_train)):.4f}")
print(f"  Test acc:  {accuracy_score(y_test, deep_tree.predict(X_test)):.4f}")
print()
print("Pruned Tree (max_depth=3, min_samples_leaf=5):")
print(f"  Depth: {pruned_tree.get_depth()}, Leaves: {pruned_tree.get_n_leaves()}")
print(f"  Train acc: {accuracy_score(y_train, pruned_tree.predict(X_train)):.4f}")
print(f"  Test acc:  {accuracy_score(y_test, pruned_tree.predict(X_test)):.4f}")

## 7. Regression Demo: Synthetic Data

In [None]:
# Generate synthetic non-linear data
np.random.seed(42)
X_reg = np.sort(5 * np.random.rand(200, 1), axis=0)
y_reg = np.sin(X_reg).ravel() + 0.3 * np.random.randn(200)

# Fit regression trees with different depths
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

for ax, depth in zip(axes, [2, 5, None]):
    reg = DecisionTreeRegressor(max_depth=depth, random_state=42)
    reg.fit(X_reg, y_reg)
    
    X_test_reg = np.linspace(0, 5, 500).reshape(-1, 1)
    y_pred = reg.predict(X_test_reg)
    
    ax.scatter(X_reg, y_reg, s=15, alpha=0.5, label='Data')
    ax.plot(X_test_reg, y_pred, color='red', linewidth=2, label='Prediction')
    depth_str = str(depth) if depth else 'None'
    ax.set_title(f'max_depth={depth_str}')
    ax.legend(fontsize=9)
    ax.set_xlabel('X')
    ax.set_ylabel('y')

plt.suptitle('DecisionTreeRegressor: Effect of max_depth', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("Left: underfitting (too shallow), Middle: good fit, Right: overfitting (too deep)")

In [None]:
# Demonstrate that decision trees cannot extrapolate
np.random.seed(42)
X_extrap = np.linspace(0, 5, 100).reshape(-1, 1)
y_extrap = 2 * X_extrap.ravel() + 1 + 0.5 * np.random.randn(100)

reg_extrap = DecisionTreeRegressor(max_depth=5, random_state=42)
reg_extrap.fit(X_extrap, y_extrap)

# Predict beyond training range
X_full = np.linspace(-1, 8, 300).reshape(-1, 1)
y_pred_full = reg_extrap.predict(X_full)

fig, ax = plt.subplots(figsize=(9, 5))
ax.scatter(X_extrap, y_extrap, s=20, alpha=0.6, label='Training data')
ax.plot(X_full, y_pred_full, color='red', linewidth=2, label='Tree prediction')
ax.plot(X_full, 2 * X_full.ravel() + 1, 'g--', linewidth=1.5, label='True linear trend')
ax.axvspan(-1, 0, alpha=0.1, color='red', label='Extrapolation zone')
ax.axvspan(5, 8, alpha=0.1, color='red')
ax.set_xlabel('X')
ax.set_ylabel('y')
ax.set_title('Decision Trees Cannot Extrapolate')
ax.legend()
plt.tight_layout()
plt.show()

print("Outside the training range, the tree predicts a constant (the nearest leaf value).")

## 8. Common Mistakes

1. **Not limiting tree depth**: An unconstrained tree will memorize the training data, leading to poor generalization. Always set `max_depth`, `min_samples_split`, or `min_samples_leaf`.

2. **Ignoring overfitting signals**: If training accuracy is much higher than test accuracy, the tree is too complex. Use cross-validation to select hyperparameters.

3. **Using trees for extrapolation**: Decision trees predict constant values in each leaf. They cannot extrapolate beyond the range of training data -- predictions outside that range are simply the nearest leaf's mean.

4. **Assuming feature importance equals causation**: A feature being "important" in the tree means it is useful for splitting, not that it causes the outcome.

5. **Ignoring class imbalance**: Decision trees can be biased toward the majority class. Use `class_weight='balanced'` or resample the data.

## 9. Exercises

### Exercise 1: Entropy vs Gini
Train two `DecisionTreeClassifier` models on the Iris dataset -- one with `criterion='gini'` and one with `criterion='entropy'`. Compare their test accuracy and tree structure (depth, number of leaves). Are the results meaningfully different?

### Exercise 2: Hyperparameter Tuning
Use `GridSearchCV` to find the best combination of `max_depth` (1-10), `min_samples_split` (2, 5, 10), and `min_samples_leaf` (1, 3, 5) for a `DecisionTreeClassifier` on the Iris dataset. Report the best parameters and CV score.

### Exercise 3: Regression Tree
Generate a noisy quadratic dataset: `y = 2*x^2 - 3*x + 1 + noise`. Fit `DecisionTreeRegressor` models with depths 2, 4, 6, and None. Plot the predictions and discuss which depth gives the best bias-variance trade-off.

In [None]:
# Exercise 1 starter code
# clf_gini = DecisionTreeClassifier(criterion='gini', max_depth=4, random_state=42)
# clf_entropy = DecisionTreeClassifier(criterion='entropy', max_depth=4, random_state=42)
# ... fit, predict, compare ...