Python code for tree ensemble interpretation
Switch branches/tags
Nothing to show
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Failed to load latest commit information.


Python code for tree ensemble interpretation proposed in the following paper.


To use defragTrees:

  • Python3.x
  • Numpy
  • Pandas

To run example codes in example directory:

  • Python: XGBoost, Scikit-learn
  • R: randomForest

To replicate paper results in paper directory:

  • Python: Scikit-learn, Matplotlib, pylab
  • R: randomForest, inTrees, nodeHarvest


Prepare data:

  • Input X: feature matrix, numpy array of size (num, dim).
  • Output y: output array, numpy array of size (num,).
    • For regression, y is real value.
    • For classification, y is class index (i.e., 0, 1, 2, ..., C-1, for C classes).
  • Splitter splitter: thresholds of tree ensembles, numpy array of size (# of split rules, 2).
    • Each row of splitter is (feature index, threshold). Suppose the split rule is second feature < 0.5, the row of splitter is then (1, 0.5).

Import the class:

from defragTrees import DefragModel

Fit the simplified model:

Kmax = 10 # uppder-bound number of rules to be fitted
mdl = DefragModel(modeltype='regression') # change to 'classification' if necessary., y, splitter, Kmax), y, splitter, Kmax, fittype='EM') # use this when one wants exactly Kmax rules to be fitted

Check the learned rules:


For further deitals, see In IPython, one can check:

import defragTrees


Simple Examples

See example directory.

Replicating Paper Results

See paper directory.