Skip to content

Commit

Permalink
validations
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinykuo committed Dec 5, 2018
1 parent 34c8a87 commit 7232101
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 99 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Expand Up @@ -5,3 +5,4 @@
^configure\.R$
^logs$
^spark-warehouse$
^man-roxygen$
4 changes: 3 additions & 1 deletion DESCRIPTION
Expand Up @@ -13,7 +13,9 @@ Depends:
R (>= 3.1.2)
Imports:
sparklyr,
forge
forge (>= 0.1.0.9002)
RoxygenNote: 6.1.0
Suggests:
testthat
Remotes:
rstudio/forge
1 change: 1 addition & 0 deletions NAMESPACE
Expand Up @@ -10,3 +10,4 @@ export(xgboost_classifier)
export(xgboost_regressor)
import(forge)
import(sparklyr)
importFrom(sparklyr,invoke)
1 change: 1 addition & 0 deletions R/imports.R
@@ -1,2 +1,3 @@
#' @import forge
#' @importFrom sparklyr invoke
NULL
20 changes: 6 additions & 14 deletions R/xgboost_classifier.R
Expand Up @@ -3,6 +3,7 @@
#' XGBoost classifier for Spark.
#'
#' @inheritParams xgboost_regressor
#' @param num_class Number of classes.
#' @template roxlate-ml-probabilistic-classifier-params
#' @export
xgboost_classifier <- function(x, formula = NULL, eta = 0.3, gamma = 0, max_depth = 6,
Expand Down Expand Up @@ -346,18 +347,9 @@ xgboost_classifier.tbl_spark <- function(x, formula = NULL, eta = 0.3, gamma = 0

# Validator
validator_xgboost_classifier <- function(args) {
args[["checkpoint_interval"]] <- cast_scalar_integer(args[["checkpoint_interval"]])
args[["max_bins"]] <- cast_scalar_integer(args[["max_bins"]])
args[["max_depth"]] <- cast_scalar_integer(args[["max_depth"]])
args[["nthread"]] <- cast_scalar_integer(args[["nthread"]])
args[["num_class"]] <- cast_nullable_scalar_integer(args[["num_class"]])
args[["num_early_stopping_rounds"]] <- cast_scalar_integer(args[["num_early_stopping_rounds"]])
args[["num_round"]] <- cast_scalar_integer(args[["num_round"]])
args[["num_workers"]] <- cast_scalar_integer(args[["num_workers"]])
args[["seed"]] <- cast_scalar_integer(args[["seed"]])
args[["silent"]] <- cast_scalar_integer(args[["silent"]])
args[["thresholds"]] <- cast_nullable_double_list(args[["thresholds"]])
args[["missing"]] <- cast_nullable_scalar_double(args[["missing"]])
args <- validator_xgboost_regressor(args)
args[["thresholds"]] <- cast_nullable_double_list(args[["thresholds"]]) %>%
certify(bounded(0, 1), .allow_null = TRUE, .id = "thresholds")
args
}

Expand All @@ -370,8 +362,8 @@ new_xgboost_classification_model <- function(jobj) {
jobj,
features_col = invoke(jobj, "getFeaturesCol"),
prediction_col = invoke(jobj, "getPredictionCol"),
probability_col = sparklyr:::try_null(invoke(jobj, "getProbabilityCol")),
raw_prediction_col = sparklyr:::try_null(invoke(jobj, "getRawPredictionCol")),
probability_col = invoke(jobj, "getProbabilityCol"),
raw_prediction_col = invoke(jobj, "getRawPredictionCol"),
class = "xgboost_classification_model")
}

Expand Down
237 changes: 153 additions & 84 deletions R/xgboost_regressor.R

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions man/xgboost_classifier.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/xgboost_regressor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7232101

Please sign in to comment.