From c59d287c74ffd9f5e259078ba9139b3133f6cc74 Mon Sep 17 00:00:00 2001 From: "Documenter.jl" Date: Fri, 28 Apr 2023 12:05:27 +0000 Subject: [PATCH] build based on 7a27869 --- dev/api/index.html | 4 ++-- dev/index.html | 8 ++++---- dev/search/index.html | 2 +- dev/search_index.js | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dev/api/index.html b/dev/api/index.html index 801f6e5..69f2bb8 100644 --- a/dev/api/index.html +++ b/dev/api/index.html @@ -9,7 +9,7 @@ max_rules::Int=10 ) -> MLJModelInterface.Probabilistic

Explainable rule-based model based on a random forest. This SIRUS algorithm extracts rules from a stabilized random forest. See the main page of the documentation for details about how it works.

Example

The classifier satisfies the MLJ interface, so it can be used like any other MLJ model. For example, it can be used to create a machine:

julia> using SIRUS, MLJ
 
-julia> mach = machine(StableRulesClassifier(; max_rules=15), X, y);

Arguments

source
SIRUS.StableForestClassifierType
StableForestClassifier(;
+julia> mach = machine(StableRulesClassifier(; max_rules=15), X, y);

Arguments

  • rng: Random number generator. StableRNGs are advised.
  • partial_sampling: Ratio of samples to use in each subset of the data. The default of 0.7 should be fine for most cases.
  • n_trees: The number of trees to use. The higher the number, the more likely it is that the correct rules are extracted from the trees, but also the longer model fitting will take. In most cases, 1000 rules should be more than enough, but it might be useful to run 2000 rules one time and verify that the model performance does not change much.
  • max_depth: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).
  • q: Number of cutpoints to use per feature. The default value of 10 should be good for most situations.
  • min_data_in_leaf: Minimum number of data points per leaf.
  • max_rules: This is the most important hyperparameter. In general, the more rules, the more accurate the model. However, more rules will also decrease model interpretability. So, it is important to find a good balance here. In most cases, 10-40 rules should provide reasonable accuracy while remaining interpretable.
  • lambda: The weights of the final rules are determined via a regularized regression over each rule as a binary feature. This hyperparameter specifies the strength of the ridge (L2) regularizer. Since the rules are quite strongly correlated, the ridge regularizer is the most useful to stabilize the weight estimates.
source
SIRUS.StableForestClassifierType
StableForestClassifier(;
     rng::AbstractRNG=default_rng(),
     partial_sampling::Real=0.7,
     n_trees::Int=1_000,
@@ -18,4 +18,4 @@
     min_data_in_leaf::Int=5
 ) <: MLJModelInterface.Probabilistic

Random forest classifier with a stabilized forest structure (Bénard et al., 2021). This stabilization increases stability when extracting rules. The impact on the predictive accuracy compared to standard random forests should be relatively small.

Note

Just like normal random forests, this model is not easily explainable. If you are interested in an explainable model, use the StableRulesClassifier.

Example

The classifier satisfies the MLJ interface, so it can be used like any other MLJ model. For example, it can be used to create a machine:

julia> using SIRUS, MLJ
 
-julia> mach = machine(StableForestClassifier(), X, y);

Arguments

  • rng: Random number generator. StableRNGs are advised.
  • partial_sampling: Ratio of samples to use in each subset of the data. The default of 0.7 should be fine for most cases.
  • n_trees: The number of trees to use.
  • max_depth: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).
  • q: Number of cutpoints to use per feature. The default value of 10 should be good for most situations.
  • min_data_in_leaf: Minimum number of data points per leaf.
source

Methods

SIRUS.feature_namesFunction
feature_names(rule::Rule) -> Vector{String}

Return a vector of feature names; one for each clause in rule.

source
SIRUS.directionsFunction
directions(rule::Rule) -> Vector{Symbol}

Return a vector of split directions; one for each clause in rule.

source
Base.valuesMethod
values(rule::Rule) -> Vector{Float64}

Return a vector split values; one for each clause in rule.

source
SIRUS.satisfiesFunction
satisfies(row::AbstractVector, rule::Rule)

Return whether data row satisfies rule.

source
+julia> mach = machine(StableForestClassifier(), X, y);

Arguments

source

Methods

SIRUS.feature_namesFunction
feature_names(rule::Rule) -> Vector{String}

Return a vector of feature names; one for each clause in rule.

source
SIRUS.directionsFunction
directions(rule::Rule) -> Vector{Symbol}

Return a vector of split directions; one for each clause in rule.

source
Base.valuesMethod
values(rule::Rule) -> Vector{Float64}

Return a vector split values; one for each clause in rule.

source
SIRUS.satisfiesFunction
satisfies(row::AbstractVector, rule::Rule)

Return whether data row satisfies rule.

source
diff --git a/dev/index.html b/dev/index.html index 2606885..22b2b63 100644 --- a/dev/index.html +++ b/dev/index.html @@ -714,7 +714,7 @@

Benchmarks "StableRulesClassifier" "(max_rules = 5,)" 0.68 -0.06 +0.05 3 @@ -752,7 +752,7 @@

Benchmarks - + + diff --git a/dev/search/index.html b/dev/search/index.html index 6a128c7..e22de9a 100644 --- a/dev/search/index.html +++ b/dev/search/index.html @@ -1,2 +1,2 @@ -Search · SIRUS.jl

