# 2.3 **Train** and Visualize Decision Trees - Predict Student Departure with Decision Trees

## Model Cycle: The 5 Key Steps

### 1. Build the Model : Create the pipeline with decision tree classifier.  
### **2. Train the Model : Fit the model on the training data.**  
### 3. Generate Predictions : Use the trained model to make predictions.  
### 4. Evaluate the Model : Assess performance using evaluation metrics.  
### 5. Improve the Model : Tune hyperparameters for optimal performance.

## Introduction

In the previous notebook, we built decision tree pipelines. Now we will train these models on our student departure data and explore one of the key advantages of decision trees: **interpretability through visualization**.

Decision trees can be visualized as flowcharts, making them ideal for communicating model logic to non-technical stakeholders like academic advisors, administrators, and faculty.

### Learning Objectives

By the end of this notebook, you will be able to:

1. Train decision tree models on student data
2. Visualize trained trees using multiple methods
3. Extract human-readable decision rules
4. Interpret feature importance from trained trees
5. Compare tree structures across different configurations

## 1. Load Dependencies and Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
import numpy as np
import pickle
import os

from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt

pd.options.display.max_columns = None

In [None]:
# Set up file paths
root_filepath = '/content/drive/MyDrive/projects/Applied-Data-Analytics-For-Higher-Education-Course-2/'
data_filepath = f'{root_filepath}data/'
course3_filepath = f'{root_filepath}course_3/'
models_path = f'{course3_filepath}models/'

In [None]:
# Load training data
df_training = pd.read_csv(f'{data_filepath}training.csv')

print(f"Training data shape: {df_training.shape}")
print(f"\nTarget distribution:")
print(df_training['SEM_3_STATUS'].value_counts())

In [None]:
# Define feature matrix and target
X_train = df_training
y_train = df_training['SEM_3_STATUS']

## 2. Load Pre-built Models

In [None]:
# Load the decision tree models we built in notebook 2.2
basic_dt_model = pickle.load(open(f'{models_path}basic_decision_tree_model.pkl', 'rb'))
constrained_dt_model = pickle.load(open(f'{models_path}constrained_decision_tree_model.pkl', 'rb'))
balanced_dt_model = pickle.load(open(f'{models_path}balanced_decision_tree_model.pkl', 'rb'))

print("Models loaded successfully!")
print(f"  - Basic Decision Tree")
print(f"  - Constrained Decision Tree (max_depth=5)")
print(f"  - Balanced Decision Tree (class_weight='balanced')")

## 3. Train Decision Tree Models

Training a decision tree involves the recursive partitioning algorithm we discussed in notebook 2.1. The tree grows by finding the best splits until stopping criteria are met.

### 3.1 Training the Basic Model

In [None]:
# Train the basic (unconstrained) decision tree
print("Training Basic Decision Tree...")
basic_dt_model.fit(X_train, y_train)

# Get the trained classifier
basic_tree = basic_dt_model.named_steps['classifier']

print(f"\nBasic Tree Statistics:")
print(f"  - Tree depth: {basic_tree.get_depth()}")
print(f"  - Number of leaves: {basic_tree.get_n_leaves()}")
print(f"  - Number of features used: {basic_tree.n_features_in_}")

**Observation**: Notice how deep the unconstrained tree grows! This is a clear sign of overfitting - the tree is memorizing the training data.

### 3.2 Training the Constrained Model

In [None]:
# Train the constrained decision tree
print("Training Constrained Decision Tree...")
constrained_dt_model.fit(X_train, y_train)

# Get the trained classifier
constrained_tree = constrained_dt_model.named_steps['classifier']

print(f"\nConstrained Tree Statistics:")
print(f"  - Tree depth: {constrained_tree.get_depth()}")
print(f"  - Number of leaves: {constrained_tree.get_n_leaves()}")
print(f"  - Number of features used: {constrained_tree.n_features_in_}")

### 3.3 Training the Balanced Model

In [None]:
# Train the balanced decision tree
print("Training Balanced Decision Tree...")
balanced_dt_model.fit(X_train, y_train)

# Get the trained classifier
balanced_tree = balanced_dt_model.named_steps['classifier']

print(f"\nBalanced Tree Statistics:")
print(f"  - Tree depth: {balanced_tree.get_depth()}")
print(f"  - Number of leaves: {balanced_tree.get_n_leaves()}")
print(f"  - Number of features used: {balanced_tree.n_features_in_}")

