<a href="https://colab.research.google.com/github/statlib/learn-rules/blob/main/notebooks/rulefit-tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RuleFit
Implementation of a rule based prediction algorithm based on the rulefit algorithm from Friedman and Popescu ([PDF](http://statweb.stanford.edu/~jhf/ftp/RuleFit.pdf)).

The algorithm can be used for predicting an output vector y given an input matrix X. In the first step a tree ensemble is generated with gradient boosting. The trees are then used to form rules, where the paths to each node in each tree form one rule. A rule is a binary decision if an observation is in a given node, which is dependent on the input features that were used in the splits. The ensemble of rules together with the original input features are then being input in a L1-regularized linear model, also called Lasso, which estimates the effects of each rule on the output target but at the same time estimating many of those effects to zero.

You can use rulefit for predicting a numeric response (categorial not yet implemented). The input has to be a numpy matrix with only numeric values.

In [1]:
!pip install --upgrade rulefit scikit-learn &> /dev/null

In [4]:
import numpy as np
import pandas as pd
from rulefit import RuleFit

boston_data = pd.read_csv(
    "https://raw.githubusercontent.com/christophM/rulefit/master/boston.csv", 
    index_col=0
)

y = boston_data.medv.values
_X = boston_data.drop("medv", axis=1)
X = _X.values
features = _X.columns

In [5]:
rf = RuleFit()
rf.fit(X, y, feature_names=features)

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


In [6]:
from sklearn.ensemble import GradientBoostingRegressor

gb = GradientBoostingRegressor(n_estimators=2000, max_depth=20, learning_rate=0.01)
rf = RuleFit(tree_generator=gb)
rf.fit(X, y, feature_names=features)

  model = cd_fast.enet_coordinate_descent(


In [8]:
rf.predict(X)

array([24.84525322, 21.79219054, 33.79334629, 33.3507348 , 34.48392686,
       27.36634752, 21.4936276 , 21.46011902, 15.06990971, 18.36191622,
       18.69329684, 19.05824824, 22.44553211, 19.2209587 , 18.26496264,
       19.52549054, 21.27710598, 16.83997717, 19.29568326, 18.20163434,
       14.46594732, 16.87230214, 17.27774152, 14.39646377, 16.08519968,
       16.4330627 , 16.31723535, 16.29416589, 18.31457781, 20.64407444,
       13.70753508, 17.28268677, 14.00804441, 16.23248276, 15.4639173 ,
       19.6180256 , 20.21380616, 21.86556826, 21.47426804, 29.56741806,
       34.19998883, 29.22231899, 25.89646458, 25.89809521, 23.15052032,
       21.74848228, 21.3973243 , 18.41404529, 14.73062923, 19.34817873,
       20.71575362, 21.70138229, 24.98326001, 22.6259006 , 17.81538053,
       32.53036334, 21.94811158, 31.31474447, 23.89870741, 21.02288382,
       18.77268449, 17.38056579, 22.34579253, 25.60002292, 31.58241969,
       23.38983105, 19.25832335, 20.72374606, 19.21500152, 20.68

In [9]:
rules = rf.get_rules()
rules = rules[rules.coef != 0].sort_values("support", ascending=False)
print(rules)

                                                  rule    type      coef  \
6                                                  age  linear -0.045068   
10                                             ptratio  linear  0.300844   
11                                               black  linear  0.000147   
7                                                  dis  linear -0.315272   
9                                                  tax  linear  0.000006   
..                                                 ...     ...       ...   
575  ptratio > 15.25 & tax > 222.5 & rm <= 4.753999...    rule  2.068024   
721  dis <= 2.0642999410629272 & tax > 222.5 & lsta...    rule -0.686256   
442  rm <= 7.83650016784668 & lstat <= 21.489999771...    rule -1.229255   
761  black > 105.23999786376953 & ptratio <= 20.949...    rule -1.456458   
172  rm <= 6.955499887466431 & rm > 6.8380000591278...    rule  0.299451   

      support  importance  
6    1.000000    1.260560  
10   1.000000    0.647410  
11 