Skip to content

Commit

Permalink
Test sirus via RCall (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Oct 6, 2023
1 parent a353bcc commit 5c87eda
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 13 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ jobs:
with:
version: ${{ matrix.version }}
- uses: julia-actions/cache@v1
- uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true
r-version: '4'
- run: echo "LD_LIBRARY_PATH=$(R RHOME)/lib:$LD_LIBRARY_PATH" >> $GITHUB_ENV
if: matrix.os == 'ubuntu-latest'
- run: Rscript -e 'install.packages("sirus")'
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
with:
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ tmp**

paper/*.pdf
paper/*.jats

.RData
.Rhistory
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Expand All @@ -29,6 +31,7 @@ MLJDecisionTreeInterface = "0.4"
MLJLinearModels = "0.9"
MLJTestInterface = "0.2"
MLJXGBoostInterface = "0.3.8"
RCall = "0.13"
StableRNGs = "1"
StatisticalMeasures = "0.1"
Tables = "1.7"
47 changes: 34 additions & 13 deletions test/mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ end
@test classes isa Vector{<:Int}
end

function _score(e::PerformanceEvaluation)
return round(only(e.measurement); sigdigits=2)
end

function _with_trailing_zero(score::Real)::String
text = string(score)::String
if length(text) == 3
Expand All @@ -87,12 +83,6 @@ function _with_trailing_zero(score::Real)::String
end
end

function _evaluate(model, X, y, nfolds=10, measure=auc)
resampling = CV(; nfolds, shuffle=true, rng=_rng())
acceleration = MLJBase.CPUThreads()
evaluate(model, X, y; acceleration, verbosity=0, resampling, measure)
end

results = DataFrame(;
Dataset=String[],
Model=String[],
Expand All @@ -117,7 +107,7 @@ function _evaluate!(
X, y = datasets[dataset]
nfolds = 10
model = modeltype(; hyperparameters...)
e = _evaluate(model, X, y, nfolds, measure)
e = _evaluate(model, X, y; nfolds, measure)
score = _with_trailing_zero(_score(e))
se = let
val = round(only(MLJBase._standard_errors(e)); digits=2)
Expand Down Expand Up @@ -207,6 +197,11 @@ let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper)
@test 0.60 < _score(e)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
e = _evaluate!(results, data, RSirusClassifier, hyper)
end
end

let
Expand Down Expand Up @@ -234,6 +229,11 @@ let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper)
@test 0.79 < _score(e)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
e = _evaluate!(results, data, RSirusClassifier, hyper)
end
end

let
Expand All @@ -260,6 +260,11 @@ let

hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper; measure)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
e = _evaluate!(results, data, RSirusClassifier, hyper; measure)
end
end

let
Expand All @@ -284,9 +289,14 @@ let

hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
e = _evaluate!(results, data, RSirusClassifier, hyper)
end
end

e_iris = let
let
data = "iris"
measure = accuracy

Expand All @@ -312,7 +322,8 @@ e_iris = let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper; measure)
@test 0.62 < _score(e)
e

# R sirus doesn't appear to support multiclass classification.
end

rulesmodel = StableRulesRegressor(; max_depth=2, max_rules=30, rng=_rng())
Expand Down Expand Up @@ -351,6 +362,11 @@ let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
er = _evaluate!(results, data, StableRulesRegressor, hyper; measure=rsq)
@test 0.55 < _score(er)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
_evaluate!(results, data, RSirusRegressor, hyper; measure=rsq)
end
end

emr = let
Expand All @@ -377,6 +393,11 @@ emr = let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
er = _evaluate!(results, data, StableRulesRegressor, hyper; measure)
@test 0.50 < _score(er)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
_evaluate!(results, data, RSirusRegressor, hyper; measure=rsq)
end
end

pretty = rename(results, :se => "1.96*SE")
Expand Down
12 changes: 12 additions & 0 deletions test/preliminaries.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import Base

ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"

const CAN_RUN_R_SIRUS = v"1.8" < VERSION

using CategoricalArrays:
CategoricalValue,
CategoricalVector,
Expand Down Expand Up @@ -53,6 +55,16 @@ using Test
const S = SIRUS
_rng(seed::Int=1) = StableRNG(seed)

function _score(e::PerformanceEvaluation)
return round(only(e.measurement); sigdigits=2)
end

function _evaluate(model, X, y; nfolds::Number=10, measure=auc)
resampling = CV(; nfolds, shuffle=true, rng=_rng())
acceleration = MLJBase.CPUThreads()
evaluate(model, X, y; acceleration, verbosity=0, resampling, measure)
end

if !haskey(ENV, "REGISTERED_CANCER")
name = "Cancer"
message = "Wisconsin Diagnostic Breast Cancer (WDBC) dataset"
Expand Down
184 changes: 184 additions & 0 deletions test/rcall.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#
# This file defines the MLJ wrappers around R sirus and tests them.
# Actual comparisons against other models are done in test/mlj.jl.
#

import MLJModelInterface:
MLJModelInterface,
fit,
predict,
metadata_model,
metadata_pkg

using CategoricalArrays:
CategoricalArray,
CategoricalPool,
CategoricalValue
using MLJModelInterface:
MLJModelInterface,
UnivariateFinite,
Continuous,
Count,
Deterministic,
Finite,
Probabilistic,
Table
using RCall

const MMI = MLJModelInterface

# @rlibrary sirus

MMI.@mlj_model mutable struct RSirusRegressor <: Deterministic
max_depth::Int=2
max_rules::Int=10
end

R"library('sirus')"

n = 100
A = rand(_rng(), n)
B = rand(_rng(), n)
X = DataFrame(; A, B)
y = rand(_rng(), n)

function fit(
model::RSirusRegressor,
verbosity::Int,
X,
y
)
if !Tables.istable(X)
error("Expected a Table but got $(typeof(Xnew))")
end
df = DataFrame(X)
fitted_model = R"""
fitted.model <- sirus.fit(
$df,
$y,
type="reg",
num.rule=$(model.max_rules),
p0=NULL,
num.rule.max=$(model.max_rules),
q=4,
max.depth=$(model.max_depth),
num.trees=NULL,
num.threads=1,
verbose=FALSE,
seed=1
)
# print(sirus.print(fitted.model))
fitted.model
"""
fitresult = fitted_model
cache = nothing
report = nothing
return fitresult, cache, report
end

verbosity = 0

model = RSirusRegressor()
mach = machine(model, X, y, verbosity)
fit!(mach; verbosity)
# mach.fitresult

function predict(
model::RSirusRegressor,
fitresult::RObject,
Xnew
)
if !Tables.istable(Xnew)
error("Expected a Table but got $(typeof(Xnew))")
end
df = DataFrame(Xnew)
predictions = R"""
sirus.predict($fitresult, $df)
"""
return rcopy(predictions)
end

predict(mach, X)
e = _evaluate(model, X, y; measure=rsq)
@test 0.6 < _score(e)

MMI.@mlj_model mutable struct RSirusClassifier <: Probabilistic
max_depth::Int=2
max_rules::Int=10
end

function fit(
model::RSirusClassifier,
verbosity::Int,
X,
y::CategoricalArray
)
# Based on MLJXGBoostInterface.
a_target_element = y[1]
@assert a_target_element isa CategoricalValue

if !Tables.istable(X)
error("Expected a Table but got $(typeof(Xnew))")
end
df = DataFrame(X)
outcomes = get.(y)
fitted_model = R"""
fitted.model <- sirus.fit(
$df,
$outcomes,
type="classif",
num.rule=$(model.max_rules),
p0=NULL,
num.rule.max =$(model.max_rules),
q=4,
max.depth=$(model.max_depth),
num.trees=NULL,
num.threads=1,
verbose=FALSE,
seed=1
)
# print(sirus.print(fitted.model))
fitted.model
"""
fitresult = (fitted_model, a_target_element)
cache = nothing
report = nothing
return fitresult, cache, report
end

y = categorical(rand(_rng(), [0, 1], n))

model = RSirusClassifier()
mach = machine(model, X, y, verbosity)
fit!(mach; verbosity)

function predict(
model::RSirusClassifier,
fitresult::Tuple{RObject, CategoricalValue},
Xnew
)
fitted_model, a_target_element = fitresult
if !Tables.istable(Xnew)
error("Expected a Table but got $(typeof(Xnew))")
end
df = DataFrame(Xnew)
rpredictions = R"""
sirus.predict($fitted_model, $df)
"""
classes = MMI.classes(a_target_element)
predictions = rcopy(rpredictions)
augment = ndims(predictions) == 1
@show classes
@show predictions
@show augment
return UnivariateFinite(classes, predictions; augment)
end

predict(mach, X)

e = _evaluate(model, X, y; measure=auc)
@test 0.5 < _score(e)

# Looks like sirus does not support multiclass classification.
# y = categorical(rand(_rng(), [0, 1, 2], n))
# _evaluate(model, X, y; measure=accuracy)
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ end
include("weights.jl")
end

if CAN_RUN_R_SIRUS
@testset "rcall" begin
include("rcall.jl")
end
end

@testset "mlj" begin
include("mlj.jl")
end
Expand Down

1 comment on commit 5c87eda

@rikhuijzer
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future reference, the output is

55×7 DataFrame
 Row │ Dataset          Model                   Hyperparameters                  measure   score   1.96*SE  nfolds
     │ String           String                  String                           String    String  String   Int64
─────┼─────────────────────────────────────────────────────────────────────────────────────────────────────────────
   1 │ haberman         DecisionTreeClassifier  (;)                              auc       0.54    0.06         10
   2 │ haberman         LogisticClassifier      (;)                              auc       0.69    0.06         10
   3 │ haberman         XGBoostClassifier       (;)                              auc       0.65    0.04         10
   4 │ haberman         XGBoostClassifier       (max_depth = 2,)                 auc       0.63    0.04         10
   5 │ haberman         StableForestClassifier  (max_depth = 2,)                 auc       0.70    0.05         10
   6 │ haberman         StableRulesClassifier   (max_depth = 2, max_rules = 30)  auc       0.70    0.07         10
   7 │ haberman         StableRulesClassifier   (max_depth = 2, max_rules = 10)  auc       0.67    0.06         10
   8 │ haberman         RSirusClassifier        (max_depth = 2, max_rules = 10)  auc       0.66    0.05         10
   9 │ titanic          DecisionTreeClassifier  (;)                              auc       0.76    0.05         10
  10 │ titanic          LogisticClassifier      (;)                              auc       0.84    0.02         10
  11 │ titanic          XGBoostClassifier       (;)                              auc       0.86    0.03         10
  12 │ titanic          XGBoostClassifier       (max_depth = 2,)                 auc       0.87    0.03         10
  13 │ titanic          StableForestClassifier  (max_depth = 2,)                 auc       0.85    0.02         10
  14 │ titanic          StableRulesClassifier   (max_depth = 2, max_rules = 30)  auc       0.83    0.02         10
  15 │ titanic          StableRulesClassifier   (max_depth = 2, max_rules = 10)  auc       0.83    0.02         10
  16 │ titanic          RSirusClassifier        (max_depth = 2, max_rules = 10)  auc       0.81    0.02         10
  17 │ cancer           DecisionTreeClassifier  (;)                              auc       0.92    0.03         10
  18 │ cancer           MultinomialClassifier   (;)                              auc       0.98    0.01         10
  19 │ cancer           XGBoostClassifier       (;)                              auc       0.99    0.00         10
  20 │ cancer           XGBoostClassifier       (max_depth = 2,)                 auc       0.99    0.00         10
  21 │ cancer           StableForestClassifier  (max_depth = 2,)                 auc       0.99    0.01         10
  22 │ cancer           StableRulesClassifier   (max_depth = 2, max_rules = 30)  auc       0.98    0.01         10
  23 │ cancer           StableRulesClassifier   (max_depth = 2, max_rules = 10)  auc       0.98    0.01         10
  24 │ cancer           RSirusClassifier        (max_depth = 2, max_rules = 10)  auc       0.96    0.02         10
  25 │ diabetes         DecisionTreeClassifier  (;)                              auc       0.67    0.05         10
  26 │ diabetes         LogisticClassifier      (;)                              auc       0.70    0.06         10
  27 │ diabetes         XGBoostClassifier       (;)                              auc       0.80    0.04         10
  28 │ diabetes         XGBoostClassifier       (max_depth = 2,)                 auc       0.82    0.03         10
  29 │ diabetes         StableForestClassifier  (max_depth = 2,)                 auc       0.82    0.03         10
  30 │ diabetes         StableRulesClassifier   (max_depth = 2, max_rules = 30)  auc       0.78    0.04         10
  31 │ diabetes         StableRulesClassifier   (max_depth = 2, max_rules = 10)  auc       0.75    0.05         10
  32 │ diabetes         RSirusClassifier        (max_depth = 2, max_rules = 10)  auc       0.80    0.02         10
  33 │ iris             DecisionTreeClassifier  (;)                              accuracy  0.95    0.03         10
  34 │ iris             MultinomialClassifier   (;)                              accuracy  0.97    0.03         10
  35 │ iris             XGBoostClassifier       (;)                              accuracy  0.94    0.04         10
  36 │ iris             XGBoostClassifier       (max_depth = 2,)                 accuracy  0.93    0.04         10
  37 │ iris             StableForestClassifier  (max_depth = 2,)                 accuracy  0.95    0.04         10
  38 │ iris             StableRulesClassifier   (max_depth = 2, max_rules = 30)  accuracy  0.83    0.10         10
  39 │ iris             StableRulesClassifier   (max_depth = 2, max_rules = 10)  accuracy  0.77    0.08         10
  40 │ boston           DecisionTreeRegressor   (;)                              R²        0.74    0.11         10
  41 │ boston           LinearRegressor         (;)                              R²        0.70    0.05         10
  42 │ boston           XGBoostRegressor        (;)                              R²        0.87    0.05         10
  43 │ boston           XGBoostRegressor        (max_depth = 2,)                 R²        0.86    0.05         10
  44 │ boston           StableForestRegressor   (max_depth = 2,)                 R²        0.67    0.09         10
  45 │ boston           StableRulesRegressor    (max_depth = 2, max_rules = 30)  R²        0.57    0.07         10
  46 │ boston           StableRulesRegressor    (max_depth = 2, max_rules = 10)  R²        0.61    0.09         10
  47 │ boston           RSirusRegressor         (max_depth = 2, max_rules = 10)  R²        0.63    0.07         10
  48 │ make_regression  DecisionTreeRegressor   (;)                              R²        0.90    0.02         10
  49 │ make_regression  LinearRegressor         (;)                              R²        1.00    0.00         10
  50 │ make_regression  XGBoostRegressor        (;)                              R²        0.97    0.01         10
  51 │ make_regression  XGBoostRegressor        (max_depth = 2,)                 R²        0.98    0.00         10
  52 │ make_regression  StableForestRegressor   (max_depth = 2,)                 R²        0.68    0.05         10
  53 │ make_regression  StableRulesRegressor    (max_depth = 2, max_rules = 30)  R²        0.46    0.05         10
  54 │ make_regression  StableRulesRegressor    (max_depth = 2, max_rules = 10)  R²        0.53    0.05         10
  55 │ make_regression  RSirusRegressor         (max_depth = 2, max_rules = 10)  R²        0.71    0.05         10

with Julia version 1.9.3

The small differences for XGBoost performance compared to earlier versions are most likely due to dmlc/XGBoost.jl#191.

Please sign in to comment.