# Building a Decision Tree from Scratch: A Visual Guide

## What Are We Building?

Imagine you're sorting fruit. You might ask: "Is it red?" If yes, it's probably an apple. If no, ask: "Is it yellow?" If yes, maybe a banana. This chain of yes/no questions is exactly how a decision tree works—it's an algorithm that learns to ask the right questions to classify data.

In this guide, we'll build a decision tree classifier from the ground up, understanding each piece as we go.

---

## Part 1: Measuring "Messiness" with Gini Impurity

Before we can build our tree, we need a way to measure how "mixed up" our data is. Enter **Gini impurity**.

### The Intuition

Imagine you have a bag of colored marbles:
- **Bag A**: All red marbles → Very "pure" (Gini = 0)
- **Bag B**: 50% red, 50% blue → Very "messy" (Gini = 0.5)
- **Bag C**: 90% red, 10% blue → Mostly pure (Gini = 0.18)

Gini impurity gives us a number that tells us how mixed our labels are. The formula is:

**Gini = 1 - Σ(pᵢ²)**

Where pᵢ is the probability of each class.

### The Code

In [20]:
def gini(y):
    """Calculate how mixed up our labels are"""
    if len(y) == 0:
        return 0
    
    # Count each label
    counts = {}
    for label in y:
        counts[label] = counts.get(label, 0) + 1
    
    # Convert counts to probabilities
    total = len(y)
    probabilities = [count / total for count in counts.values()]
    
    # Calculate impurity
    impurity = 1 - sum(p ** 2 for p in probabilities)
    
    return impurity

### Example

In [23]:
pure_labels = [0, 0, 0, 0]      # Gini = 0 (perfect!)
mixed_labels = [0, 1, 0, 1]     # Gini = 0.5 (messy!)
mostly_pure = [0, 0, 0, 1]      # Gini = 0.375 (pretty good)

In [29]:
print(gini(pure_labels))
print(gini(mixed_labels))
print(gini(mostly_pure))

0.0
0.5
0.375


## Part 2: Finding the Best Split

Now that we can measure messiness, how do we decide where to split our data? We try every possible split and pick the one that reduces messiness the most.

### The Strategy

Think of it like organizing a messy closet. You could organize by:
- Color (shirts on left, pants on right)
- Season (summer on left, winter on right)
- Formality (casual vs. formal)

We try all options and pick the one that makes the most sense (reduces chaos the most).

### Gini Gain

**Gini Gain** = How much messiness did we remove?

```
Gain = (Messiness Before) - (Weighted Average Messiness After)
```

### The Code

In [33]:
def best_split(X, y):
    """Try every possible split and return the best one"""
    best_gain = 0
    best_feature = None
    best_threshold = None
    current_impurity = gini(y)
    
    # Try splitting on each feature
    for feature in range(len(X[0])):
        # Try each unique value as a threshold
        thresholds = set(row[feature] for row in X)
        
        for threshold in thresholds:
            # Split data: left (≤ threshold) and right (> threshold)
            left_y = [y[i] for i in range(len(y)) if X[i][feature] <= threshold]
            right_y = [y[i] for i in range(len(y)) if X[i][feature] > threshold]
            
            # Skip if split leaves one side empty
            if len(left_y) == 0 or len(right_y) == 0:
                continue
            
            # Calculate weighted impurity of the split
            left_impurity = gini(left_y)
            right_impurity = gini(right_y)
            weighted_impurity = (len(left_y) / len(y)) * left_impurity + \
                               (len(right_y) / len(y)) * right_impurity
            
            # How much did we improve?
            gain = current_impurity - weighted_impurity
            
            # Keep track of the best split so far
            if gain > best_gain:
                best_gain = gain
                best_feature = feature
                best_threshold = threshold
    
    return best_feature, best_threshold, best_gain

### Visual Example

```
Before split: [0, 0, 1, 1, 0, 1]  (Gini = 0.5)

Try: Feature_0 <= 2.5
├─ Left:  [0, 0, 0]  (Gini = 0)    ← Pure!
└─ Right: [1, 1, 1]  (Gini = 0)    ← Pure!

Gain = 0.5 - 0 = 0.5  ← Excellent split!
```


## Part 3: Building the Tree Recursively

