Skip to content

Commit

Permalink
add Sparse Distance Weighted Discrimination #98
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jul 5, 2015
1 parent b8975b7 commit c341c31
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 0 deletions.
78 changes: 78 additions & 0 deletions RegressionTests/Code/sdwd.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
library(caret)
timestamp <- format(Sys.time(), "%Y_%m_%d_%H_%M")

model <- "sdwd"

#########################################################################

set.seed(1)
training <- twoClassSim(50, linearVars = 2)
testing <- twoClassSim(500, linearVars = 2)
trainX <- training[, -ncol(training)]
trainY <- training$Class

cctrl1 <- trainControl(method = "cv", number = 3, returnResamp = "all",
classProbs = TRUE,
summaryFunction = twoClassSummary)
cctrl2 <- trainControl(method = "LOOCV",
classProbs = TRUE, summaryFunction = twoClassSummary)
cctrl3 <- trainControl(method = "none",
classProbs = TRUE, summaryFunction = twoClassSummary)

set.seed(849)
test_class_cv_model <- train(trainX, trainY,
method = "sdwd",
trControl = cctrl1,
metric = "ROC",
preProc = c("center", "scale"))

set.seed(849)
test_class_cv_form <- train(Class ~ ., data = training,
method = "sdwd",
trControl = cctrl1,
metric = "ROC",
preProc = c("center", "scale"))

test_class_pred <- predict(test_class_cv_model, testing[, -ncol(testing)])
test_class_prob <- predict(test_class_cv_model, testing[, -ncol(testing)], type = "prob")
test_class_pred_form <- predict(test_class_cv_form, testing[, -ncol(testing)])
test_class_prob_form <- predict(test_class_cv_form, testing[, -ncol(testing)], type = "prob")

set.seed(849)
test_class_loo_model <- train(trainX, trainY,
method = "sdwd",
trControl = cctrl2,
metric = "ROC",
preProc = c("center", "scale"))

set.seed(849)
test_class_none_model <- train(trainX, trainY,
method = "sdwd",
trControl = cctrl3,
tuneGrid = test_class_cv_model$bestTune,
metric = "ROC",
preProc = c("center", "scale"))

test_class_none_pred <- predict(test_class_none_model, testing[, -ncol(testing)])
test_class_none_prob <- predict(test_class_none_model, testing[, -ncol(testing)], type = "prob")

test_levels <- levels(test_class_cv_model)
if(!all(levels(trainY) %in% test_levels))
cat("wrong levels")

#########################################################################

test_class_predictors1 <- predictors(test_class_cv_model)

#########################################################################

tests <- grep("test_", ls(), fixed = TRUE, value = TRUE)

sInfo <- sessionInfo()

save(list = c(tests, "sInfo", "timestamp"),
file = file.path(getwd(), paste(model, ".RData", sep = "")))

q("no")


56 changes: 56 additions & 0 deletions models/files/sdwd.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
modelInfo <- list(label = "Sparse Distance Weighted Discrimination",
library = "sdwd",
type = "Classification",
parameters = data.frame(parameter = c('lambda', 'lambda2'),
class = c("numeric", "numeric"),
label = c('L1 Penalty', 'L2 Penalty')),
grid = function(x, y, len = NULL) {
lev <- levels(y)
y <- ifelse(y == lev[1], 1, -1)
init <- sdwd(as.matrix(x), y,
nlambda = len + 2,
lambda2 = 0)
lambda <- unique(init$lambda)
lambda <- lambda[-c(1, length(lambda))]
lambda <- lambda[1:min(length(lambda), len)]
expand.grid(lambda = lambda,
lambda2 = seq(0.1, 1, length = len))
},
loop = NULL,
fit = function(x, y, wts, param, lev, last, classProbs, ...) {
y <- ifelse(y == lev[1], 1, -1)
sdwd(as.matrix(x), y = y,
lambda = param$lambda,
lambda2 = param$lambda2,
...)
},
predict = function(modelFit, newdata, submodels = NULL) {
if(!is.matrix(newdata)) newdata <- as.matrix(newdata)
out <- predict(modelFit, newx = newdata, type = "class")
ifelse(out == 1, modelFit$obsLevels[1], modelFit$obsLevels[2])
},
prob = function(modelFit, newdata, submodels = NULL) {
if(!is.matrix(newdata)) newdata <- as.matrix(newdata)
out <- predict(modelFit, newx = newdata, type = "link")
out <- binomial()$linkinv(out)
out <- data.frame(c1 = out, c2 = 1 - out)
colnames(out) <- modelFit$obsLevels
out
},
predictors = function(x, ...) {
out <- apply(x$beta, 1, function(x) any(x != 0))
names(out)[out]
},
varImp = function(object, lambda = NULL, ...) {
out <- as.data.frame(as.matrix(abs(object$beta)))
colnames(out) <- "Overall"
out
},
levels = function(x) if(any(names(x) == "obsLevels")) x$obsLevels else NULL,
tags = c("Discriminant Analysis Models", "Implicit Feature Selection",
"L1 Regularization", "L2 Regularization", "Linear Classifier"),
sort = function(x) x[order(-x$lambda, -x$lambda2),],
trim = function(x) {
x$call <- NULL
x
})
2 changes: 2 additions & 0 deletions pkg/caret/inst/NEWS.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
\itemize{
\item A new model using the \cpkg{randomForest} and \cpkg{inTrees} packages called \code{rfRules} was added. A basic random forest model is used and then is decomposed into rules (of user-specified complexity). The \cpkg{inTrees} package is used to prune and optimize the rules. Thanks to Mirjam Jenny who suggested the workflow.
\item From the \cpkg{rotationForest} package, a model of the same name was added.
\item From the \cpkg{sdwd} package, the model \code{sdwd} was added.
\item Localized linear discriminant analysis (\code{method = "loclda"}) from the \cpkg{klaR} package was added.
\item From the \cpkg{nnls} package, a model of the same name was added.
\item Another linear SVM model from the \cpkg{e1071} package was added using \code{method = "svmLinear2"}
Expand All @@ -18,6 +19,7 @@
\item More error traps were added for common mistakes (e.g. bad factor levels in classification).
\item An internal function (\code{class2ind}) that can be used to make dummy variables for a single factor vector is now documented and exported.
\item A bug was fixed in the \code{xyplot.lift} where the reference line was incorrectly computed. Thanks to Einat Sitbon for finding this.
\item A bug related to calculating the Box-Cox transformation found by John Johnson was fixed.
}
}

Expand Down

0 comments on commit c341c31

Please sign in to comment.