Skip to content

Commit

Permalink
Extra test cases for case weights
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jul 13, 2017
1 parent dbccbdf commit 351a8fd
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 9 deletions.
44 changes: 41 additions & 3 deletions RegressionTests/Code/earth.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,29 @@ test_reg_loo_model <- train(trainX, trainY,
tuneGrid = egrid,
preProc = c("center", "scale"))

case_weights <- runif(nrow(trainX))

set.seed(849)
test_reg_cv_model_weights <- train(trainX, trainY,
method = "earth",
trControl = rctrl1,
weights = runif(nrow(trainX)),
weights = case_weights,
tuneGrid = egrid,
preProc = c("center", "scale"))

set.seed(849)
test_reg_cv_form_weights <- train(y ~ ., data = training,
method = "earth",
trControl = rctrl1,
weights = runif(nrow(trainX)),
weights = case_weights,
tuneGrid = egrid,
preProc = c("center", "scale"))

set.seed(849)
test_reg_loo_model_weights <- train(trainX, trainY,
method = "earth",
trControl = rctrl2,
weights = runif(nrow(trainX)),
weights = case_weights,
tuneGrid = egrid,
preProc = c("center", "scale"))

Expand All @@ -192,6 +194,42 @@ test_reg_rec <- train(recipe = rec_reg,
method = "earth",
trControl = rctrl1)

tmp <- training
tmp$wts <- case_weights

reg_rec <- recipe(y ~ ., data = tmp) %>%
add_role(wts, new_role = "case weight") %>%
step_center(all_predictors()) %>%
step_scale(all_predictors())

set.seed(849)
test_reg_cv_weight_rec <- train(reg_rec,
data = tmp,
method = "earth",
trControl = rctrl1,
tuneGrid = egrid)
if(
!isTRUE(
all.equal(test_reg_cv_weight_rec$results,
test_reg_cv_form_weights$results))
)
stop("CV weights not giving the same results")

set.seed(849)
test_reg_loo_weight_rec <- train(reg_rec,
data = tmp,
method = "earth",
trControl = rctrl2,
tuneGrid = egrid)
if(
!isTRUE(
all.equal(test_reg_loo_weight_rec$results,
test_reg_loo_model_weights$results))
)
stop("CV weights not giving the same results")



test_reg_pred_rec <- predict(test_reg_rec, testing[, -ncol(testing)])

