-
Notifications
You must be signed in to change notification settings - Fork 78
/
test_linear_reg_stan.R
78 lines (67 loc) · 1.91 KB
/
test_linear_reg_stan.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
library(testthat)
library(parsnip)
library(rlang)
library(rstanarm)
###################################################################
num_pred <- c("Sepal.Width", "Petal.Width", "Petal.Length")
iris_bad_form <- as.formula(Species ~ term)
iris_basic <- linear_reg()
ctrl <- fit_control(verbosity = 1, catch = FALSE)
caught_ctrl <- fit_control(verbosity = 1, catch = TRUE)
quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE)
test_that('stan_glm execution', {
skip_on_cran()
iris_basic_stan <- linear_reg(others = list(seed = 1333))
# passes interactively but not on R CMD check
expect_error(
res <- fit(
iris_basic,
Sepal.Width ~ log(Sepal.Length) + Species,
data = iris,
control = ctrl,
engine = "stan"
),
regexp = NA
)
expect_error(
res <- fit(
iris_basic,
x = iris[, num_pred],
y = iris$Sepal.Length,
engine = "stan",
control = ctrl
),
regexp = NA
)
expect_error(
res <- fit(
iris_basic,
iris_bad_form,
data = iris,
engine = "stan",
control = ctrl
)
)
})
test_that('stan prediction', {
uni_stan <- stan_glm(Sepal.Length ~ Sepal.Width + Petal.Width + Petal.Length, data = iris, seed = 123)
uni_pred <- unname(predict(uni_stan, newdata = iris[1:5, ]))
inl_stan <- stan_glm(Sepal.Width ~ log(Sepal.Length) + Species, data = iris, seed = 123)
inl_pred <- unname(predict(inl_stan, newdata = iris[1:5, c("Sepal.Length", "Species")]))
res_xy <- fit(
linear_reg(others = list(seed = 123)),
x = iris[, num_pred],
y = iris$Sepal.Length,
engine = "stan",
control = ctrl
)
expect_equal(uni_pred, predict(res_xy, iris[1:5, num_pred]), tolerance = 0.001)
res_form <- fit(
iris_basic,
Sepal.Width ~ log(Sepal.Length) + Species,
data = iris,
engine = "stan",
control = ctrl
)
expect_equal(inl_pred, predict(res_form, iris[1:5, ]), tolerance = 0.001)
})