In [None]:
# Compare tree statistics
comparison_df = pd.DataFrame({
    'Model': ['Basic (Unconstrained)', 'Constrained', 'Balanced'],
    'Depth': [basic_tree.get_depth(), constrained_tree.get_depth(), balanced_tree.get_depth()],
    'Leaves': [basic_tree.get_n_leaves(), constrained_tree.get_n_leaves(), balanced_tree.get_n_leaves()],
    'Complexity': ['Very High', 'Moderate', 'Moderate']
})

print("\nTree Structure Comparison:")
print(comparison_df.to_string(index=False))

## 4. Visualizing Decision Trees

One of the greatest strengths of decision trees is their interpretability. Let's explore multiple ways to visualize our trained trees.

### 4.1 Text-based Representation

The simplest way to view a decision tree is as text-based rules.

In [None]:
# Get feature names after preprocessing
# We need to extract the feature names from the preprocessor
preprocessor = constrained_dt_model.named_steps['preprocessing']

# Get numerical feature names (passed through)
numerical_columns = [
    'HS_GPA', 'GPA_1', 'GPA_2', 'DFW_RATE_1', 'DFW_RATE_2',
    'UNITS_ATTEMPTED_1', 'UNITS_ATTEMPTED_2'
]

# Get categorical feature names (one-hot encoded)
# We need to fit the preprocessor first to get the encoded names
preprocessor.fit(X_train)
cat_encoder = preprocessor.named_transformers_['cat']
cat_feature_names = cat_encoder.get_feature_names_out(['GENDER', 'RACE_ETHNICITY', 'FIRST_GEN_STATUS']).tolist()

# Combine all feature names
feature_names = numerical_columns + cat_feature_names
print(f"Feature names ({len(feature_names)} total):")
for i, name in enumerate(feature_names):
    print(f"  {i}: {name}")

In [None]:
# Text representation of the constrained tree
print("Text Representation of Constrained Decision Tree:")
print("="*60)
tree_rules = export_text(constrained_tree, feature_names=feature_names, max_depth=4)
print(tree_rules)

### 4.2 Graphical Visualization

Matplotlib's plot_tree provides a clear graphical representation.

In [None]:
# Visualize the constrained tree (more manageable size)
plt.figure(figsize=(20, 12))
plot_tree(constrained_tree, 
          feature_names=feature_names,
          class_names=['Enrolled', 'Not Enrolled'],
          filled=True,
          rounded=True,
          fontsize=8,
          max_depth=3)  # Limit display depth for readability