#########################################################################
Expand Down
47 changes: 43 additions & 4 deletions RegressionTests/Code/glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ test_class_none_model <- train(trainX, trainY,
test_class_none_pred <- predict(test_class_none_model, testing[, -ncol(testing)])
test_class_none_prob <- predict(test_class_none_model, testing[, -ncol(testing)], type = "prob")

case_weights <- runif(nrow(trainX))

set.seed(849)
test_class_cv_weight <- train(trainX, trainY,
weights = runif(nrow(trainX)),
weights = case_weights,
method = "glm",
trControl = cctrl4,
tuneLength = 1,
Expand All @@ -83,7 +85,7 @@ test_class_cv_weight <- train(trainX, trainY,

set.seed(849)
test_class_loo_weight <- train(trainX, trainY,
weights = runif(nrow(trainX)),
weights = case_weights,
method = "glm",
trControl = cctrl5,
tuneLength = 1,
Expand All @@ -101,6 +103,43 @@ test_class_pred_rec <- predict(test_class_rec, testing[, -ncol(testing)])
test_class_prob_rec <- predict(test_class_rec, testing[, -ncol(testing)],
type = "prob")

tmp <- training
tmp$wts <- case_weights

class_rec <- recipe(Class ~ ., data = tmp) %>%
add_role(wts, new_role = "case weight") %>%
step_center(all_predictors()) %>%
step_scale(all_predictors())

set.seed(849)
test_class_cv_weight_rec <- train(class_rec,
data = tmp,
method = "glm",
trControl = cctrl4,
tuneLength = 1,
metric = "Accuracy")
if(
!isTRUE(
all.equal(test_class_cv_weight_rec$results,
test_class_cv_weight$results))
)
stop("CV weights not giving the same results")

set.seed(849)
test_class_loo_weight_rec <- train(class_rec,
data = tmp,
method = "glm",
trControl = cctrl5,
tuneLength = 1,
metric = "Accuracy")
if(
!isTRUE(
all.equal(test_class_loo_weight_rec$results,
test_class_loo_weight$results))
)
stop("CV weights not giving the same results")


test_levels <- levels(test_class_cv_model)
if(!all(levels(trainY) %in% test_levels))
cat("wrong levels")
Expand Down Expand Up @@ -160,15 +199,15 @@ test_reg_none_pred <- predict(test_reg_none_model, testX)

set.seed(849)
test_reg_cv_weight <- train(trainX, trainY,
weights = runif(nrow(trainX)),
weights = case_weights,
method = "glm",
trControl = cctrl4,
tuneLength = 1,
preProc = c("center", "scale"))

set.seed(849)
test_reg_loo_weight <- train(trainX, trainY,
weights = runif(nrow(trainX)),
weights = case_weights,
method = "glm",
trControl = cctrl5,
tuneLength = 1,
Expand Down
30 changes: 28 additions & 2 deletions RegressionTests/Code/multinom.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ testing <- twoClassSim(500, linearVars = 2)
trainX <- training[, -ncol(training)]
trainY <- training$Class

wts <- runif(nrow(trainX))

rec_cls <- recipe(Class ~ ., data = training) %>%
step_center(all_predictors()) %>%
step_scale(all_predictors())
Expand Down Expand Up @@ -87,7 +89,7 @@ test_class_none_prob <- predict(test_class_none_model, testing[, -ncol(testing)]

set.seed(849)
test_class_cv_weight <- train(trainX, trainY,
weights = runif(nrow(trainX)),
weights = wts,
method = "multinom",
trControl = cctrl4,
tuneLength = 2,
Expand All @@ -97,7 +99,7 @@ test_class_cv_weight <- train(trainX, trainY,

set.seed(849)
test_class_loo_weight <- train(trainX, trainY,
weights = runif(nrow(trainX)),
weights = wts,
method = "multinom",
trControl = cctrl5,
tuneLength = 2,
Expand All @@ -116,6 +118,30 @@ test_class_pred_rec <- predict(test_class_rec, testing[, -ncol(testing)])
test_class_prob_rec <- predict(test_class_rec, testing[, -ncol(testing)],
type = "prob")

tmp <- training
tmp$wts <- wts

weight_rec <- recipe(Class ~ ., data = tmp) %>%
add_role(wts, new_role = "case weight") %>%
step_center(all_predictors()) %>%
step_scale(all_predictors())

set.seed(849)
test_class_cv_weight_rec <- train(weight_rec, data = tmp,
method = "multinom",
trControl = cctrl4,
tuneLength = 2,
metric = "Accuracy",
trace = FALSE)

if(
!isTRUE(
all.equal(test_class_cv_weight_rec$results,
test_class_cv_weight$results))
)
stop("CV weights not giving the same results")


test_levels <- levels(test_class_cv_model)
if(!all(levels(trainY) %in% test_levels))
cat("wrong levels")
Expand Down
23 changes: 23 additions & 0 deletions RegressionTests/Code/xgbTree.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,29 @@ test_class_pred_rec <- predict(test_class_rec, testing[, -ncol(testing)])
test_class_prob_rec <- predict(test_class_rec, testing[, -ncol(testing)],
type = "prob")


tmp <- training
tmp$wts <- training_weight

class_rec <- recipe(Class ~ ., data = tmp) %>%
add_role(wts, new_role = "case weight") %>%
step_center(all_predictors()) %>%
step_scale(all_predictors())

set.seed(849)
test_class_cv_model_weight_rec <- train(class_rec,
data = tmp,
method = "xgbTree",
trControl = cctrl1,
metric = "ROC",
tuneGrid = xgbGrid)
if(
!isTRUE(
all.equal(test_class_cv_model_weight_rec$results,
test_class_cv_model_weight$results))
)
stop("CV weights not giving the same results")

test_levels <- levels(test_class_cv_model)
if(!all(levels(trainY) %in% test_levels))
cat("wrong levels")
Expand Down

0 comments on commit 351a8fd

Please sign in to comment.