Skip to content

Commit

Permalink
Merge pull request #606 from stan-dev/fix-1D-unit-vector
Browse files Browse the repository at this point in the history
avoid error for 1-D unit_vector
  • Loading branch information
bgoodri committed Oct 30, 2023
2 parents 92a877c + 48f590e commit a9889e6
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/stan_files/lm.stan
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ transformed data {
}
parameters {
// must not call with init="0"
array[K > 1 ? J : 0] unit_vector[K] u; // primitives for coefficients
// https://github.com/stan-dev/rstanarm/issues/603#issuecomment-1785928224
array[K > 1 ? J : 0] unit_vector[K > 1 ? K : 2] u; // primitives for coefficients
array[J * has_intercept] real z_alpha; // primitives for intercepts
array[J] real<lower=(K > 1 ? 0 : -1), upper=1> R2; // proportions of variance explained
vector[J * (1 - prior_PD)] log_omega; // under/overfitting factors
Expand Down
4 changes: 3 additions & 1 deletion src/stan_files/polr.stan
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ transformed data {
}
parameters {
simplex[J] pi;
array[K > 1] unit_vector[K] u;
// avoid error by making unit_vector have 2 elements when K <= 1
// https://github.com/stan-dev/rstanarm/issues/603#issuecomment-1785928224
array[K > 1] unit_vector[K > 1 ? K : 2] u;
real<lower=(K > 1 ? 0 : -1), upper=1> R2;
array[is_skewed] real<lower=0> alpha;
}
Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/test_stan_lm.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ test_that("stan_lm doesn't break with vb algorithms", {
expect_stanreg(fit2)
})

test_that("stan_lm works with 1 predictor", {
SW(fit <- stan_lm(mpg ~ wt, data = mtcars,
prior = R2(0.5, "mean"), refresh = 0,
seed = SEED))
expect_stanreg(fit)
})

test_that("stan_lm throws error if only intercept", {
expect_error(stan_lm(mpg ~ 1, data = mtcars, prior = R2(location = 0.75)),
regexp = "not suitable for estimating a mean")
Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/test_stan_polr.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ test_that("stan_polr runs for esoph example", {
expect_stanreg(fit2vb)
})

test_that("stan_polr runs with 1 predictor", {
esoph$x1 <- rnorm(nrow(esoph))
expect_stanreg(stan_polr(tobgp ~ x1, data = esoph, prior = R2(0.5, "mean")))
})

test_that("stan_polr throws error if formula excludes intercept", {
expect_error(stan_polr(tobgp ~ 0 + agegp + alcgp, data = esoph,
method = "loglog", prior = R2(0.4, "median")),
Expand Down

0 comments on commit a9889e6

Please sign in to comment.