Skip to content

Commit

Permalink
check for boundary violations in cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
tnagler committed Mar 10, 2024
1 parent e1b914c commit 367cc54
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 28 deletions.
14 changes: 2 additions & 12 deletions R/kde1d-methods.R
Expand Up @@ -32,25 +32,15 @@
#' @export
dkde1d <- function(x, obj) {
x <- prep_eval_arg(x, obj)
sel <- is_below_support(x, obj) | is_above_support(x, obj)
x[sel] <- NA
d <- dkde1d_cpp(x, obj)
d[sel] <- 0
d
dkde1d_cpp(x, obj)
}

#' @param q vector of quantiles.
#' @rdname dkde1d
#' @export
pkde1d <- function(q, obj) {
q <- prep_eval_arg(q, obj)
below <- is_below_support(q, obj)
above <- is_above_support(q, obj)
q[below | above] <- NA
p <- pkde1d_cpp(q, obj)
p[below] <- 0
p[above] <- 1
p
pkde1d_cpp(q, obj)
}

#' @param p vector of probabilities.
Expand Down
14 changes: 0 additions & 14 deletions R/tools.R
Expand Up @@ -36,17 +36,3 @@ prep_eval_arg <- function(x, obj) {
x <- ordered(x, levels(obj$x))
as.numeric(x) - 1
}

#' boolean vector for observations below the distribution's support
#' @noRd
is_below_support <- function(x, obj) {
lower <- if (!is.nan(obj$xmin)) obj$xmin else -Inf
x < lower
}

#' boolean vector for observations above the distribution's support
#' @noRd
is_above_support <- function(x, obj) {
upper <- if (!is.nan(obj$xmax)) obj$xmax else Inf
x > upper
}
12 changes: 10 additions & 2 deletions inst/include/kde1d/kde1d.hpp
Expand Up @@ -89,8 +89,7 @@ class Kde1d
void check_xmin_xmax(const double& xmin, const double& xmax) const;
void check_inputs(const Eigen::VectorXd& x,
const Eigen::VectorXd& weights = Eigen::VectorXd()) const;
void fit_internal(const Eigen::VectorXd& x,
const Eigen::VectorXd& weights = Eigen::VectorXd());
void check_boundaries(const Eigen::VectorXd& x) const;
Eigen::VectorXd pdf_continuous(const Eigen::VectorXd& x) const;
Eigen::VectorXd cdf_continuous(const Eigen::VectorXd& x) const;
Eigen::VectorXd quantile_continuous(const Eigen::VectorXd& x) const;
Expand Down Expand Up @@ -196,6 +195,7 @@ inline void
Kde1d::fit(const Eigen::VectorXd& x, const Eigen::VectorXd& weights)
{
check_inputs(x, weights);
check_boundaries(x);

// preprocessing for nans and jittering
Eigen::VectorXd xx = x;
Expand Down Expand Up @@ -770,6 +770,14 @@ Kde1d::check_inputs(const Eigen::VectorXd& x,
throw std::invalid_argument("x and weights must have the same size");
}

inline void
Kde1d::check_boundaries(const Eigen::VectorXd& x) const
{
if ((x.array() < xmin_).any() | (x.array() > xmax_).any()) {
throw std::invalid_argument("x must be contained in [xmin, xmax].");
}
}

void
Kde1d::set_interpolation_grid(const interp::InterpolationGrid& grid)
{
Expand Down

0 comments on commit 367cc54

Please sign in to comment.