Join GitHub today
GitHub is home to over 50 million developers working together to host and review code, manage projects, and build software together.
Sign upmultiClassSummary #107
multiClassSummary #107
Comments
|
Can you add some error trapping so that, if the class probabilities columns are not in Otherwise, it looks great. Thanks |
|
Definitely. I'll add some error checking and put together a PR. — On Sun, Jan 25, 2015 at 5:22 PM, topepo notifications@github.com wrote:
|
|
You might want to add something like this around line 35:
to help understand what is being computed. Thanks, Max |
|
I tested this function for a 3 class problem and it worked with glmnet and xgbTree, but it failed for gbm. |
|
Just FYI there's similar code floating around: https://github.com/rseiter/PracticalMLProject/blob/master/multiClassSummary.R Might be some useful bits to incorporate? |
|
For future reference, incase that gist goes away: #Multi-Class Summary Function
#Based on caret:::twoClassSummary
# From: http://moderntoolmaking.blogspot.com/2012/07/error-metrics-for-multi-class-problems.html
# RES: disable compilation for debugging
# require(compiler)
# multiClassSummary <- cmpfun(function (data, lev = NULL, model = NULL){
multiClassSummary <- function (data, lev = NULL, model = NULL){
#Load Libraries
require(Metrics)
require(caret)
#Check data
if (!all(levels(data[, "pred"]) == levels(data[, "obs"])))
stop("levels of observed and predicted data do not match")
#Calculate custom one-vs-all stats for each class
prob_stats <- lapply(levels(data[, "pred"]), function(class){
#Grab one-vs-all data for the class
pred <- ifelse(data[, "pred"] == class, 1, 0)
obs <- ifelse(data[, "obs"] == class, 1, 0)
prob <- data[,class]
#Calculate one-vs-all AUC and logLoss and return
cap_prob <- pmin(pmax(prob, .000001), .999999)
prob_stats <- c(auc(obs, prob), logLoss(obs, cap_prob))
names(prob_stats) <- c('ROC', 'logLoss')
return(prob_stats)
})
prob_stats <- do.call(rbind, prob_stats)
rownames(prob_stats) <- paste('Class:', levels(data[, "pred"]))
#Calculate confusion matrix-based statistics
CM <- confusionMatrix(data[, "pred"], data[, "obs"])
#Aggregate and average class-wise stats
#Todo: add weights
# RES: support two classes here as well
#browser() # Debug
if (length(levels(data[, "pred"])) == 2) {
class_stats <- c(CM$byClass, prob_stats[1,])
} else {
class_stats <- cbind(CM$byClass, prob_stats)
class_stats <- colMeans(class_stats)
}
# Aggregate overall stats
overall_stats <- c(CM$overall)
# Combine overall with class-wise stats and remove some stats we don't want
stats <- c(overall_stats, class_stats)
stats <- stats[! names(stats) %in% c('AccuracyNull',
'Prevalence', 'Detection Prevalence')]
# Clean names
names(stats) <- gsub('[[:blank:]]+', '_', names(stats))
if (length(levels(data[, "pred"]) == 2)) {
# Change name ordering to place most useful first
# May want to remove some of these eventually
stats <- stats[c("ROC", "Sensitivity", "Specificity", "Accuracy", "Kappa", "logLoss",
"AccuracyLower", "AccuracyUpper", "AccuracyPValue", "McnemarPValue",
"Pos_Pred_Value", "Neg_Pred_Value", "Detection_Rate",
"Balanced_Accuracy")]
}
return(stats)
} |
|
We used to have a faster version in the package too. |
|
Why'd we take it out? I remember a commit about the "cake is a lie," but I can't remember why |
|
I figured that I shouldn't create and support code that was done better somewhere else (in I'd like to do a release soon and add this. I was going to:
How does that sound? |
|
On second thought, averaging the ROC values is a better idea. Nevermind about that second part... |
|
I like the first part and average ROC Values. Having class-wise roc and Logloss one-vs-all might be nice to have too, but I think it's better to optimize the overall values. |
|
Here's an updated version: multiClassSummary <- function (data, lev = NULL, model = NULL){
#Check data
if (!all(levels(data[, "pred"]) == levels(data[, "obs"])))
stop("levels of observed and predicted data do not match")
## Overall multinomial loss
lloss <- mnLogLoss(data = data, lev = lev, model = model)
#Calculate custom one-vs-all ROC curves for each class
prob_stats <- lapply(levels(data[, "pred"]),
function(class){
#Grab one-vs-all data for the class
pred <- ifelse(data[, "pred"] == class, 1, 0)
obs <- ifelse(data[, "obs"] == class, 1, 0)
prob <- data[,class]
#Calculate one-vs-all AUC
prob_stats <- as.vector(auc(obs, prob))
names(prob_stats) <- c('ROC')
return(prob_stats)
})
roc_stats <- mean(unlist(prob_stats))
#Calculate confusion matrix-based statistics
CM <- confusionMatrix(data[, "pred"], data[, "obs"])
#Aggregate and average class-wise stats
#Todo: add weights
# RES: support two classes here as well
#browser() # Debug
if (length(levels(data[, "pred"])) == 2) {
class_stats <- c(CM$byClass, roc_vals)
} else {
class_stats <- colMeans(CM$byClass)
names(class_stats) <- paste("Mean", names(class_stats))
}
# Aggregate overall stats
overall_stats <- c(CM$overall, lloss, Mean_ROC = roc_stats)
# Combine overall with class-wise stats and remove some stats we don't want
stats <- c(overall_stats, class_stats)
stats <- stats[! names(stats) %in% c('AccuracyNull', "AccuracyLower", "AccuracyUpper",
"AccuracyPValue", "McnemarPValue",
'Mean Prevalence', 'Mean Detection Prevalence')]
# Clean names
names(stats) <- gsub('[[:blank:]]+', '_', names(stats))
if (length(levels(data[, "pred"]) == 2)) {
# Change name ordering to place most useful first
# May want to remove some of these eventually
stats <- stats[c("logLoss", "Mean_ROC",
"Mean_Sensitivity", "Mean_Specificity", "Accuracy", "Kappa",
"Mean_Pos_Pred_Value", "Mean_Neg_Pred_Value", "Mean_Detection_Rate",
"Mean_Balanced_Accuracy")]
}
return(stats)
}
library(AppliedPredictiveModeling)
data(schedulingData)
set.seed(1417)
in_train <- createDataPartition(schedulingData$Class, p = 3/4, list = FALSE)
training <- schedulingData[ in_train,]
testing <- schedulingData[-in_train,]
mod <- train(Class ~ ., data = training, method = "rpart",
tuneLength = 10,
trControl = trainControl(classProbs = TRUE,
summaryFunction = multiClassSummary))
test_pred <- predict(mod, testing, type = "prob")
test_pred$obs <- testing$Class
test_pred$pred <- predict(mod, testing)
multiClassSummary(test_pred, lev = levels(test_pred$obs)) |
|
I just checked-in code. I made some changes so that, if class probabilities are not computed, you can still get the other measures. Take a look at the man file and see if you like the wording. Thanks, Max |
|
|
|
I get the error:
when using I checked
I expected "Mean F1" to be available as a metric. |
|
Your model is using You can use your own function to pass in to
Take a look at
This controls what is displayed versus what is computed. |
|
From the following line seems that
And then takes column means for Thanks for the tip about |
|
Yes, but it looks like the item |
|
Thanks. Where can I find the source code for the most recent implementation of |
|
It is here. If you add it, feel free to do a pull request. Thanks, Max |
|
I added "Mean_F1" to In testing, I've run into some errors:
Result is:
I'm sure I've read of such errors on issues here (at least one which was fixed due to named/unnamed list) and on SO, etc. Ideas? PS. As an aside, in my local copy of |
Look in the package
Yes, it is in Thanks |
|
Sorry, ignore the previous error; I could not reproduce it. Not sure, but I think it was because before running it, I had used My example runs without error, but it does give a warning:
I've been seeing this warning a lot in other code I've been running. Is this of concern? I also ran the
|
Probably not. It often happens when a regression model predicts using the mean value and R^2 can't be computed. You would probably want to look at the The other warnings are expected; we fix a few things here and there and let people know when we do. |
Prototype here: https://gist.github.com/zachmayer/3061272
Just needs some testing!