-
Notifications
You must be signed in to change notification settings - Fork 2
/
mlj.jl
243 lines (215 loc) · 7.34 KB
/
mlj.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
module MLJImplementation
import MLJModelInterface:
fit,
predict,
metadata_model,
metadata_pkg
using CategoricalArrays:
CategoricalArray,
CategoricalValue,
categorical,
levelcode,
unwrap
using MLJModelInterface:
MLJModelInterface,
UnivariateFinite,
Continuous,
Count,
Finite,
Probabilistic,
Table
using Random: AbstractRNG, default_rng
using SIRUS:
StableForest,
StableRules,
_colnames,
_forest,
_mean,
_predict,
_rules,
_process_rules
using Tables: Tables, matrix
"""
StableForestClassifier(;
rng::AbstractRNG=default_rng(),
partial_sampling::Real=0.7,
n_trees::Int=1_000,
max_depth::Int=2,
q::Int=10,
min_data_in_leaf::Int=5
) <: MLJModelInterface.Probabilistic
Random forest classifier with a stabilized forest structure (Bénard et al., [2021](http://proceedings.mlr.press/v130/benard21a.html)).
This stabilization increases stability when extracting rules.
The impact on the predictive accuracy compared to standard random forests should be relatively small.
!!! note
Just like normal random forests, this model is not easily explainable.
If you are interested in an explainable model, use the `StableRulesClassifier`.
# Example
The classifier satisfies the MLJ interface, so it can be used like any other MLJ model.
For example, it can be used to create a machine:
```julia
julia> using SIRUS, MLJ
julia> mach = machine(StableForestClassifier(), X, y);
```
# Arguments
- `rng`: Random number generator. `StableRNGs` are advised.
- `partial_sampling`:
Ratio of samples to use in each subset of the data.
The default of 0.7 should be fine for most cases.
- `n_trees`: The number of trees to use.
- `max_depth`:
The depth of the tree.
A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).
- `q`: Number of cutpoints to use per feature.
The default value of 10 should be good for most situations.
- `min_data_in_leaf`: Minimum number of data points per leaf.
"""
Base.@kwdef mutable struct StableForestClassifier <: Probabilistic
rng::AbstractRNG=default_rng()
partial_sampling::Real=0.7
n_trees::Int=1_000
max_depth::Int=2
q::Int=10
min_data_in_leaf::Int=5
end
"""
StableRulesClassifier(;
rng::AbstractRNG=default_rng(),
partial_sampling::Real=0.7,
n_trees::Int=1_000,
max_depth::Int=2,
q::Int=10,
min_data_in_leaf::Int=5,
max_rules::Int=10
) -> MLJModelInterface.Probabilistic
Explainable rule-based model based on a random forest.
This SIRUS algorithm extracts rules from a stabilized random forest.
See the [main page of the documentation](https://huijzer.xyz/StableTrees.jl/dev/) for details about how it works.
# Example
The classifier satisfies the MLJ interface, so it can be used like any other MLJ model.
For example, it can be used to create a machine:
```julia
julia> using SIRUS, MLJ
julia> mach = machine(StableRulesClassifier(; max_rules=15), X, y);
```
# Arguments
- `rng`: Random number generator. `StableRNGs` are advised.
- `partial_sampling`:
Ratio of samples to use in each subset of the data.
The default of 0.7 should be fine for most cases.
- `n_trees`:
The number of trees to use.
The higher the number, the more likely it is that the correct rules are extracted from the trees, but also the longer model fitting will take.
In most cases, 1000 rules should be more than enough, but it might be useful to run 2000 rules one time and verify that the model performance does not change much.
- `max_depth`:
The depth of the tree.
A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).
- `q`: Number of cutpoints to use per feature.
The default value of 10 should be good for most situations.
- `min_data_in_leaf`: Minimum number of data points per leaf.
- `max_rules`:
This is the most important hyperparameter.
In general, the more rules, the more accurate the model.
However, more rules will also decrease model interpretability.
So, it is important to find a good balance here.
In most cases, 10-40 rules should provide reasonable accuracy while remaining interpretable.
- `lambda`:
The weights of the final rules are determined via a regularized regression over each rule as a binary feature.
This hyperparameter specifies the strength of the ridge (L2) regularizer.
Since the rules are quite strongly correlated, the ridge regularizer is the most useful to stabilize the weight estimates.
"""
Base.@kwdef mutable struct StableRulesClassifier <: Probabilistic
rng::AbstractRNG=default_rng()
partial_sampling::Real=0.7
n_trees::Int=1_000
max_depth::Int=2
q::Int=10
min_data_in_leaf::Int=5
max_rules::Int=10
lambda::Float64=5
end
metadata_model(
StableForestClassifier;
input_scitype=Table(Continuous, Count),
target_scitype=AbstractVector{<:Finite},
supports_weights=false,
docstring="Random forest classifier with a stabilized forest structure",
path="SIRUS.StableForestClassifier"
)
metadata_model(
StableRulesClassifier;
input_scitype=Table(Continuous, Count),
target_scitype=AbstractVector{<:Finite},
supports_weights=false,
docstring="Stable rule-based classifier",
path="SIRUS.StableForestClassifier"
)
metadata_pkg.(
[StableForestClassifier, StableRulesClassifier];
name="SIRUS",
uuid="9113e207-2504-4b06-8eee-d78e288bee65",
url="https://github.com/rikhuijzer/SIRUS.jl",
julia=true,
license="MIT",
is_wrapper=false
)
"""
Return a floating point vector of `A`.
This method patches the version from CategoricalArrays.jl for `AbstractString`s.
"""
function _float(A::CategoricalArray{T}) where T
if !isconcretetype(T)
msg = "`float` not defined on abstractly-typed arrays; please convert to a more specific type"
throw(ArgumentError(msg))
end
if T isa Type{String}
msg = "Cannot automatically convert $(typeof(A)) to an array containing `Float`s."
throw(ArgumentError(msg))
end
return float(A)
end
function fit(model::StableForestClassifier, verbosity::Int, X, y)
forest = _forest(
model.rng,
matrix(X),
_float(y),
_colnames(X);
model.partial_sampling,
model.n_trees,
model.max_depth,
model.q,
model.min_data_in_leaf
)
fitresult = forest
cache = nothing
report = nothing
return fitresult, cache, report
end
function predict(model::StableForestClassifier, fitresult::StableForest, Xnew)
forest = fitresult
return _predict(forest, matrix(Xnew))
end
function fit(model::StableRulesClassifier, verbosity::Int, X, y)
data = matrix(X)
outcome = _float(y)
forest = _forest(
model.rng,
data,
outcome,
_colnames(X);
model.partial_sampling,
model.n_trees,
model.max_depth,
model.q,
model.min_data_in_leaf
)
fitresult = StableRules(forest, data, outcome, model)
cache = nothing
report = nothing
return fitresult, cache, report
end
function predict(model::StableRulesClassifier, fitresult::StableRules, Xnew)
isempty(fitresult.rules) && error("Zero rules")
return _predict(fitresult, matrix(Xnew))
end
end # module