plt.title('Constrained Decision Tree (First 3 Levels)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Visualize the balanced tree
plt.figure(figsize=(20, 12))
plot_tree(balanced_tree, 
          feature_names=feature_names,
          class_names=['Enrolled', 'Not Enrolled'],
          filled=True,
          rounded=True,
          fontsize=8,
          max_depth=3)
plt.title('Balanced Decision Tree (First 3 Levels)', fontsize=14)
plt.tight_layout()
plt.show()

**Reading the Tree Visualization:**
- **First line**: The decision rule (feature <= threshold)
- **gini**: The Gini impurity at this node
- **samples**: Number of training samples reaching this node
- **value**: Distribution of classes [Enrolled, Not Enrolled]
- **class**: The majority class (prediction if this were a leaf)
- **Color**: Blue = Enrolled majority, Orange = Not Enrolled majority

### 4.3 Interactive Tree Visualization

Let's create an interactive visualization using Plotly to explore tree structure.

In [None]:
def extract_tree_structure(tree, feature_names):
    """
    Extract tree structure for visualization.
    Returns nodes and edges data.
    """
    tree_ = tree.tree_
    n_nodes = tree_.node_count
    
    nodes_data = []
    edges_data = []
    
    # Calculate positions using BFS
    from collections import deque
    
    # Node positions
    positions = {}
    queue = deque([(0, 0, 0, 1)])  # (node_id, depth, left_bound, right_bound)
    
    while queue:
        node_id, depth, left, right = queue.popleft()
        x = (left + right) / 2
        y = -depth
        positions[node_id] = (x, y)
        
        # Get node info
        is_leaf = tree_.children_left[node_id] == tree_.children_right[node_id]
        
        if is_leaf:
            # Leaf node
            values = tree_.value[node_id][0]
            predicted_class = 'N' if np.argmax(values) == 1 else 'E'
            label = f"Predict: {predicted_class}\n({int(values[0])} E, {int(values[1])} N)"
            node_type = 'leaf'
        else:
            # Internal node
            feature = feature_names[tree_.feature[node_id]]
            threshold = tree_.threshold[node_id]
            label = f"{feature}\n<= {threshold:.2f}"
            node_type = 'internal'
            
            # Add children to queue
            left_child = tree_.children_left[node_id]
            right_child = tree_.children_right[node_id]
            mid = (left + right) / 2
            queue.append((left_child, depth + 1, left, mid))
            queue.append((right_child, depth + 1, mid, right))
            
            # Add edges
            edges_data.append((node_id, left_child, 'Yes'))
            edges_data.append((node_id, right_child, 'No'))
        
        samples = tree_.n_node_samples[node_id]
        gini = tree_.impurity[node_id]
        
        nodes_data.append({
            'id': node_id,
            'x': x,
            'y': y,
            'label': label,
            'type': node_type,
            'samples': samples,
            'gini': gini
        })
    
    return nodes_data, edges_data, positions

# Extract structure from balanced tree (limited depth for visualization)
nodes, edges, positions = extract_tree_structure(balanced_tree, feature_names)

In [None]:
# Create interactive tree visualization
fig = go.Figure()

# Add edges
for parent_id, child_id, label in edges:
    if parent_id in positions and child_id in positions:
        x0, y0 = positions[parent_id]
        x1, y1 = positions[child_id]
        fig.add_trace(go.Scatter(
            x=[x0, x1], y=[y0, y1],
            mode='lines',
            line=dict(color='gray', width=1),
            showlegend=False,
            hoverinfo='skip'
        ))

# Add nodes
internal_nodes = [n for n in nodes if n['type'] == 'internal']
leaf_nodes = [n for n in nodes if n['type'] == 'leaf']

# Internal nodes
fig.add_trace(go.Scatter(
    x=[n['x'] for n in internal_nodes],
    y=[n['y'] for n in internal_nodes],
    mode='markers+text',
    marker=dict(size=40, color='lightblue', line=dict(color='blue', width=2)),
    text=[n['label'].split('\n')[0] for n in internal_nodes],
    textposition='middle center',
    textfont=dict(size=8),
    hovertemplate='<b>%{text}</b><br>Samples: %{customdata[0]}<br>Gini: %{customdata[1]:.3f}<extra></extra>',
    customdata=[[n['samples'], n['gini']] for n in internal_nodes],
    name='Decision Nodes'
))

# Leaf nodes
fig.add_trace(go.Scatter(
    x=[n['x'] for n in leaf_nodes],
    y=[n['y'] for n in leaf_nodes],
    mode='markers+text',
    marker=dict(size=30, color='lightgreen', symbol='square', line=dict(color='green', width=2)),
    text=[n['label'].split('\n')[0] for n in leaf_nodes],
    textposition='middle center',
    textfont=dict(size=7),
    hovertemplate='<b>%{text}</b><br>Samples: %{customdata[0]}<br>Gini: %{customdata[1]:.3f}<extra></extra>',
    customdata=[[n['samples'], n['gini']] for n in leaf_nodes],
    name='Leaf Nodes (Predictions)'
))

fig.update_layout(
    title='Interactive Decision Tree Visualization (Balanced Model)',
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    height=600,
    showlegend=True
)

fig.show()

## 5. Extracting Decision Rules

For stakeholder communication, we can extract human-readable decision rules from the tree.

In [None]:
def extract_rules(tree, feature_names, class_names):
    """
    Extract decision rules from a trained decision tree.
    """
    tree_ = tree.tree_
    rules = []
    
    def recurse(node, depth, path):
        if tree_.feature[node] != -2:  # Not a leaf
            feature = feature_names[tree_.feature[node]]
            threshold = tree_.threshold[node]
            
            # Left branch (<=)
            left_path = path + [f"{feature} <= {threshold:.2f}"]
            recurse(tree_.children_left[node], depth + 1, left_path)
            
            # Right branch (>)
            right_path = path + [f"{feature} > {threshold:.2f}"]
            recurse(tree_.children_right[node], depth + 1, right_path)
        else:
            # Leaf node
            values = tree_.value[node][0]
            total = sum(values)
            predicted_class = class_names[np.argmax(values)]
            confidence = max(values) / total
            
            rule = {
                'conditions': path,
                'prediction': predicted_class,
                'confidence': confidence,
                'samples': int(total),
                'distribution': f"{int(values[0])} E, {int(values[1])} N"
            }
            rules.append(rule)
    
    recurse(0, 0, [])
    return rules

# Extract rules from the balanced tree
class_names = ['Enrolled', 'Not Enrolled']
rules = extract_rules(balanced_tree, feature_names, class_names)

In [None]:
# Display rules predicting "Not Enrolled" with high confidence
print("High-Confidence Rules for At-Risk Students (Not Enrolled):")
print("="*70)

# Sort by confidence and filter for "Not Enrolled" predictions
at_risk_rules = [r for r in rules if r['prediction'] == 'Not Enrolled' and r['confidence'] >= 0.5]
at_risk_rules = sorted(at_risk_rules, key=lambda x: x['confidence'], reverse=True)

for i, rule in enumerate(at_risk_rules[:5], 1):
    print(f"\nRule {i} (Confidence: {rule['confidence']:.1%}, Samples: {rule['samples']}):")
    print(f"  IF {' AND '.join(rule['conditions'])}")
    print(f"  THEN Predict: {rule['prediction']}")
    print(f"  Distribution: {rule['distribution']}")

In [None]:
# Display rules predicting "Enrolled" with high confidence
print("High-Confidence Rules for Retained Students (Enrolled):")
print("="*70)

enrolled_rules = [r for r in rules if r['prediction'] == 'Enrolled' and r['confidence'] >= 0.9]
enrolled_rules = sorted(enrolled_rules, key=lambda x: x['samples'], reverse=True)

for i, rule in enumerate(enrolled_rules[:5], 1):
    print(f"\nRule {i} (Confidence: {rule['confidence']:.1%}, Samples: {rule['samples']}):")
    print(f"  IF {' AND '.join(rule['conditions'])}")
    print(f"  THEN Predict: {rule['prediction']}")

## 6. Feature Importance

Decision trees provide built-in feature importance scores based on how much each feature reduces impurity across all splits.

In [None]:
# Get feature importances from the balanced tree
importances = balanced_tree.feature_importances_

# Create a DataFrame for visualization
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': importances
}).sort_values('Importance', ascending=True)

