# Arborium: Simplified Tree Representations

This notebook demonstrates how to use Arborium to create simplified tree representations of complex XGBoost models. This can be especially useful for understanding and explaining models with many trees and deep structures.

## Installation

If you're running this notebook in Colab or outside the arborium repository, uncomment and run the following cell to install the package:

In [None]:
# Uncomment if running in Colab or if you haven't installed arborium yet
# !pip install arborium[xgboost]

## Importing Libraries

First, let's import the necessary libraries:

In [None]:
from arborium import XGBTreeVisualizer
import xgboost as xgb
import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

## Loading and Preparing Data

We'll use the California Housing dataset for this example, which has more samples and features than our previous examples:

In [None]:
# Load a large dataset
housing = fetch_california_housing()
X, y = housing.data, housing.target
feature_names = housing.feature_names

# Split the data for evaluation
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Take a quick look at our data
print(f"Number of samples: {X.shape[0]}")
print(f"Number of features: {X.shape[1]}")
print(f"Feature names: {feature_names}")

## Training a Complex XGBoost Model

Let's train a more complex XGBoost model with many trees and deep structure:

In [None]:
# Train a complex model
model = xgb.XGBRegressor(n_estimators=100, max_depth=8)
model.fit(X_train, y_train)

# Check model performance
y_pred = model.predict(X_test)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
print(f"Model trained with {model.n_estimators} trees of max depth {model.max_depth}")
print(f"Test RMSE: {rmse:.4f}")

## Creating a Visualizer

Now, let's create an Arborium visualizer for the model:

In [None]:
# Create a visualizer
visualizer = XGBTreeVisualizer(model, X_train, y_train, feature_names=feature_names)

## Visualizing a Regular Tree

Let's first look at one of the regular trees in the model:

In [None]:
# Show a regular tree (the 10th tree)
visualizer.show_tree(9)

As you can see, individual trees in this complex model can be quite deep and hard to interpret. This is where simplified trees come in handy.

## Creating a Simplified Tree Representation

Arborium can create a simplified decision tree that approximates the behavior of the entire ensemble:

In [None]:
# Show a simplified representation of the entire model
simplified_tree = visualizer.show_simplified_tree(
    max_depth=3,              # Control the depth of the simplified tree
    n_components=None,        # Use all features (no dimensionality reduction)
    n_samples=5000            # Use 5000 samples to build the simplified model
)

## Using the Simplified Model for Predictions

The simplified model can also be used to make predictions. Let's see how it compares to the full model:

In [None]:
# Use the simplified model for predictions on the test set
simplified_predictions = visualizer.predict_with_simplified_tree(X_test)

# Compare with the full model
full_predictions = model.predict(X_test)

# Calculate metrics
simplified_rmse = np.sqrt(mean_squared_error(y_test, simplified_predictions))
full_rmse = np.sqrt(mean_squared_error(y_test, full_predictions))

print(f"Full model RMSE: {full_rmse:.4f}")
print(f"Simplified model RMSE: {simplified_rmse:.4f}")
print(f"Performance difference: {((simplified_rmse - full_rmse) / full_rmse * 100):.2f}%")

## Experimenting with Different Simplification Parameters

Let's try different parameters for the simplified tree:

In [None]:
# Try a deeper simplified tree
deeper_tree = visualizer.show_simplified_tree(
    max_depth=5,
    n_samples=5000
)

In [None]:
# Try with dimensionality reduction
pca_tree = visualizer.show_simplified_tree(
    max_depth=3,
    n_components=4,  # Reduce to 4 principal components
    n_samples=5000
)

## Getting the Simplified Model

You can also access the simplified model directly, which is a scikit-learn decision tree:

In [None]:
# Get the most recently created simplified model
dt_model = visualizer.get_simplified_model()

# Show information about the model
print(f"Type: {type(dt_model).__name__}")
print(f"Max depth: {dt_model.max_depth}")
print(f"Number of leaves: {dt_model.get_n_leaves()}")

## Conclusion

You've now learned how to use Arborium to create simplified tree representations of complex XGBoost models. These simplified trees can help with:

1. Model interpretation and explanation
2. Understanding the most important features and decision rules
3. Creating approximate but more interpretable models

While simplified trees sacrifice some performance compared to the full ensemble, they provide valuable insights into how the model makes predictions, which can be crucial for explaining model behavior to stakeholders or debugging model issues.