Now comes the magic. We'll build our tree by recursively splitting the data until we can't (or shouldn't) split anymore.

### When to Stop Splitting

We stop when:
1. **All labels are the same** (pure node—nothing to gain)
2. **We've gone too deep** (max_depth reached—prevent overfitting)
3. **Too few samples** (not enough data to make a reliable split)
4. **No gain from splitting** (can't improve anymore)

### The Code



In [39]:
def most_common_label(y):
    """Return the most frequent label"""
    counts = {}
    for label in y:
        counts[label] = counts.get(label, 0) + 1
    return max(counts.items(), key=lambda x: x[1])[0]


def build_tree(X, y, depth=0, max_depth=5, min_samples_leaf=1):
    """Recursively build the decision tree"""
    
    # STOPPING CONDITION 1: Pure node
    if gini(y) == 0:
        return {
            'leaf': True,
            'class': y[0],
            'samples': len(y),
            'impurity': 0
        }
    
    # STOPPING CONDITION 2: Max depth reached
    if depth >= max_depth:
        return {
            'leaf': True,
            'class': most_common_label(y),
            'samples': len(y),
            'impurity': gini(y)
        }
    
    # STOPPING CONDITION 3: Too few samples
    if len(y) < 2 * min_samples_leaf:
        return {
            'leaf': True,
            'class': most_common_label(y),
            'samples': len(y),
            'impurity': gini(y)
        }
    
    # Find the best way to split
    feature, threshold, gain = best_split(X, y)
    
    # STOPPING CONDITION 4: No improvement
    if gain == 0 or feature is None:
        return {
            'leaf': True,
            'class': most_common_label(y),
            'samples': len(y),
            'impurity': gini(y)
        }
    
    # Split the data
    left_indices = [i for i in range(len(y)) if X[i][feature] <= threshold]
    right_indices = [i for i in range(len(y)) if X[i][feature] > threshold]
    
    # Check minimum samples constraint
    if len(left_indices) < min_samples_leaf or len(right_indices) < min_samples_leaf:
        return {
            'leaf': True,
            'class': most_common_label(y),
            'samples': len(y),
            'impurity': gini(y)
        }
    
    # Recursively build left and right subtrees
    left_X = [X[i] for i in left_indices]
    left_y = [y[i] for i in left_indices]
    right_X = [X[i] for i in right_indices]
    right_y = [y[i] for i in right_indices]
    
    return {
        'leaf': False,
        'feature': feature,
        'threshold': threshold,
        'samples': len(y),
        'impurity': gini(y),
        'gain': gain,
        'left': build_tree(left_X, left_y, depth + 1, max_depth, min_samples_leaf),
        'right': build_tree(right_X, right_y, depth + 1, max_depth, min_samples_leaf)
    }

## Part 4: Making Predictions

Once we've built our tree, making predictions is like following a flowchart.

### The Process

1. Start at the root (top) of the tree
2. Check the feature value of your sample
3. Go left if value ≤ threshold, right if value > threshold
4. Repeat until you reach a leaf
5. Return the leaf's class prediction

### The Code

In [42]:
def predict_single(tree, sample):
    """Make a prediction for one sample by walking down the tree"""
    
    # Base case: we've reached a leaf node
    if tree['leaf']:
        return tree['class']
    
    # Recursive case: keep going left or right
    if sample[tree['feature']] <= tree['threshold']:
        return predict_single(tree['left'], sample)
    else:
        return predict_single(tree['right'], sample)


def predict(tree, X):
    """Make predictions for multiple samples"""
    return [predict_single(tree, sample) for sample in X]

### Visual Example

```
Sample: [2.5, 3.0]

         [Feature_0 <= 3.0]
         /                \
    YES (2.5 ≤ 3.0)        NO
       /                      \
  Predict: 0             Predict: 1
      ↑
   We land here!
```

## Part 5: Putting It All Together



In [51]:
# Sample data: 6 points with 2 features each
X = [
    [2.5, 1.5],  # Class 0
    [3.5, 2.5],  # Class 1
    [1.5, 3.5],  # Class 0
    [4.5, 4.5],  # Class 1
    [3.0, 1.0],  # Class 0
    [1.0, 4.0]   # Class 1
]
y = [0, 1, 0, 1, 0, 1]

# Build the tree
tree = build_tree(X, y, max_depth=3, min_samples_leaf=1)

# Visualize it
print("TREE STRUCTURE:")
print_tree(tree, ["Feature_0", "Feature_1"])

# Make predictions
predictions = predict(tree, X)

# Check accuracy
for i, (sample, true, pred) in enumerate(zip(X, y, predictions)):
    match = "✓" if true == pred else "✗"
    print(f"Sample {i}: {sample} → True: {true}, Pred: {pred} {match}")

accuracy = sum(1 for t, p in zip(y, predictions) if t == p) / len(y)
print(f"\nAccuracy: {accuracy:.2%}")

# Test on new data
new_samples = [[2.0, 2.0], [4.0, 4.0], [1.0, 1.0]]
new_predictions = predict(tree, new_samples)

TREE STRUCTURE:
[Feature_0 <= 3.00] (samples=6, impurity=0.500, gain=0.250)
  Left:
    [Feature_0 <= 1.00] (samples=4, impurity=0.375, gain=0.375)
      Left:
        → Predict: 1 (samples=1, impurity=0.000)
      Right:
        → Predict: 0 (samples=3, impurity=0.000)
  Right:
    → Predict: 1 (samples=2, impurity=0.000)
Sample 0: [2.5, 1.5] → True: 0, Pred: 0 ✓
Sample 1: [3.5, 2.5] → True: 1, Pred: 1 ✓
Sample 2: [1.5, 3.5] → True: 0, Pred: 0 ✓
Sample 3: [4.5, 4.5] → True: 1, Pred: 1 ✓
Sample 4: [3.0, 1.0] → True: 0, Pred: 0 ✓
Sample 5: [1.0, 4.0] → True: 1, Pred: 1 ✓

Accuracy: 100.00%


## Understanding the Hyperparameters

### max_depth
- **What it does**: Limits how many questions the tree can ask
- **Small value (e.g., 1-2)**: Simple tree, may underfit
- **Large value (e.g., 10+)**: Complex tree, may overfit
- **Sweet spot**: Usually 3-7 for small datasets

### min_samples_leaf
- **What it does**: Requires each leaf to have at least N samples
- **Small value (e.g., 1)**: More specific predictions, risk of overfitting
- **Large value (e.g., 5+)**: More general predictions, more robust
- **Use case**: Increase this if your tree is too complex

---

## Key Takeaways

1. **Decision trees work by asking questions**: Each node tests a feature against a threshold
2. **Gini impurity measures messiness**: Lower is better (0 = pure)
3. **We split to maximize information gain**: Pick the split that best separates classes
4. **Recursion builds the tree**: Split, then split the splits, until we should stop
5. **Hyperparameters control complexity**: Use them to prevent overfitting

---

## Next Steps

Now that you understand the basics, you can:
- Add support for continuous target variables (regression trees)
- Implement pruning to simplify overfitted trees
- Build a random forest (multiple trees voting together)
- Add feature importance calculations
- Handle missing values

Happy tree building!