Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test sirus via RCall #58

Merged
merged 12 commits into from
Oct 6, 2023
Merged
7 changes: 7 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ jobs:
with:
version: ${{ matrix.version }}
- uses: julia-actions/cache@v1
- uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true
r-version: '4'
- run: echo "LD_LIBRARY_PATH=$(R RHOME)/lib:$LD_LIBRARY_PATH" >> $GITHUB_ENV
if: matrix.os == 'ubuntu-latest'
- run: Rscript -e 'install.packages("sirus")'
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
with:
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ tmp**

paper/*.pdf
paper/*.jats

.RData
.Rhistory
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Expand All @@ -29,6 +31,7 @@ MLJDecisionTreeInterface = "0.4"
MLJLinearModels = "0.9"
MLJTestInterface = "0.2"
MLJXGBoostInterface = "0.3.8"
RCall = "0.13"
StableRNGs = "1"
StatisticalMeasures = "0.1"
Tables = "1.7"
47 changes: 34 additions & 13 deletions test/mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ end
@test classes isa Vector{<:Int}
end

function _score(e::PerformanceEvaluation)
return round(only(e.measurement); sigdigits=2)
end

function _with_trailing_zero(score::Real)::String
text = string(score)::String
if length(text) == 3
Expand All @@ -87,12 +83,6 @@ function _with_trailing_zero(score::Real)::String
end
end

function _evaluate(model, X, y, nfolds=10, measure=auc)
resampling = CV(; nfolds, shuffle=true, rng=_rng())
acceleration = MLJBase.CPUThreads()
evaluate(model, X, y; acceleration, verbosity=0, resampling, measure)
end

results = DataFrame(;
Dataset=String[],
Model=String[],
Expand All @@ -117,7 +107,7 @@ function _evaluate!(
X, y = datasets[dataset]
nfolds = 10
model = modeltype(; hyperparameters...)
e = _evaluate(model, X, y, nfolds, measure)
e = _evaluate(model, X, y; nfolds, measure)
score = _with_trailing_zero(_score(e))
se = let
val = round(only(MLJBase._standard_errors(e)); digits=2)
Expand Down Expand Up @@ -207,6 +197,11 @@ let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper)
@test 0.60 < _score(e)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
e = _evaluate!(results, data, RSirusClassifier, hyper)
end
end

let
Expand Down Expand Up @@ -234,6 +229,11 @@ let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper)
@test 0.79 < _score(e)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
e = _evaluate!(results, data, RSirusClassifier, hyper)
end
end

let
Expand All @@ -260,6 +260,11 @@ let

hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper; measure)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
e = _evaluate!(results, data, RSirusClassifier, hyper; measure)
end
end

let
Expand All @@ -284,9 +289,14 @@ let

hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
e = _evaluate!(results, data, RSirusClassifier, hyper)
end
end

e_iris = let
let
data = "iris"
measure = accuracy

Expand All @@ -312,7 +322,8 @@ e_iris = let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
e = _evaluate!(results, data, StableRulesClassifier, hyper; measure)
@test 0.62 < _score(e)
e

# R sirus doesn't appear to support multiclass classification.
end

rulesmodel = StableRulesRegressor(; max_depth=2, max_rules=30, rng=_rng())
Expand Down Expand Up @@ -351,6 +362,11 @@ let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
er = _evaluate!(results, data, StableRulesRegressor, hyper; measure=rsq)
@test 0.55 < _score(er)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
_evaluate!(results, data, RSirusRegressor, hyper; measure=rsq)
end
end

emr = let
Expand All @@ -377,6 +393,11 @@ emr = let
hyper = (; rng=_rng(), max_depth=2, max_rules=10)
er = _evaluate!(results, data, StableRulesRegressor, hyper; measure)
@test 0.50 < _score(er)

if CAN_RUN_R_SIRUS
hyper = (; max_depth=2, max_rules=10)
_evaluate!(results, data, RSirusRegressor, hyper; measure=rsq)
end
end

pretty = rename(results, :se => "1.96*SE")
Expand Down
12 changes: 12 additions & 0 deletions test/preliminaries.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import Base

ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"

const CAN_RUN_R_SIRUS = v"1.8" < VERSION

using CategoricalArrays:
CategoricalValue,
CategoricalVector,
Expand Down Expand Up @@ -53,6 +55,16 @@ using Test
const S = SIRUS
_rng(seed::Int=1) = StableRNG(seed)

function _score(e::PerformanceEvaluation)
return round(only(e.measurement); sigdigits=2)
end

function _evaluate(model, X, y; nfolds::Number=10, measure=auc)
resampling = CV(; nfolds, shuffle=true, rng=_rng())
acceleration = MLJBase.CPUThreads()
evaluate(model, X, y; acceleration, verbosity=0, resampling, measure)
end

if !haskey(ENV, "REGISTERED_CANCER")
name = "Cancer"
message = "Wisconsin Diagnostic Breast Cancer (WDBC) dataset"
Expand Down
184 changes: 184 additions & 0 deletions test/rcall.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#
# This file defines the MLJ wrappers around R sirus and tests them.
# Actual comparisons against other models are done in test/mlj.jl.
#

import MLJModelInterface:
MLJModelInterface,
fit,
predict,
metadata_model,
metadata_pkg

using CategoricalArrays:
CategoricalArray,
CategoricalPool,
CategoricalValue
using MLJModelInterface:
MLJModelInterface,
UnivariateFinite,
Continuous,
Count,
Deterministic,
Finite,
Probabilistic,
Table
using RCall

const MMI = MLJModelInterface

# @rlibrary sirus

MMI.@mlj_model mutable struct RSirusRegressor <: Deterministic
max_depth::Int=2
max_rules::Int=10
end

R"library('sirus')"

n = 100
A = rand(_rng(), n)
B = rand(_rng(), n)
X = DataFrame(; A, B)
y = rand(_rng(), n)

function fit(
model::RSirusRegressor,
verbosity::Int,
X,
y
)
if !Tables.istable(X)
error("Expected a Table but got $(typeof(Xnew))")
end
df = DataFrame(X)
fitted_model = R"""
fitted.model <- sirus.fit(
$df,
$y,
type="reg",
num.rule=$(model.max_rules),
p0=NULL,
num.rule.max=$(model.max_rules),
q=4,
max.depth=$(model.max_depth),
num.trees=NULL,
num.threads=1,
verbose=FALSE,
seed=1
)
# print(sirus.print(fitted.model))
fitted.model
"""
fitresult = fitted_model
cache = nothing
report = nothing
return fitresult, cache, report
end

verbosity = 0

model = RSirusRegressor()
mach = machine(model, X, y, verbosity)
fit!(mach; verbosity)
# mach.fitresult

function predict(
model::RSirusRegressor,
fitresult::RObject,
Xnew
)
if !Tables.istable(Xnew)
error("Expected a Table but got $(typeof(Xnew))")
end
df = DataFrame(Xnew)
predictions = R"""
sirus.predict($fitresult, $df)
"""
return rcopy(predictions)
end

predict(mach, X)
e = _evaluate(model, X, y; measure=rsq)
@test 0.6 < _score(e)

MMI.@mlj_model mutable struct RSirusClassifier <: Probabilistic
max_depth::Int=2
max_rules::Int=10
end

function fit(
model::RSirusClassifier,
verbosity::Int,
X,
y::CategoricalArray
)
# Based on MLJXGBoostInterface.
a_target_element = y[1]
@assert a_target_element isa CategoricalValue

if !Tables.istable(X)
error("Expected a Table but got $(typeof(Xnew))")
end
df = DataFrame(X)
outcomes = get.(y)
fitted_model = R"""
fitted.model <- sirus.fit(
$df,
$outcomes,
type="classif",
num.rule=$(model.max_rules),
p0=NULL,
num.rule.max =$(model.max_rules),
q=4,
max.depth=$(model.max_depth),
num.trees=NULL,
num.threads=1,
verbose=FALSE,
seed=1
)
# print(sirus.print(fitted.model))
fitted.model
"""
fitresult = (fitted_model, a_target_element)
cache = nothing
report = nothing
return fitresult, cache, report
end

y = categorical(rand(_rng(), [0, 1], n))

model = RSirusClassifier()
mach = machine(model, X, y, verbosity)
fit!(mach; verbosity)

function predict(
model::RSirusClassifier,
fitresult::Tuple{RObject, CategoricalValue},
Xnew
)
fitted_model, a_target_element = fitresult
if !Tables.istable(Xnew)
error("Expected a Table but got $(typeof(Xnew))")
end
df = DataFrame(Xnew)
rpredictions = R"""
sirus.predict($fitted_model, $df)
"""
classes = MMI.classes(a_target_element)
predictions = rcopy(rpredictions)
augment = ndims(predictions) == 1
@show classes
@show predictions
@show augment
return UnivariateFinite(classes, predictions; augment)
end

predict(mach, X)

e = _evaluate(model, X, y; measure=auc)
@test 0.5 < _score(e)

# Looks like sirus does not support multiclass classification.
# y = categorical(rand(_rng(), [0, 1, 2], n))
# _evaluate(model, X, y; measure=accuracy)
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ end
include("weights.jl")
end

if CAN_RUN_R_SIRUS
@testset "rcall" begin
include("rcall.jl")
end
end

@testset "mlj" begin
include("mlj.jl")
end
Expand Down
Loading