# Filter to show only features with non-zero importance
importance_df = importance_df[importance_df['Importance'] > 0]

print("Feature Importances (Balanced Decision Tree):")
print(importance_df.sort_values('Importance', ascending=False).to_string(index=False))

In [None]:
# Visualize feature importance
fig = go.Figure(go.Bar(
    y=importance_df['Feature'],
    x=importance_df['Importance'],
    orientation='h',
    marker=dict(color='steelblue')
))

fig.update_layout(
    title='Feature Importance (Balanced Decision Tree)',
    xaxis_title='Importance (Gini Impurity Reduction)',
    yaxis_title='Feature',
    height=500
)

fig.show()

In [None]:
# Compare feature importance across models
importance_comparison = pd.DataFrame({
    'Feature': feature_names,
    'Basic': basic_tree.feature_importances_,
    'Constrained': constrained_tree.feature_importances_,
    'Balanced': balanced_tree.feature_importances_
})

# Get top features by average importance
importance_comparison['Average'] = importance_comparison[['Basic', 'Constrained', 'Balanced']].mean(axis=1)
top_features = importance_comparison.nlargest(10, 'Average')

print("Top 10 Features by Average Importance Across Models:")
print(top_features[['Feature', 'Basic', 'Constrained', 'Balanced', 'Average']].to_string(index=False))

In [None]:
# Create grouped bar chart comparing importance across models
fig = go.Figure()

top_features_sorted = top_features.sort_values('Average')

for model in ['Basic', 'Constrained', 'Balanced']:
    fig.add_trace(go.Bar(
        y=top_features_sorted['Feature'],
        x=top_features_sorted[model],
        name=model,
        orientation='h'
    ))

fig.update_layout(
    title='Feature Importance Comparison Across Decision Tree Models',
    xaxis_title='Importance',
    yaxis_title='Feature',
    barmode='group',
    height=500,
    legend=dict(orientation='h', yanchor='bottom', y=1.02)
)