Loading search...

    +Search · SIRUS.jl

    Loading search...

      diff --git a/dev/search_index.js b/dev/search_index.js index daae8a6..7fb319b 100644 --- a/dev/search_index.js +++ b/dev/search_index.js @@ -1,3 +1,3 @@ var documenterSearchIndex = {"docs": -[{"location":"api/#API","page":"API","title":"API","text":"","category":"section"},{"location":"api/#Types","page":"API","title":"Types","text":"","category":"section"},{"location":"api/","page":"API","title":"API","text":"StableRulesClassifier\nStableForestClassifier","category":"page"},{"location":"api/#SIRUS.StableRulesClassifier","page":"API","title":"SIRUS.StableRulesClassifier","text":"StableRulesClassifier(;\n rng::AbstractRNG=default_rng(),\n partial_sampling::Real=0.7,\n n_trees::Int=1_000,\n max_depth::Int=2,\n q::Int=10,\n min_data_in_leaf::Int=5,\n max_rules::Int=10\n) -> MLJModelInterface.Probabilistic\n\nExplainable rule-based model based on a random forest. This SIRUS algorithm extracts rules from a stabilized random forest. See the main page of the documentation for details about how it works.\n\nExample\n\nThe classifier satisfies the MLJ interface, so it can be used like any other MLJ model. For example, it can be used to create a machine:\n\njulia> using SIRUS, MLJ\n\njulia> mach = machine(StableRulesClassifier(; max_rules=15), X, y);\n\nArguments\n\nrng: Random number generator. StableRNGs are advised.\npartial_sampling: Ratio of samples to use in each subset of the data. The default of 0.7 should be fine for most cases.\nn_trees: The number of trees to use. The higher the number, the more likely it is that the correct rules are extracted from the trees, but also the longer model fitting will take. In most cases, 1000 rules should be more than enough, but it might be useful to run 2000 rules one time and verify that the model performance does not change much.\nmax_depth: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).\nq: Number of cutpoints to use per feature. The default value of 10 should be good for most situations.\nmin_data_in_leaf: Minimum number of data points per leaf.\nmax_rules: This is the most important hyperparameter. In general, the more rules, the more accurate the model. However, more rules will also decrease model interpretability. So, it is important to find a good balance here. In most cases, 10-40 rules should provide reasonable accuracy while remaining interpretable.\nlambda: The weights of the final rules are determined via a regularized regression over each rule as a binary feature. This hyperparameter specifies the strength of the ridge (L2) regularizer. Since the rules are quite strongly correlated, the ridge regularizer is the most useful to stabilize the weight estimates.\n\n\n\n\n\n","category":"type"},{"location":"api/#SIRUS.StableForestClassifier","page":"API","title":"SIRUS.StableForestClassifier","text":"StableForestClassifier(;\n rng::AbstractRNG=default_rng(),\n partial_sampling::Real=0.7,\n n_trees::Int=1_000,\n max_depth::Int=2,\n q::Int=10,\n min_data_in_leaf::Int=5\n) <: MLJModelInterface.Probabilistic\n\nRandom forest classifier with a stabilized forest structure (Bénard et al., 2021). This stabilization increases stability when extracting rules. The impact on the predictive accuracy compared to standard random forests should be relatively small.\n\nnote: Note\nJust like normal random forests, this model is not easily explainable. If you are interested in an explainable model, use the StableRulesClassifier.\n\nExample\n\nThe classifier satisfies the MLJ interface, so it can be used like any other MLJ model. For example, it can be used to create a machine:\n\njulia> using SIRUS, MLJ\n\njulia> mach = machine(StableForestClassifier(), X, y);\n\nArguments\n\nrng: Random number generator. StableRNGs are advised.\npartial_sampling: Ratio of samples to use in each subset of the data. The default of 0.7 should be fine for most cases.\nn_trees: The number of trees to use.\nmax_depth: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).\nq: Number of cutpoints to use per feature. The default value of 10 should be good for most situations.\nmin_data_in_leaf: Minimum number of data points per leaf.\n\n\n\n\n\n","category":"type"},{"location":"api/#Methods","page":"API","title":"Methods","text":"","category":"section"},{"location":"api/","page":"API","title":"API","text":"feature_names\ndirections\nvalues(::SIRUS.Rule)\nsatisfies","category":"page"},{"location":"api/#SIRUS.feature_names","page":"API","title":"SIRUS.feature_names","text":"feature_names(rule::Rule) -> Vector{String}\n\nReturn a vector of feature names; one for each clause in rule.\n\n\n\n\n\n","category":"function"},{"location":"api/#SIRUS.directions","page":"API","title":"SIRUS.directions","text":"directions(rule::Rule) -> Vector{Symbol}\n\nReturn a vector of split directions; one for each clause in rule.\n\n\n\n\n\n","category":"function"},{"location":"api/#Base.values-Tuple{SIRUS.Rule}","page":"API","title":"Base.values","text":"values(rule::Rule) -> Vector{Float64}\n\nReturn a vector split values; one for each clause in rule.\n\n\n\n\n\n","category":"method"},{"location":"api/#SIRUS.satisfies","page":"API","title":"SIRUS.satisfies","text":"satisfies(row::AbstractVector, rule::Rule)\n\nReturn whether data row satisfies rule.\n\n\n\n\n\n","category":"function"},{"location":"","page":"SIRUS","title":"SIRUS","text":"\n\n\n\n\n\n\n\n\n\n\n\n\n

      This package is a pure Julia implementation of the Stable and Interpretable RUle Sets (SIRUS) algorithm. The algorithm was originally created by Clément Bénard, Gérard Biau, Sébastien Da Veiga, and Erwan Scornet. This package has only implemented binary classification for now. Regression and multiclass-classification will be implemented later. For R users, the original version of the SIRUS algorithm is available via CRAN.

      \n

      The algorithm is based on random forests. However, compared to random forests, the model is much more interpretable since the forests are converted to a set of decison rules. This page will provide an overview of the algorithm and describe not only how it can be used but also how it works. To do this, let's start by briefly describing random forests.

      \n
      \n\n","category":"page"},{"location":"#Random-forests","page":"SIRUS","title":"Random forests","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Random forests are known to produce accurate predictions especially in settings where the number of features p is close to or higher than the number of observations n (Biau & Scornet, 2016). Let's start by explaining the building blocks of random forests: decision trees. As an example, we take Haberman's Survival Data Set (see the Appendix below for more details):

      \n
      \n\n
      data = _haberman()
      \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
      ageyearnodessurvival
      130.01964.01.01
      230.01962.03.01
      330.01965.00.01
      431.01959.02.01
      531.01965.04.01
      633.01958.010.01
      733.01960.00.01
      834.01959.00.00
      934.01966.09.00
      1034.01958.030.01
      ...
      30683.01958.02.00
      \n\n\n
      X = data[:, Not(:survival)];
      \n\n\n
      y = data.survival;
      \n\n\n\n

      This dataset contains observations from a study with patients who had breast cancer. The survival column contains a 0 if a patient has died within 5 years and 1 if the patient has survived for at least 5 years. The aim is to predict survival based on the age, the year in which the operation was conducted and the number of detected auxillary nodes.

      \n
      \n\n\n

      Via MLJ.jl, we can fit multiple decision trees on this dataset:

      \n
      \n\n
      tree_evaluations = let\n    model = DecisionTreeClassifier(; max_depth=2, rng=_rng())\n    _evaluate(model, X, y)\nend;
      \n\n\n\n

      This has fitted various trees to various subsets of the dataset via cross-validation. Here, I've set max_depth=2 to simplify the fitted trees which makes the tree more easily explainable. Also, for our small dataset, this forces the model to remain simple so it likely reduces overfitting. Let's look at the first tree:

      \n
      \n\n
      let\n    tree = tree_evaluations.fitted_params_per_fold[1].tree\n    _io2text() do io\n        DecisionTree.print_tree(io, tree; feature_names=names(data))\n    end\nend
      \n
      Feature 3: \"nodes\" < 2.5 ?\n├─ Feature 1: \"age\" < 79.5 ?\n    ├─ 2 : 151/178\n    └─ 1 : 1/1\n└─ Feature 1: \"age\" < 43.5 ?\n    ├─ 2 : 16/20\n    └─ 1 : 42/76\n
      \n\n\n

      What this shows is that the first tree decided that the nodes feature is the most helpful in deciding who will survive for 5 more years. Next, if the nodes feature is below 2.5, then age will be selected on. If age < 79.5, then the model will predict the second class and if age ≥ 79.5 it will predict the first class. Similarly for age < 43.5. Now, let's see what happens for a slight change in the data. In other words, let's see how the fitted model for the second split looks:

      \n
      \n\n
      let\n    tree = tree_evaluations.fitted_params_per_fold[2].tree\n    _io2text() do io\n        DecisionTree.print_tree(io, tree; feature_names=names(data))\n    end\nend
      \n
      Feature 3: \"nodes\" < 2.5 ?\n├─ Feature 1: \"age\" < 77.0 ?\n    ├─ 2 : 147/175\n    └─ 1 : 2/2\n└─ Feature 1: \"age\" < 43.5 ?\n    ├─ 2 : 18/21\n    └─ 1 : 41/77\n
      \n\n\n

      This shows that the features and the values for the splitpoints are not the same for both trees. This is called stability. Or in this case, a decision tree is considered to be unstable. This instability is problematic in situations where real-world decisions are based on the outcome of the model. Imagine using this model for the selecting which students are allowed to enter some university. If the model is updated every year with the data from the last year, then the selection criteria would vary wildly per year. This instability also causes accuracy to fluctuate wildly. Intuitively, this makes sense: if the model changes wildly for small data changes, then model accuracy also changes wildly. This intuitively also implies that the model is more likely to overfit. This is why random forests were introduced. Basically, random forests fit a large number of trees and average their predictions to come to a more accurate prediction. The individual trees are obtained by restricting the observations and the features that the trees are allowed to use. For the restriction on the observations, the trees are only allowed to see partial_sampling * n observations. In practise, partial_sampling is often 0.7. The restriction on the features is defined in such a way that it guarantees that not every tree will take the same split at the root of the tree. This makes the trees less correlated (James et al., 2021; Section 8.2.2) and, hence, more accurate.

      \n

      Unfortunately, these random forests are hard to interpret. To interpret the model, individuals would need to interpret hundreds to thousands of trees containing multiple levels. Alternatively, methods have been created to visualize these uninterpretable models (for example, see Molnar (2022); Chapters 6, 7 and 8). The most promising one of these methods are Shapley values and SHAP. These methods show which features have the highest influence on the prediction. See my blog post on Random forests and Shapley values for more information. Knowing which features have the highest influence is nice, but they do not state exactly what feature is used and at what cutoff. Again, this is not good enough for selecting students into universities. For example, what if the government decides to ask for details about the selection? The only answer that you can give is that some features are used for selection more than others and that they are on average used in a certain direction. If the government asks for biases in the model, then these are impossible to report. In practice, the decision is still a black-box. SIRUS solves this by extracting easily interpretable rules from the random forests.

      \n
      \n\n","category":"page"},{"location":"#Rule-based-models","page":"SIRUS","title":"Rule-based models","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Rule-based models promise much greater interpretability than random forests. Instead of returning a large number of trees, rule-based models return a set of rules. Each rule can be interpreted on its own and the final model aggregates these rules by summing the prediction of each rules. For example, one rule can be:

      \n
      \n

      if nodes < 4.5 then chance of survival is 0.6 and if nodes ≥ 4.5 then chance of survival is 0.4.

      \n
      \n

      Note that these rules can be extracted quite easily from the decision trees. For splits on the second level of the tree, the rule could look like:

      \n
      \n

      if nodes < 4.5 and age < 38.5 then chance of survival is 0.8 and otherwise the chance of survival is 0.4.

      \n
      \n

      When applying this extracting of rules to a random forest, there will be thousands of rules. Next, via some heuristic, the most important rules can be localized and these rules then result in the final model. See, for example, RuleFit (Friedman & Popescu, 2008). The problem with this approach is that they are fitted on the unstable decision trees that were shown above. As an example, on time the tree splits on age < 43.5 and another time on age < 44.5.

      \n
      \n\n","category":"page"},{"location":"#Tree-stabilization","page":"SIRUS","title":"Tree stabilization","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      In the papers which introduce SIRUS, Bénard et al. (2021a, 2021b) proof that their algorithm is stable and that the other algorithms are not. They achieve their stability by restricting the location at which the splitpoints can be chosen. To see how this works, let's look at the age feature on its own.

      \n
      \n\n
      nodes = sort(data.age);
      \n\n\n\n\n\n\n\n\n\n

      The default random forest algorithm is allowed to choose any location inside this feature to split on. To avoid having to figure out locations by itself, the algorithm will choose on of the datapoints as a split location. So, for example, the following split indicated by the red vertical line would be a valid choice:

      \n
      \n\n\n\n\n\n\n\n\n

      But what happens if we take a random subset of the data? Say, we take the following subset of length 0.7 * length(nodes):

      \n
      \n\n\n\n\n\n\n\n\n\n\n\n

      Now, the algorithm would choose a different location and, hence, introduce instability. To solve this, Bénard et al. decided to limit the splitpoints that the algorithm can use to split to data to a pre-defined set of points. For each feature, they find q empirical quantiles where q is typically 10. Let's overlay these quantiles on top of the age feature:

      \n
      \n\n\n\n\n\n

      Next, let's see where the cutpoints are when we take the same random subset as above:

      \n
      \n\n\n\n\n\n

      As can be seen, many cutpoints are at the same location as before. Furthermore, compared to the unrestricted range, the chance that two different trees who see a different random subset of the data will select the same cutpoint has increased dramatically.

      \n

      The benefit of this is that it is now quite easy to extract the most important rules. Rule extraction consists of simplifying them a bit and ordering them by frequency of occurrence. Let's see how accurate this model is.

      \n
      \n\n","category":"page"},{"location":"#Benchmarks","page":"SIRUS","title":"Benchmarks","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Let's compare the following models:

      \n
        \n
      • Decision tree (DecisionTreeClassifier)

        \n
      • \n
      • Stabilized random forest (StableForestClassifier)

        \n
      • \n
      • SIRUS (StableRulesClassifier)

        \n
      • \n
      • LightGBM (LGBMClassifier)

        \n
      • \n
      \n

      The latter is a state-of-the-art gradient boosting model created by Microsoft. See the Appendix for more details about these results.

      \n
      \n\n
      results = let\n    df = DataFrame(getproperty.([e1, e2, e3, e4, e5, e6], :row))\n    rename!(df, :se => \"1.96*SE\")\nend
      \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
      ModelHyperparametersAUC1.96*SE
      1\"DecisionTreeClassifier\"\"(max_depth = 2,)\"0.640.07
      2\"StableRulesClassifier\"\"(max_rules = 5,)\"0.680.06
      3\"StableRulesClassifier\"\"(max_rules = 25,)\"0.680.06
      4\"StableRulesClassifier\"\"(max_depth = 1,)\"0.70.05
      5\"StableForestClassifier\"\"(max_depth = 2,)\"0.70.05
      6\"LGBMClassifier\"\"(;)\"0.710.06
      \n\n\n\n

      We can summarize these results as follows:

      \n
      \n\n\n\n\n\n

      As can be seen, the score of the stabilized random forest (StableForestClassifier) is almost as good as Microsoft's classifier (LGBMClassifier), but both are not interpretable since that requires interpreting thousands of trees. With the rule-based classifier (StableRulesClassifier), a small amount of predictive performance can be traded for high interpretability. Note that the rule-based classifier may actually be more accurate in practice because verifying and debugging the model is much easier.

      \n

      Regarding the hyperparameters, tuning max_rules and max_depth has the most effect. max_rules specifies the number of rules to which the random forest is simplified. Setting to a high number such as 999 makes the predictive performance similar to that of a random forest, but also makes the interpretability as bad as a random forest. Therefore, it makes more sense to truncate the rules to somewhere in the range 5 to 40 to obtain accurate models with high interpretability. max_depth specifies how many levels the trees have. For larger datasets, max_depth=2 makes the most sense since it can find more complex patterns in the data. For smaller datasets, max_depth=1 makes more sense since it reduces the chance of overfitting. It also simplifies the rules because with max_depth=1, the rule will contain only one conditional (for example, "if A then ...") versus two conditionals (for example, "if A & B then ..."). In some cases, model accuracy can be improved by increasing n_trees. The higher this number, the more trees are fitted and, hence, the higher the chance that the right rules are extracted from the trees.

      \n
      \n\n","category":"page"},{"location":"#Interpretation","page":"SIRUS","title":"Interpretation","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Finally, let's interpret the rules that the model has learned. Since we know that the model performs well on the cross-validations, we can fit our preferred model on the complete dataset:

      \n
      \n\n
      let\n    model = StableRulesClassifier(; max_depth=1, rng=_rng())\n    mach = machine(model, X, y)\n    fit!(mach)\n    mach.fitresult\nend
      \n
      StableRules model with 3 rules:\n if X[i, :nodes] < 4.0 then 0.075 else 0.06 +\n if X[i, :age] < 42.0 then 0.022 else 0.02 +\n if X[i, :year] < 1960.0 then 0.015 else 0.017\nand 2 classes: [0.0, 1.0]. \nNote: showing only the probability for class 1.0 since class 0.0 has probability 1 - p.\n
      \n\n\n

      The interpretation of the fitted model is as follows. The model has learned three rules for this dataset. For making a prediction for some value at row i, the model will first look at the value for the nodes feature. If the value is below the listed number, then the number after then is chosen and otherwise the number after else. This is done for all the rules and, finally, the rules are summed to obtain the final prediction.

      \n
      \n\n","category":"page"},{"location":"#Visualization","page":"SIRUS","title":"Visualization","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Since our rules are relatively simple with only a binary outcome and only one clause in each rule, the following figure is a way to visualize the obtained rules per fold. For multiple clauses, I would not know how to visualize the rules. Also, this plot is probably not perfect; let me know if you have suggestions.

      \n

      This figure shows the model uncertainty. The x-position on the left shows log(else-scores / if-scores), the vertical lines on the right show the threshold, and the histograms on the right show the data. For example, for the nodes, it can be seen that all rules (fitted in the different cross-validation folds) base their decision on whether the nodes are below, roughly, 5. Next, the left side indicates that the individuals who had less than 5 nodes are more likely to survive, according to the model. The sizes of the dots indicate the weight that the rule has, so a bigger dot means that a rule plays a larger role in the final outcome. These dots are sized in such a way that a doubling in weight means a doubling in surface size. Finally, the variables are ordered by the sum of the weights.

      \n
      \n\n\n\n\n\n\n\n\n\n\n\n

      What this plot shows is that the nodes feature is on average chosen as the feature with the most predictive power. This can be concluded because the nodes feature is shown as the first feature and the tickness of the dots is the biggest. Furthermore, there is unfortunately some unstability in the position of the splitpoint for the nodes feature. Some models split the data at around 3 and others at around 5. Depending on the context in which this model is used, it might thus be beneficial to decrease the number of empirical quantiles q that the model can use to split on. By default q=10, but maybe something like q=3 would make more sense here.

      \n
      \n\n","category":"page"},{"location":"#Conclusion","page":"SIRUS","title":"Conclusion","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Compared to decision trees, the rule-based classifier is more stable, more accurate and similarly easy to interpet. Compared to the random forest, the rule-based classifier is only slightly less accurate, but much easier to interpet. Due to the interpretability, it is likely that the rule-based classifier will be more accurate in real-world settings. This makes rule-based highly suitable for many machine learning tasks.

      \n
      \n\n\n\n\n\n\n\n\n\n\n\n\n\n","category":"page"},{"location":"#Acknowledgements","page":"SIRUS","title":"Acknowledgements","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Thanks to Clément Bénard, Gérard Biau, Sébastian da Veiga and Erwan Scornet for creating the SIRUS algorithm and documenting it extensively. Special thanks to Clément Bénard for answering my questions regarding the implementation. Thanks to Hylke Donker for figuring out a way to visualize these rules. Also thanks to my PhD supervisors Ruud den Hartigh, Peter de Jonge and Frank Blaauw, and Age de Wit and colleagues at the Dutch Ministry of Defence for providing the data clarifying the constraints of the problem and for providing many methodological suggestions.

      \n
      \n\n","category":"page"},{"location":"#Appendix","page":"SIRUS","title":"Appendix","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n
      \n\n
      begin\n    ENV[\"DATADEPS_ALWAYS_ACCEPT\"] = \"true\"\n\n    using CairoMakie\n    using CategoricalArrays: categorical\n    using CSV: CSV\n    using DataDeps: DataDeps, DataDep, @datadep_str\n    using DataFrames\n    using LightGBM.MLJInterface: LGBMClassifier\n    using MLJDecisionTreeInterface: DecisionTree, DecisionTreeClassifier\n    using MLJ: CV, MLJ, Not, PerformanceEvaluation, auc, fit!, evaluate, machine\n    using StableRNGs: StableRNG\n    using SIRUS\n    using Statistics: mean, std\nend
      \n\n\n
      function _plot_cutpoints(data::AbstractVector)\n    fig = Figure(; resolution=(800, 100))\n    ax = Axis(fig[1, 1])\n    cutpoints = Float64.(unique(ST._cutpoints(data, 10)))\n    scatter!(ax, data, fill(1, length(data)))\n    vlines!(ax, cutpoints; color=:black, linestyle=:dash)\n    textlocs = [(c, 1.1) for c in cutpoints]\n    for cutpoint in cutpoints\n        annotation = string(round(cutpoint; digits=2))::String\n        text!(ax, cutpoint + 0.2, 1.08; text=annotation, textsize=13)\n    end\n    ylims!(ax, 0.9, 1.2)\n    hideydecorations!(ax)\n    return fig\nend;
      \n\n\n
      _rng(seed::Int=1) = StableRNG(seed);
      \n\n\n
      function _io2text(f::Function)\n    io = IOBuffer()\n    f(io)\n    s = String(take!(io))\n    return Base.Text(s)\nend;
      \n\n\n
      function _evaluate(model, X, y; nfolds=10)\n    resampling = CV(; nfolds, shuffle=true, rng=_rng())\n    acceleration = MLJ.CPUThreads()\n    evaluate(model, X, y; acceleration, verbosity=0, resampling, measure=auc)\nend;
      \n\n\n\n\n\n
      function register_haberman()\n    name = \"Haberman\"\n    message = \"Slightly modified copy of Haberman's Survival Data Set\"\n    remote_path = \"https://github.com/rikhuijzer/haberman-survival-dataset/releases/download/v1.0.0/haberman.csv\"\n    checksum = \"a7e9aeb249e11ac17c2b8ea4fdafd5c9392219d27cb819ffaeb8a869eb727a0f\"\n    DataDeps.register(DataDep(name, message, remote_path, checksum))\nend;
      \n\n\n
      function _haberman()\n    register_haberman()\n    dir = datadep\"Haberman\"\n    path = joinpath(dir, \"haberman.csv\")\n    df = CSV.read(path, DataFrame)\n    df[!, :survival] = categorical(df.survival)\n    # Need Floats for the LGBMClassifier.\n    for col in [:age, :year, :nodes]\n        df[!, col] = float.(df[:, col])\n    end\n    return df\nend;
      \n\n\n
      _filter_rng(hyper::NamedTuple) = Base.structdiff(hyper, (; rng=:foo));
      \n\n\n
      _pretty_name(modeltype) = last(split(string(modeltype), '.'));
      \n\n\n
      function _evaluate(modeltype, hyperparameters, X, y)\n    model = modeltype(; hyperparameters...)\n    e = _evaluate(model, X, y)\n    row = (;\n        Model=_pretty_name(modeltype),\n        Hyperparameters=_hyper2str(_filter_rng(hyperparameters)),\n        AUC=_score(e),\n        se=round(only(MLJ.MLJBase._standard_errors(e)); digits=2)\n    )\n    (; e, row)\nend;
      \n\n\n
      _hyper2str(hyper::NamedTuple) = hyper == (;) ? \"(;)\" : string(hyper)::String;
      \n\n\n
      function _score(e::PerformanceEvaluation)\n    return round(only(e.measurement); digits=2)\nend;
      \n\n\n
      e1 = let\n    model = DecisionTreeClassifier\n    hyperparameters = (; max_depth=2, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e2 = let\n    model = StableRulesClassifier\n    hyperparameters = (; max_rules=5, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e3 = let\n    model = StableRulesClassifier\n    hyperparameters = (;  max_rules=25, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e4 = let\n    model = StableRulesClassifier\n    hyperparameters = (; max_depth=1, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e5 = let\n    model = StableForestClassifier\n    hyperparameters = (; max_depth=2, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e6 = let\n    model = LGBMClassifier\n    hyperparameters = (; )\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n","category":"page"},{"location":"","page":"SIRUS","title":"SIRUS","text":"EditURL = \"https://github.com/rikhuijzer/SIRUS.jl/blob/main/docs/src/sirus.jl\"","category":"page"}] +[{"location":"api/#API","page":"API","title":"API","text":"","category":"section"},{"location":"api/#Types","page":"API","title":"Types","text":"","category":"section"},{"location":"api/","page":"API","title":"API","text":"StableRulesClassifier\nStableForestClassifier","category":"page"},{"location":"api/#SIRUS.StableRulesClassifier","page":"API","title":"SIRUS.StableRulesClassifier","text":"StableRulesClassifier(;\n rng::AbstractRNG=default_rng(),\n partial_sampling::Real=0.7,\n n_trees::Int=1_000,\n max_depth::Int=2,\n q::Int=10,\n min_data_in_leaf::Int=5,\n max_rules::Int=10\n) -> MLJModelInterface.Probabilistic\n\nExplainable rule-based model based on a random forest. This SIRUS algorithm extracts rules from a stabilized random forest. See the main page of the documentation for details about how it works.\n\nExample\n\nThe classifier satisfies the MLJ interface, so it can be used like any other MLJ model. For example, it can be used to create a machine:\n\njulia> using SIRUS, MLJ\n\njulia> mach = machine(StableRulesClassifier(; max_rules=15), X, y);\n\nArguments\n\nrng: Random number generator. StableRNGs are advised.\npartial_sampling: Ratio of samples to use in each subset of the data. The default of 0.7 should be fine for most cases.\nn_trees: The number of trees to use. The higher the number, the more likely it is that the correct rules are extracted from the trees, but also the longer model fitting will take. In most cases, 1000 rules should be more than enough, but it might be useful to run 2000 rules one time and verify that the model performance does not change much.\nmax_depth: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).\nq: Number of cutpoints to use per feature. The default value of 10 should be good for most situations.\nmin_data_in_leaf: Minimum number of data points per leaf.\nmax_rules: This is the most important hyperparameter. In general, the more rules, the more accurate the model. However, more rules will also decrease model interpretability. So, it is important to find a good balance here. In most cases, 10-40 rules should provide reasonable accuracy while remaining interpretable.\nlambda: The weights of the final rules are determined via a regularized regression over each rule as a binary feature. This hyperparameter specifies the strength of the ridge (L2) regularizer. Since the rules are quite strongly correlated, the ridge regularizer is the most useful to stabilize the weight estimates.\n\n\n\n\n\n","category":"type"},{"location":"api/#SIRUS.StableForestClassifier","page":"API","title":"SIRUS.StableForestClassifier","text":"StableForestClassifier(;\n rng::AbstractRNG=default_rng(),\n partial_sampling::Real=0.7,\n n_trees::Int=1_000,\n max_depth::Int=2,\n q::Int=10,\n min_data_in_leaf::Int=5\n) <: MLJModelInterface.Probabilistic\n\nRandom forest classifier with a stabilized forest structure (Bénard et al., 2021). This stabilization increases stability when extracting rules. The impact on the predictive accuracy compared to standard random forests should be relatively small.\n\nnote: Note\nJust like normal random forests, this model is not easily explainable. If you are interested in an explainable model, use the StableRulesClassifier.\n\nExample\n\nThe classifier satisfies the MLJ interface, so it can be used like any other MLJ model. For example, it can be used to create a machine:\n\njulia> using SIRUS, MLJ\n\njulia> mach = machine(StableForestClassifier(), X, y);\n\nArguments\n\nrng: Random number generator. StableRNGs are advised.\npartial_sampling: Ratio of samples to use in each subset of the data. The default of 0.7 should be fine for most cases.\nn_trees: The number of trees to use.\nmax_depth: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).\nq: Number of cutpoints to use per feature. The default value of 10 should be good for most situations.\nmin_data_in_leaf: Minimum number of data points per leaf.\n\n\n\n\n\n","category":"type"},{"location":"api/#Methods","page":"API","title":"Methods","text":"","category":"section"},{"location":"api/","page":"API","title":"API","text":"feature_names\ndirections\nvalues(::SIRUS.Rule)\nsatisfies","category":"page"},{"location":"api/#SIRUS.feature_names","page":"API","title":"SIRUS.feature_names","text":"feature_names(rule::Rule) -> Vector{String}\n\nReturn a vector of feature names; one for each clause in rule.\n\n\n\n\n\n","category":"function"},{"location":"api/#SIRUS.directions","page":"API","title":"SIRUS.directions","text":"directions(rule::Rule) -> Vector{Symbol}\n\nReturn a vector of split directions; one for each clause in rule.\n\n\n\n\n\n","category":"function"},{"location":"api/#Base.values-Tuple{SIRUS.Rule}","page":"API","title":"Base.values","text":"values(rule::Rule) -> Vector{Float64}\n\nReturn a vector split values; one for each clause in rule.\n\n\n\n\n\n","category":"method"},{"location":"api/#SIRUS.satisfies","page":"API","title":"SIRUS.satisfies","text":"satisfies(row::AbstractVector, rule::Rule)\n\nReturn whether data row satisfies rule.\n\n\n\n\n\n","category":"function"},{"location":"","page":"SIRUS","title":"SIRUS","text":"\n\n\n\n\n\n\n\n\n\n\n\n\n

      This package is a pure Julia implementation of the Stable and Interpretable RUle Sets (SIRUS) algorithm. The algorithm was originally created by Clément Bénard, Gérard Biau, Sébastien Da Veiga, and Erwan Scornet. This package has only implemented binary classification for now. Regression and multiclass-classification will be implemented later. For R users, the original version of the SIRUS algorithm is available via CRAN.

      \n

      The algorithm is based on random forests. However, compared to random forests, the model is much more interpretable since the forests are converted to a set of decison rules. This page will provide an overview of the algorithm and describe not only how it can be used but also how it works. To do this, let's start by briefly describing random forests.

      \n
      \n\n","category":"page"},{"location":"#Random-forests","page":"SIRUS","title":"Random forests","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Random forests are known to produce accurate predictions especially in settings where the number of features p is close to or higher than the number of observations n (Biau & Scornet, 2016). Let's start by explaining the building blocks of random forests: decision trees. As an example, we take Haberman's Survival Data Set (see the Appendix below for more details):

      \n
      \n\n
      data = _haberman()
      \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
      ageyearnodessurvival
      130.01964.01.01
      230.01962.03.01
      330.01965.00.01
      431.01959.02.01
      531.01965.04.01
      633.01958.010.01
      733.01960.00.01
      834.01959.00.00
      934.01966.09.00
      1034.01958.030.01
      ...
      30683.01958.02.00
      \n\n\n
      X = data[:, Not(:survival)];
      \n\n\n
      y = data.survival;
      \n\n\n\n

      This dataset contains observations from a study with patients who had breast cancer. The survival column contains a 0 if a patient has died within 5 years and 1 if the patient has survived for at least 5 years. The aim is to predict survival based on the age, the year in which the operation was conducted and the number of detected auxillary nodes.

      \n
      \n\n\n

      Via MLJ.jl, we can fit multiple decision trees on this dataset:

      \n
      \n\n
      tree_evaluations = let\n    model = DecisionTreeClassifier(; max_depth=2, rng=_rng())\n    _evaluate(model, X, y)\nend;
      \n\n\n\n

      This has fitted various trees to various subsets of the dataset via cross-validation. Here, I've set max_depth=2 to simplify the fitted trees which makes the tree more easily explainable. Also, for our small dataset, this forces the model to remain simple so it likely reduces overfitting. Let's look at the first tree:

      \n
      \n\n
      let\n    tree = tree_evaluations.fitted_params_per_fold[1].tree\n    _io2text() do io\n        DecisionTree.print_tree(io, tree; feature_names=names(data))\n    end\nend
      \n
      Feature 3: \"nodes\" < 2.5 ?\n├─ Feature 1: \"age\" < 79.5 ?\n    ├─ 2 : 151/178\n    └─ 1 : 1/1\n└─ Feature 1: \"age\" < 43.5 ?\n    ├─ 2 : 16/20\n    └─ 1 : 42/76\n
      \n\n\n

      What this shows is that the first tree decided that the nodes feature is the most helpful in deciding who will survive for 5 more years. Next, if the nodes feature is below 2.5, then age will be selected on. If age < 79.5, then the model will predict the second class and if age ≥ 79.5 it will predict the first class. Similarly for age < 43.5. Now, let's see what happens for a slight change in the data. In other words, let's see how the fitted model for the second split looks:

      \n
      \n\n
      let\n    tree = tree_evaluations.fitted_params_per_fold[2].tree\n    _io2text() do io\n        DecisionTree.print_tree(io, tree; feature_names=names(data))\n    end\nend
      \n
      Feature 3: \"nodes\" < 2.5 ?\n├─ Feature 1: \"age\" < 77.0 ?\n    ├─ 2 : 147/175\n    └─ 1 : 2/2\n└─ Feature 1: \"age\" < 43.5 ?\n    ├─ 2 : 18/21\n    └─ 1 : 41/77\n
      \n\n\n

      This shows that the features and the values for the splitpoints are not the same for both trees. This is called stability. Or in this case, a decision tree is considered to be unstable. This instability is problematic in situations where real-world decisions are based on the outcome of the model. Imagine using this model for the selecting which students are allowed to enter some university. If the model is updated every year with the data from the last year, then the selection criteria would vary wildly per year. This instability also causes accuracy to fluctuate wildly. Intuitively, this makes sense: if the model changes wildly for small data changes, then model accuracy also changes wildly. This intuitively also implies that the model is more likely to overfit. This is why random forests were introduced. Basically, random forests fit a large number of trees and average their predictions to come to a more accurate prediction. The individual trees are obtained by restricting the observations and the features that the trees are allowed to use. For the restriction on the observations, the trees are only allowed to see partial_sampling * n observations. In practise, partial_sampling is often 0.7. The restriction on the features is defined in such a way that it guarantees that not every tree will take the same split at the root of the tree. This makes the trees less correlated (James et al., 2021; Section 8.2.2) and, hence, more accurate.

      \n

      Unfortunately, these random forests are hard to interpret. To interpret the model, individuals would need to interpret hundreds to thousands of trees containing multiple levels. Alternatively, methods have been created to visualize these uninterpretable models (for example, see Molnar (2022); Chapters 6, 7 and 8). The most promising one of these methods are Shapley values and SHAP. These methods show which features have the highest influence on the prediction. See my blog post on Random forests and Shapley values for more information. Knowing which features have the highest influence is nice, but they do not state exactly what feature is used and at what cutoff. Again, this is not good enough for selecting students into universities. For example, what if the government decides to ask for details about the selection? The only answer that you can give is that some features are used for selection more than others and that they are on average used in a certain direction. If the government asks for biases in the model, then these are impossible to report. In practice, the decision is still a black-box. SIRUS solves this by extracting easily interpretable rules from the random forests.

      \n
      \n\n","category":"page"},{"location":"#Rule-based-models","page":"SIRUS","title":"Rule-based models","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Rule-based models promise much greater interpretability than random forests. Instead of returning a large number of trees, rule-based models return a set of rules. Each rule can be interpreted on its own and the final model aggregates these rules by summing the prediction of each rules. For example, one rule can be:

      \n
      \n

      if nodes < 4.5 then chance of survival is 0.6 and if nodes ≥ 4.5 then chance of survival is 0.4.

      \n
      \n

      Note that these rules can be extracted quite easily from the decision trees. For splits on the second level of the tree, the rule could look like:

      \n
      \n

      if nodes < 4.5 and age < 38.5 then chance of survival is 0.8 and otherwise the chance of survival is 0.4.

      \n
      \n

      When applying this extracting of rules to a random forest, there will be thousands of rules. Next, via some heuristic, the most important rules can be localized and these rules then result in the final model. See, for example, RuleFit (Friedman & Popescu, 2008). The problem with this approach is that they are fitted on the unstable decision trees that were shown above. As an example, on time the tree splits on age < 43.5 and another time on age < 44.5.

      \n
      \n\n","category":"page"},{"location":"#Tree-stabilization","page":"SIRUS","title":"Tree stabilization","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      In the papers which introduce SIRUS, Bénard et al. (2021a, 2021b) proof that their algorithm is stable and that the other algorithms are not. They achieve their stability by restricting the location at which the splitpoints can be chosen. To see how this works, let's look at the age feature on its own.

      \n
      \n\n
      nodes = sort(data.age);
      \n\n\n\n\n\n\n\n\n\n

      The default random forest algorithm is allowed to choose any location inside this feature to split on. To avoid having to figure out locations by itself, the algorithm will choose on of the datapoints as a split location. So, for example, the following split indicated by the red vertical line would be a valid choice:

      \n
      \n\n\n\n\n\n\n\n\n

      But what happens if we take a random subset of the data? Say, we take the following subset of length 0.7 * length(nodes):

      \n
      \n\n\n\n\n\n\n\n\n\n\n\n

      Now, the algorithm would choose a different location and, hence, introduce instability. To solve this, Bénard et al. decided to limit the splitpoints that the algorithm can use to split to data to a pre-defined set of points. For each feature, they find q empirical quantiles where q is typically 10. Let's overlay these quantiles on top of the age feature:

      \n
      \n\n\n\n\n\n

      Next, let's see where the cutpoints are when we take the same random subset as above:

      \n
      \n\n\n\n\n\n

      As can be seen, many cutpoints are at the same location as before. Furthermore, compared to the unrestricted range, the chance that two different trees who see a different random subset of the data will select the same cutpoint has increased dramatically.

      \n

      The benefit of this is that it is now quite easy to extract the most important rules. Rule extraction consists of simplifying them a bit and ordering them by frequency of occurrence. Let's see how accurate this model is.

      \n
      \n\n","category":"page"},{"location":"#Benchmarks","page":"SIRUS","title":"Benchmarks","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Let's compare the following models:

      \n
        \n
      • Decision tree (DecisionTreeClassifier)

        \n
      • \n
      • Stabilized random forest (StableForestClassifier)

        \n
      • \n
      • SIRUS (StableRulesClassifier)

        \n
      • \n
      • LightGBM (LGBMClassifier)

        \n
      • \n
      \n

      The latter is a state-of-the-art gradient boosting model created by Microsoft. See the Appendix for more details about these results.

      \n
      \n\n
      results = let\n    df = DataFrame(getproperty.([e1, e2, e3, e4, e5, e6], :row))\n    rename!(df, :se => \"1.96*SE\")\nend
      \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
      ModelHyperparametersAUC1.96*SE
      1\"DecisionTreeClassifier\"\"(max_depth = 2,)\"0.640.07
      2\"StableRulesClassifier\"\"(max_rules = 5,)\"0.680.05
      3\"StableRulesClassifier\"\"(max_rules = 25,)\"0.680.06
      4\"StableRulesClassifier\"\"(max_depth = 1,)\"0.70.05
      5\"StableForestClassifier\"\"(max_depth = 2,)\"0.70.05
      6\"LGBMClassifier\"\"(;)\"0.710.06
      \n\n\n\n

      We can summarize these results as follows:

      \n
      \n\n\n\n\n\n

      As can be seen, the score of the stabilized random forest (StableForestClassifier) is almost as good as Microsoft's classifier (LGBMClassifier), but both are not interpretable since that requires interpreting thousands of trees. With the rule-based classifier (StableRulesClassifier), a small amount of predictive performance can be traded for high interpretability. Note that the rule-based classifier may actually be more accurate in practice because verifying and debugging the model is much easier.

      \n

      Regarding the hyperparameters, tuning max_rules and max_depth has the most effect. max_rules specifies the number of rules to which the random forest is simplified. Setting to a high number such as 999 makes the predictive performance similar to that of a random forest, but also makes the interpretability as bad as a random forest. Therefore, it makes more sense to truncate the rules to somewhere in the range 5 to 40 to obtain accurate models with high interpretability. max_depth specifies how many levels the trees have. For larger datasets, max_depth=2 makes the most sense since it can find more complex patterns in the data. For smaller datasets, max_depth=1 makes more sense since it reduces the chance of overfitting. It also simplifies the rules because with max_depth=1, the rule will contain only one conditional (for example, "if A then ...") versus two conditionals (for example, "if A & B then ..."). In some cases, model accuracy can be improved by increasing n_trees. The higher this number, the more trees are fitted and, hence, the higher the chance that the right rules are extracted from the trees.

      \n
      \n\n","category":"page"},{"location":"#Interpretation","page":"SIRUS","title":"Interpretation","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Finally, let's interpret the rules that the model has learned. Since we know that the model performs well on the cross-validations, we can fit our preferred model on the complete dataset:

      \n
      \n\n
      let\n    model = StableRulesClassifier(; max_depth=1, rng=_rng())\n    mach = machine(model, X, y)\n    fit!(mach)\n    mach.fitresult\nend
      \n
      StableRules model with 3 rules:\n if X[i, :nodes] < 4.0 then 0.075 else 0.06 +\n if X[i, :age] < 42.0 then 0.022 else 0.02 +\n if X[i, :year] < 1960.0 then 0.015 else 0.017\nand 2 classes: [0.0, 1.0]. \nNote: showing only the probability for class 1.0 since class 0.0 has probability 1 - p.\n
      \n\n\n

      The interpretation of the fitted model is as follows. The model has learned three rules for this dataset. For making a prediction for some value at row i, the model will first look at the value for the nodes feature. If the value is below the listed number, then the number after then is chosen and otherwise the number after else. This is done for all the rules and, finally, the rules are summed to obtain the final prediction.

      \n
      \n\n","category":"page"},{"location":"#Visualization","page":"SIRUS","title":"Visualization","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Since our rules are relatively simple with only a binary outcome and only one clause in each rule, the following figure is a way to visualize the obtained rules per fold. For multiple clauses, I would not know how to visualize the rules. Also, this plot is probably not perfect; let me know if you have suggestions.

      \n

      This figure shows the model uncertainty. The x-position on the left shows log(else-scores / if-scores), the vertical lines on the right show the threshold, and the histograms on the right show the data. For example, for the nodes, it can be seen that all rules (fitted in the different cross-validation folds) base their decision on whether the nodes are below, roughly, 5. Next, the left side indicates that the individuals who had less than 5 nodes are more likely to survive, according to the model. The sizes of the dots indicate the weight that the rule has, so a bigger dot means that a rule plays a larger role in the final outcome. These dots are sized in such a way that a doubling in weight means a doubling in surface size. Finally, the variables are ordered by the sum of the weights.

      \n
      \n\n\n\n\n\n\n\n\n\n\n\n

      What this plot shows is that the nodes feature is on average chosen as the feature with the most predictive power. This can be concluded because the nodes feature is shown as the first feature and the tickness of the dots is the biggest. Furthermore, there is unfortunately some unstability in the position of the splitpoint for the nodes feature. Some models split the data at around 3 and others at around 5. Depending on the context in which this model is used, it might thus be beneficial to decrease the number of empirical quantiles q that the model can use to split on. By default q=10, but maybe something like q=3 would make more sense here.

      \n
      \n\n","category":"page"},{"location":"#Conclusion","page":"SIRUS","title":"Conclusion","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Compared to decision trees, the rule-based classifier is more stable, more accurate and similarly easy to interpet. Compared to the random forest, the rule-based classifier is only slightly less accurate, but much easier to interpet. Due to the interpretability, it is likely that the rule-based classifier will be more accurate in real-world settings. This makes rule-based highly suitable for many machine learning tasks.

      \n
      \n\n\n\n\n\n\n\n\n\n\n\n\n\n","category":"page"},{"location":"#Acknowledgements","page":"SIRUS","title":"Acknowledgements","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n

      Thanks to Clément Bénard, Gérard Biau, Sébastian da Veiga and Erwan Scornet for creating the SIRUS algorithm and documenting it extensively. Special thanks to Clément Bénard for answering my questions regarding the implementation. Thanks to Hylke Donker for figuring out a way to visualize these rules. Also thanks to my PhD supervisors Ruud den Hartigh, Peter de Jonge and Frank Blaauw, and Age de Wit and colleagues at the Dutch Ministry of Defence for providing the data clarifying the constraints of the problem and for providing many methodological suggestions.

      \n
      \n\n","category":"page"},{"location":"#Appendix","page":"SIRUS","title":"Appendix","text":"","category":"section"},{"location":"","page":"SIRUS","title":"SIRUS","text":"
      \n\n
      \n\n
      begin\n    ENV[\"DATADEPS_ALWAYS_ACCEPT\"] = \"true\"\n\n    using CairoMakie\n    using CategoricalArrays: categorical\n    using CSV: CSV\n    using DataDeps: DataDeps, DataDep, @datadep_str\n    using DataFrames\n    using LightGBM.MLJInterface: LGBMClassifier\n    using MLJDecisionTreeInterface: DecisionTree, DecisionTreeClassifier\n    using MLJ: CV, MLJ, Not, PerformanceEvaluation, auc, fit!, evaluate, machine\n    using StableRNGs: StableRNG\n    using SIRUS\n    using Statistics: mean, std\nend
      \n\n\n
      function _plot_cutpoints(data::AbstractVector)\n    fig = Figure(; resolution=(800, 100))\n    ax = Axis(fig[1, 1])\n    cutpoints = Float64.(unique(ST._cutpoints(data, 10)))\n    scatter!(ax, data, fill(1, length(data)))\n    vlines!(ax, cutpoints; color=:black, linestyle=:dash)\n    textlocs = [(c, 1.1) for c in cutpoints]\n    for cutpoint in cutpoints\n        annotation = string(round(cutpoint; digits=2))::String\n        text!(ax, cutpoint + 0.2, 1.08; text=annotation, textsize=13)\n    end\n    ylims!(ax, 0.9, 1.2)\n    hideydecorations!(ax)\n    return fig\nend;
      \n\n\n
      _rng(seed::Int=1) = StableRNG(seed);
      \n\n\n
      function _io2text(f::Function)\n    io = IOBuffer()\n    f(io)\n    s = String(take!(io))\n    return Base.Text(s)\nend;
      \n\n\n
      function _evaluate(model, X, y; nfolds=10)\n    resampling = CV(; nfolds, shuffle=true, rng=_rng())\n    acceleration = MLJ.CPUThreads()\n    evaluate(model, X, y; acceleration, verbosity=0, resampling, measure=auc)\nend;
      \n\n\n\n\n\n
      function register_haberman()\n    name = \"Haberman\"\n    message = \"Slightly modified copy of Haberman's Survival Data Set\"\n    remote_path = \"https://github.com/rikhuijzer/haberman-survival-dataset/releases/download/v1.0.0/haberman.csv\"\n    checksum = \"a7e9aeb249e11ac17c2b8ea4fdafd5c9392219d27cb819ffaeb8a869eb727a0f\"\n    DataDeps.register(DataDep(name, message, remote_path, checksum))\nend;
      \n\n\n
      function _haberman()\n    register_haberman()\n    dir = datadep\"Haberman\"\n    path = joinpath(dir, \"haberman.csv\")\n    df = CSV.read(path, DataFrame)\n    df[!, :survival] = categorical(df.survival)\n    # Need Floats for the LGBMClassifier.\n    for col in [:age, :year, :nodes]\n        df[!, col] = float.(df[:, col])\n    end\n    return df\nend;
      \n\n\n
      _filter_rng(hyper::NamedTuple) = Base.structdiff(hyper, (; rng=:foo));
      \n\n\n
      _pretty_name(modeltype) = last(split(string(modeltype), '.'));
      \n\n\n
      function _evaluate(modeltype, hyperparameters, X, y)\n    model = modeltype(; hyperparameters...)\n    e = _evaluate(model, X, y)\n    row = (;\n        Model=_pretty_name(modeltype),\n        Hyperparameters=_hyper2str(_filter_rng(hyperparameters)),\n        AUC=_score(e),\n        se=round(only(MLJ.MLJBase._standard_errors(e)); digits=2)\n    )\n    (; e, row)\nend;
      \n\n\n
      _hyper2str(hyper::NamedTuple) = hyper == (;) ? \"(;)\" : string(hyper)::String;
      \n\n\n
      function _score(e::PerformanceEvaluation)\n    return round(only(e.measurement); digits=2)\nend;
      \n\n\n
      e1 = let\n    model = DecisionTreeClassifier\n    hyperparameters = (; max_depth=2, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e2 = let\n    model = StableRulesClassifier\n    hyperparameters = (; max_rules=5, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e3 = let\n    model = StableRulesClassifier\n    hyperparameters = (;  max_rules=25, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e4 = let\n    model = StableRulesClassifier\n    hyperparameters = (; max_depth=1, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e5 = let\n    model = StableForestClassifier\n    hyperparameters = (; max_depth=2, rng=_rng())\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n
      e6 = let\n    model = LGBMClassifier\n    hyperparameters = (; )\n    _evaluate(model, hyperparameters, X, y)\nend;
      \n\n\n","category":"page"},{"location":"","page":"SIRUS","title":"SIRUS","text":"EditURL = \"https://github.com/rikhuijzer/SIRUS.jl/blob/main/docs/src/sirus.jl\"","category":"page"}] }