Skip to content

Commit

Permalink
Moved exception code out of train.default to model code for issue #710
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Aug 18, 2017
1 parent 9bfa386 commit ecfc0fa
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 29 deletions.
63 changes: 40 additions & 23 deletions models/files/pam.R
Expand Up @@ -11,11 +11,15 @@ modelInfo <- list(label = "Nearest Shrunken Centroids",
initialThresh <- pamr::pamr.train(list(x=t(x), y=y))$threshold
initialThresh <- initialThresh[-c(1, length(initialThresh))]
if(search == "grid") {
out <- data.frame(threshold = seq(from = min(initialThresh),
to = max(initialThresh),
length = len))
out <- data.frame(threshold =
seq(from = min(initialThresh),
to = max(initialThresh),
length = len))
} else {
out <- data.frame(threshold = runif(len, min = min(initialThresh),max = max(initialThresh)))
out <- data.frame(threshold =
runif(len,
min = min(initialThresh),
max = max(initialThresh)))
}
out

Expand All @@ -26,18 +30,23 @@ modelInfo <- list(label = "Nearest Shrunken Centroids",
submodels <- list(grid[-1,,drop = FALSE])
list(loop = loop, submodels = submodels)
},
fit = function(x, y, wts, param, lev, last, classProbs, ...)
pamr::pamr.train(list(x = t(x), y = y), threshold = param$threshold, ...),
fit = function(x, y, wts, param, lev, last, classProbs, ...) {
res <- pamr::pamr.train(list(x = t(x), y = y),
threshold = param$threshold, ...)
if (last) {
res$xData <- x
res$yData <- y
}
res
},
predict = function(modelFit, newdata, submodels = NULL) {
out <- pamr::pamr.predict(modelFit,
t(newdata),
threshold = modelFit$tuneValue$threshold)
if(!is.null(submodels))
{
if(!is.null(submodels)) {
tmp <- vector(mode = "list", length = nrow(submodels) + 1)
tmp[[1]] <- out
for(j in seq(along = submodels$threshold))
{
for(j in seq(along = submodels$threshold)) {
tmp[[j+1]] <- pamr::pamr.predict(modelFit,
t(newdata),
threshold = submodels$threshold[j])
Expand All @@ -50,13 +59,11 @@ modelInfo <- list(label = "Nearest Shrunken Centroids",
out <- pamr::pamr.predict(modelFit, t(newdata),
threshold = modelFit$tuneValue$threshold,
type= "posterior")
if(!is.null(submodels))
{
if(!is.null(submodels)) {
tmp <- vector(mode = "list", length = nrow(submodels) + 1)
tmp[[1]] <- out

for(j in seq(along = submodels$threshold))
{
for(j in seq(along = submodels$threshold)) {
tmpProb <- pamr::pamr.predict(modelFit, t(newdata),
threshold = submodels$threshold[j],
type= "posterior")
Expand All @@ -67,20 +74,30 @@ modelInfo <- list(label = "Nearest Shrunken Centroids",
out
},
predictors = function(x, newdata = NULL, threshold = NULL, ...) {
if(is.null(newdata))
{
if(!is.null(x$xData)) newdata <- x$xData else stop("must supply newdata")
if (is.null(newdata)) {
if (!is.null(x$xData))
newdata <- x$xData
else
stop("must supply newdata")
}
if(is.null(threshold))
{
if(!is.null(x$threshold)) threshold <- x$threshold else stop("must supply threshold")
if (is.null(threshold)) {
if (!is.null(x$threshold))
threshold <-
x$threshold
else
stop("must supply threshold")
}
varIndex <- pamr::pamr.predict(x, newx = newdata, threshold = threshold, type = "nonzero")
varIndex <- pamr::pamr.predict(x,
newx = newdata,
threshold = threshold,
type = "nonzero")
colnames(newdata)[varIndex]
},
varImp = function (object, threshold = NULL, data = NULL, ...) {
if(is.null(data)) data <- object$xData
if(is.null(threshold)) threshold <- object$tuneValue$threshold
if(is.null(data))
data <- object$xData
if(is.null(threshold))
threshold <- object$tuneValue$threshold
if( dim(object$centroids)[1] != dim(data)[2])
stop("the number of columns (=variables) is not consistent with the pamr object")

Expand Down
6 changes: 0 additions & 6 deletions pkg/caret/R/train.default.R
Expand Up @@ -861,12 +861,6 @@ train.default <- function(x, y,
}
} else outData <- NULL

## In the case of pam, the data will need to be saved differently
if(trControl$returnData & method == "pam") {
finalModel$xData <- x
finalModel$yData <- y
}

if(trControl$savePredictions == "final")
tmp$predictions <- merge(bestTune, tmp$predictions)

Expand Down
Binary file modified pkg/caret/inst/models/models.RData
Binary file not shown.

0 comments on commit ecfc0fa

Please sign in to comment.