Skip to content

Commit

Permalink
Merge 31354b9 into a3c4553
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmayer committed Jul 1, 2015
2 parents a3c4553 + 31354b9 commit d9ead7e
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 26 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: caretEnsemble
Type: Package
Title: Ensembles of Caret Models
Version: 1.0.4
Version: 1.0.5
Date: 2015-01-14
Authors@R: c(person(c("Zachary", "A."), "Deane-Mayer", role = c("aut", "cre"),
email = "zach.mayer@gmail.com"),
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import(grid)
import(plyr)
importFrom(caTools,colAUC)
importFrom(data.table,data.table)
importFrom(data.table,dcast.data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,set)
importFrom(data.table,setorderv)
importFrom(digest,digest)
importFrom(gridExtra,grid.arrange)
Expand Down
11 changes: 11 additions & 0 deletions R/aaa.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

##Hack to make data.table functions work with devtools::load_all
#http://stackoverflow.com/questions/23252231/r-data-table-breaks-in-exported-functions
#http://r.789695.n4.nabble.com/Import-problem-with-data-table-in-packages-td4665958.html
assign(".datatable.aware", TRUE)

#Avoid false positives in R CMD CHECK:

utils::globalVariables(
c(".fitted", ".resid", "method", "id", "yhat",
"ymax", "yavg", "ymin", "metric", "metricSD"))
3 changes: 0 additions & 3 deletions R/caretEnsemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,3 @@ autoplot.caretEnsemble <- function(object, which = c(1:6), mfrow = c(3, 2),
labs(title = paste0("Residuals Against ", xvars[2])) + theme_bw()
grid.arrange(g1, g2, g3, g4, g5, g6, ncol=2)
}

utils::globalVariables(c(".fitted", ".resid", "method", "id", "yhat",
"ymax", "yavg", "ymin", "metric", "metricSD"))
67 changes: 45 additions & 22 deletions R/helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,6 @@ extractModelTypes <- function(list_of_models){
return(type)
}

##Hack to make this function work with devtools::load_all
#http://stackoverflow.com/questions/23252231/r-data-table-breaks-in-exported-functions
#http://r.789695.n4.nabble.com/Import-problem-with-data-table-in-packages-td4665958.html
assign(".datatable.aware", TRUE)

#' @title Extract the best predictions from a train object
#' @description Extract predictions for the best tune from a model
#' @param x a train object
Expand All @@ -189,6 +184,7 @@ bestPreds <- function(x){
a <- data.table(x$bestTune, key=names(x$bestTune))
b <- data.table(x$pred, key=names(x$bestTune))
b <- b[a,]
sink <- gc(reset=TRUE)
setorderv(b, c("Resample", "rowIndex"))
return(b)
}
Expand All @@ -198,14 +194,20 @@ bestPreds <- function(x){
#' @param list_of_models an object of class caretList or a list of caret models
#' @importFrom pbapply pblapply
extractBestPreds <- function(list_of_models){
lapply(list_of_models, bestPreds)
out <- lapply(list_of_models, bestPreds)
if(is.null(names(out))){
names(out) <- make.names(sapply(list_of_models, function(x) x$method))
}
sink <- gc(reset=TRUE)
return(out)
}

#' @title Make a prediction matrix from a list of models
#' @description Extract obs from one models, and a matrix of predictions from all other models, a
#' helper function
#'
#' @param list_of_models an object of class caretList
#' @importFrom data.table set rbindlist dcast.data.table
makePredObsMatrix <- function(list_of_models){

#caretList Checks
Expand All @@ -214,30 +216,51 @@ makePredObsMatrix <- function(list_of_models){

#Make a list of models
modelLibrary <- extractBestPreds(list_of_models)
model_names <- names(modelLibrary)

#Model library checks
check_bestpreds_resamples(modelLibrary)
check_bestpreds_indexes(modelLibrary)
check_bestpreds_obs(modelLibrary)
check_bestpreds_preds(modelLibrary)
check_bestpreds_resamples(modelLibrary) #Re-write with data.table?
check_bestpreds_indexes(modelLibrary) #Re-write with data.table?
check_bestpreds_obs(modelLibrary) #Re-write with data.table?
check_bestpreds_preds(modelLibrary) #Re-write with data.table?

#Extract model type (class or reg)
type <- extractModelTypes(list_of_models)

#Extract observations from the frist model in the list
obs <- modelLibrary[[1]]$obs
if (type=="Classification"){
positive <- as.character(unique(modelLibrary[[1]]$obs)[2]) #IMPROVE THIS!
#Add names column
for(i in seq_along(modelLibrary)){
set(modelLibrary[[i]], j="modelname", value=names(modelLibrary)[[i]])
}

#Extract predicteds
if (type=="Regression"){
preds <- sapply(modelLibrary, function(x) as.numeric(x$pred))
} else if (type=="Classification"){
preds <- sapply(modelLibrary, function(x) as.numeric(x[[positive]]))
#Remove parameter columns
keep <- Reduce(intersect, lapply(modelLibrary, names))
for(i in seq_along(modelLibrary)){
rem <- setdiff(names(modelLibrary[[i]]), keep)
if(length(rem) > 0){
for(r in rem){
set(modelLibrary[[i]], j=r, value=NULL)
}
}
}
modelLibrary <- rbindlist(modelLibrary, fill=TRUE)

#For classification models that produce probs, use the probs as preds
#Otherwise, just use class predictions
if (type=="Classification"){
positive <- as.character(unique(modelLibrary$obs)[2]) #IMPROVE THIS!
pos <- as.numeric(modelLibrary[[positive]])
good_pos_values <- which(is.finite(pos))
set(modelLibrary, j="pred", value=as.numeric(modelLibrary[["pred"]]))
set(modelLibrary, i=good_pos_values, j="pred", value=modelLibrary[good_pos_values,positive,with=FALSE])
}

#Reshape wide for meta-modeling
modelLibrary <- data.table::dcast.data.table(
modelLibrary,
obs + rowIndex + Resample ~ modelname,
value.var = "pred"
)

#Name the predicteds and return
colnames(preds) <- make.names(sapply(list_of_models, function(x) x$method), unique=TRUE)
return(list(obs=obs, preds=preds, type=type))
#Return
return(list(obs=modelLibrary$obs, preds=as.matrix(modelLibrary[,model_names,with=FALSE]), type=type))
}

0 comments on commit d9ead7e

Please sign in to comment.