fig.show()

## 7. Comparing Tree Structures

In [None]:
# Comprehensive comparison of tree structures
def get_tree_stats(tree, name):
    """Get comprehensive statistics for a decision tree."""
    tree_ = tree.tree_
    
    # Calculate average impurity reduction
    impurities = tree_.impurity[tree_.impurity > 0]
    
    return {
        'Model': name,
        'Depth': tree.get_depth(),
        'Leaves': tree.get_n_leaves(),
        'Total Nodes': tree_.node_count,
        'Avg Leaf Samples': tree_.n_node_samples[tree_.feature == -2].mean(),
        'Min Leaf Samples': tree_.n_node_samples[tree_.feature == -2].min(),
        'Features Used': (tree.feature_importances_ > 0).sum()
    }

stats = [
    get_tree_stats(basic_tree, 'Basic'),
    get_tree_stats(constrained_tree, 'Constrained'),
    get_tree_stats(balanced_tree, 'Balanced')
]

stats_df = pd.DataFrame(stats)
print("\nComprehensive Tree Structure Comparison:")
print(stats_df.to_string(index=False))

In [None]:
# Visualize tree complexity comparison
fig = make_subplots(rows=1, cols=3, subplot_titles=('Tree Depth', 'Number of Leaves', 'Features Used'))

models = ['Basic', 'Constrained', 'Balanced']
colors = ['coral', 'steelblue', 'seagreen']

# Depth
fig.add_trace(go.Bar(x=models, y=stats_df['Depth'], marker_color=colors, showlegend=False), row=1, col=1)

# Leaves
fig.add_trace(go.Bar(x=models, y=stats_df['Leaves'], marker_color=colors, showlegend=False), row=1, col=2)

# Features Used
fig.add_trace(go.Bar(x=models, y=stats_df['Features Used'], marker_color=colors, showlegend=False), row=1, col=3)

fig.update_layout(height=400, title_text='Decision Tree Complexity Comparison')
fig.show()

## 8. Save Trained Models

In [None]:
# Save trained models
trained_models = {
    'basic_decision_tree_trained': basic_dt_model,
    'constrained_decision_tree_trained': constrained_dt_model,
    'balanced_decision_tree_trained': balanced_dt_model
}

for name, model in trained_models.items():
    filepath = f'{models_path}{name}.pkl'
    pickle.dump(model, open(filepath, 'wb'))
    print(f"Saved: {filepath}")

In [None]:
# Save feature names for later use
feature_names_dict = {
    'feature_names': feature_names,
    'numerical_columns': numerical_columns,
    'categorical_columns': ['GENDER', 'RACE_ETHNICITY', 'FIRST_GEN_STATUS']
}

pickle.dump(feature_names_dict, open(f'{models_path}decision_tree_feature_names.pkl', 'wb'))
print(f"\nSaved feature names to: {models_path}decision_tree_feature_names.pkl")

## 9. Summary

In this notebook, we trained and visualized decision tree models for student departure prediction.

### Key Findings

| Model | Depth | Leaves | Key Insight |
|:------|:------|:-------|:------------|
| **Basic** | Very Deep | Many | Overfits to training data |
| **Constrained** | 5 | Moderate | Better generalization potential |
| **Balanced** | 5 | Moderate | Handles class imbalance |

### Most Important Features

The decision trees identified several key predictors of student departure:
- **GPA_1 and GPA_2**: First and second semester GPAs
- **DFW_RATE_1 and DFW_RATE_2**: Course failure rates
- **UNITS_ATTEMPTED**: Course load indicators

### Interpretation Advantage

Decision trees produce interpretable rules like:
- "If GPA_1 <= 1.8 AND DFW_RATE_1 > 0.3, then predict Not Enrolled"

These rules can be directly communicated to advisors and used in early alert systems.

### Connection to ML Cycle

We completed **Step 2: Train the Model**:
- Trained three decision tree variants
- Visualized tree structures
- Extracted decision rules
- Analyzed feature importance

### Next Steps

In the next notebook, we will:
1. Evaluate model performance on test data
2. Tune hyperparameters using cross-validation
3. Compare decision trees to logistic regression

**Proceed to:** `2.4 Evaluate and Tune Decision Trees`