Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiClassSummary #107

Closed
zachmayer opened this issue Jan 24, 2015 · 24 comments
Closed

multiClassSummary #107

zachmayer opened this issue Jan 24, 2015 · 24 comments
Labels

Comments

@zachmayer
Copy link
Collaborator

@zachmayer zachmayer commented Jan 24, 2015

Prototype here: https://gist.github.com/zachmayer/3061272

multiClassSummary <- function (data, lev = NULL, model = NULL){

  #Load Libraries
  require(Metrics)
  require(caret)

  #Check data
  if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) 
    stop("levels of observed and predicted data do not match")

  #Calculate custom one-vs-all stats for each class
  prob_stats <- lapply(levels(data[, "pred"]), function(class){

    #Grab one-vs-all data for the class
    pred <- ifelse(data[, "pred"] == class, 1, 0)
    obs  <- ifelse(data[,  "obs"] == class, 1, 0)
    prob <- data[,class]

    #Calculate one-vs-all AUC and logLoss and return
    cap_prob <- pmin(pmax(prob, .000001), .999999)
    prob_stats <- c(auc(obs, prob), logLoss(obs, cap_prob))
    names(prob_stats) <- c('ROC', 'logLoss')
    return(prob_stats) 
  })
  prob_stats <- do.call(rbind, prob_stats)
  rownames(prob_stats) <- paste('Class:', levels(data[, "pred"]))

  #Calculate confusion matrix-based statistics
  CM <- confusionMatrix(data[, "pred"], data[, "obs"])

  #Aggregate and average class-wise stats
  #Todo: add weights
  class_stats <- cbind(CM$byClass, prob_stats)
  class_stats <- colMeans(class_stats)

  #Aggregate overall stats
  overall_stats <- c(CM$overall)

  #Combine overall with class-wise stats and remove some stats we don't want 
  stats <- c(overall_stats, class_stats)
  stats <- stats[! names(stats) %in% c('AccuracyNull', 
    'Prevalence', 'Detection Prevalence')]

  #Clean names and return
  names(stats) <- gsub('[[:blank:]]+', '_', names(stats))
  return(stats)

}

Just needs some testing!

@topepo
Copy link
Owner

@topepo topepo commented Jan 25, 2015

Can you add some error trapping so that, if the class probabilities columns are not in data, an error occurs (or make those performance estimates NA with a warning)?

Otherwise, it looks great. Thanks

@zachmayer
Copy link
Collaborator Author

@zachmayer zachmayer commented Jan 25, 2015

Definitely.  I'll add some error checking and put together a PR.


Sent from Mailbox

On Sun, Jan 25, 2015 at 5:22 PM, topepo notifications@github.com wrote:

Can you add some error trapping so that, if the class probabilities columns are not in data, an error occurs (or make those performance estimates NA with a warning)?

Otherwise, it looks great. Thanks

Reply to this email directly or view it on GitHub:
#107 (comment)

@topepo
Copy link
Owner

@topepo topepo commented Jan 28, 2015

You might want to add something like this around line 35:

names(class_stats) <- paste0("Mean_", names(class_stats))

to help understand what is being computed.

Thanks,

Max

@krz
Copy link

@krz krz commented Jun 5, 2015

I tested this function for a 3 class problem and it worked with glmnet and xgbTree, but it failed for gbm.
Otherwise it looks good.

@andrewcstewart
Copy link

@andrewcstewart andrewcstewart commented Jul 29, 2015

Just FYI there's similar code floating around: https://github.com/rseiter/PracticalMLProject/blob/master/multiClassSummary.R

Might be some useful bits to incorporate?

@zachmayer
Copy link
Collaborator Author

@zachmayer zachmayer commented Aug 12, 2015

For future reference, incase that gist goes away:

#Multi-Class Summary Function
#Based on caret:::twoClassSummary
# From: http://moderntoolmaking.blogspot.com/2012/07/error-metrics-for-multi-class-problems.html

# RES: disable compilation for debugging
# require(compiler)
# multiClassSummary <- cmpfun(function (data, lev = NULL, model = NULL){
multiClassSummary <- function (data, lev = NULL, model = NULL){

  #Load Libraries
  require(Metrics)
  require(caret)

  #Check data
  if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) 
    stop("levels of observed and predicted data do not match")

  #Calculate custom one-vs-all stats for each class
  prob_stats <- lapply(levels(data[, "pred"]), function(class){

    #Grab one-vs-all data for the class
    pred <- ifelse(data[, "pred"] == class, 1, 0)
    obs  <- ifelse(data[,  "obs"] == class, 1, 0)
    prob <- data[,class]

    #Calculate one-vs-all AUC and logLoss and return
    cap_prob <- pmin(pmax(prob, .000001), .999999)
    prob_stats <- c(auc(obs, prob), logLoss(obs, cap_prob))
    names(prob_stats) <- c('ROC', 'logLoss')
    return(prob_stats) 
  })
  prob_stats <- do.call(rbind, prob_stats)
  rownames(prob_stats) <- paste('Class:', levels(data[, "pred"]))

  #Calculate confusion matrix-based statistics
  CM <- confusionMatrix(data[, "pred"], data[, "obs"])

  #Aggregate and average class-wise stats
  #Todo: add weights
  # RES: support two classes here as well
  #browser() # Debug
  if (length(levels(data[, "pred"])) == 2) {
    class_stats <- c(CM$byClass, prob_stats[1,])
  } else {
    class_stats <- cbind(CM$byClass, prob_stats)
    class_stats <- colMeans(class_stats)
  }

  # Aggregate overall stats
  overall_stats <- c(CM$overall)

  # Combine overall with class-wise stats and remove some stats we don't want 
  stats <- c(overall_stats, class_stats)
  stats <- stats[! names(stats) %in% c('AccuracyNull', 
                                       'Prevalence', 'Detection Prevalence')]

  # Clean names
  names(stats) <- gsub('[[:blank:]]+', '_', names(stats))

  if (length(levels(data[, "pred"]) == 2)) {
    # Change name ordering to place most useful first
    # May want to remove some of these eventually
    stats <- stats[c("ROC", "Sensitivity", "Specificity", "Accuracy", "Kappa", "logLoss",
                     "AccuracyLower", "AccuracyUpper", "AccuracyPValue", "McnemarPValue",
                     "Pos_Pred_Value", "Neg_Pred_Value", "Detection_Rate",
                     "Balanced_Accuracy")]
  }

  return(stats)
}
@topepo
Copy link
Owner

@topepo topepo commented Aug 13, 2015

We used to have a faster version in the package too.

@zachmayer
Copy link
Collaborator Author

@zachmayer zachmayer commented Aug 13, 2015

Why'd we take it out?

I remember a commit about the "cake is a lie," but I can't remember why

@topepo
Copy link
Owner

@topepo topepo commented Sep 7, 2015

I figured that I shouldn't create and support code that was done better somewhere else (in pROC).

I'd like to do a release soon and add this. I was going to:

  • use the mnLogLoss function in the package instead of the one in Metrics
  • compute the one vs all ROC values using the corresponding class probability. For example, in the iris data, we could compute the ROC curve for the setosa probability with the labels "setosa" and "other" then do the same thing for the other classes. That may be a better option than getting all possible ROC curves and averaging since the corresponding class probability should be the most informative.

How does that sound?

@topepo
Copy link
Owner

@topepo topepo commented Sep 7, 2015

On second thought, averaging the ROC values is a better idea. Nevermind about that second part...

@zachmayer
Copy link
Collaborator Author

@zachmayer zachmayer commented Sep 7, 2015

I like the first part and average ROC Values.

Having class-wise roc and Logloss one-vs-all might be nice to have too, but I think it's better to optimize the overall values.

@topepo
Copy link
Owner

@topepo topepo commented Sep 8, 2015

Here's an updated version:

multiClassSummary <- function (data, lev = NULL, model = NULL){

  #Check data
  if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) 
    stop("levels of observed and predicted data do not match")

  ## Overall multinomial loss
  lloss <- mnLogLoss(data = data, lev = lev, model = model)

  #Calculate custom one-vs-all ROC curves for each class
  prob_stats <- lapply(levels(data[, "pred"]), 
                       function(class){
                         #Grab one-vs-all data for the class
                         pred <- ifelse(data[, "pred"] == class, 1, 0)
                         obs  <- ifelse(data[,  "obs"] == class, 1, 0)
                         prob <- data[,class]

                         #Calculate one-vs-all AUC
                         prob_stats <- as.vector(auc(obs, prob))
                         names(prob_stats) <- c('ROC')
                         return(prob_stats) 
                       })
  roc_stats <- mean(unlist(prob_stats))

  #Calculate confusion matrix-based statistics
  CM <- confusionMatrix(data[, "pred"], data[, "obs"])

  #Aggregate and average class-wise stats
  #Todo: add weights
  # RES: support two classes here as well
  #browser() # Debug
  if (length(levels(data[, "pred"])) == 2) {
    class_stats <- c(CM$byClass, roc_vals)
  } else {
    class_stats <- colMeans(CM$byClass)
    names(class_stats) <- paste("Mean", names(class_stats))
  }

  # Aggregate overall stats
  overall_stats <- c(CM$overall, lloss, Mean_ROC = roc_stats)

  # Combine overall with class-wise stats and remove some stats we don't want 
  stats <- c(overall_stats, class_stats)
  stats <- stats[! names(stats) %in% c('AccuracyNull', "AccuracyLower", "AccuracyUpper",
                                       "AccuracyPValue", "McnemarPValue", 
                                       'Mean Prevalence', 'Mean Detection Prevalence')]

  # Clean names
  names(stats) <- gsub('[[:blank:]]+', '_', names(stats))

  if (length(levels(data[, "pred"]) == 2)) {
    # Change name ordering to place most useful first
    # May want to remove some of these eventually
    stats <- stats[c("logLoss", "Mean_ROC", 
                     "Mean_Sensitivity", "Mean_Specificity", "Accuracy", "Kappa", 
                     "Mean_Pos_Pred_Value", "Mean_Neg_Pred_Value", "Mean_Detection_Rate",
                     "Mean_Balanced_Accuracy")]
  }

  return(stats)
}



library(AppliedPredictiveModeling)
data(schedulingData)

set.seed(1417)
in_train <- createDataPartition(schedulingData$Class, p = 3/4, list = FALSE)

training <- schedulingData[ in_train,]
testing  <- schedulingData[-in_train,]

mod <- train(Class ~ ., data = training, method = "rpart",
             tuneLength = 10,
             trControl = trainControl(classProbs = TRUE,
                                      summaryFunction = multiClassSummary))

test_pred <- predict(mod, testing, type = "prob")
test_pred$obs <- testing$Class
test_pred$pred <- predict(mod, testing)

multiClassSummary(test_pred, lev = levels(test_pred$obs))
topepo added a commit that referenced this issue Sep 8, 2015
@topepo
Copy link
Owner

@topepo topepo commented Sep 8, 2015

I just checked-in code. I made some changes so that, if class probabilities are not computed, you can still get the other measures.

Take a look at the man file and see if you like the wording.

Thanks,

Max

topepo added a commit that referenced this issue Sep 8, 2015
@zachmayer
Copy link
Collaborator Author

@zachmayer zachmayer commented Sep 8, 2015

👍 from me!

@zachmayer zachmayer closed this Sep 8, 2015
@adatum
Copy link
Contributor

@adatum adatum commented Jan 5, 2017

I get the error:

The metric "Mean F1" was not in the result set. NA will be used instead.

when using train( .... metric = "Mean F1" ....) and the multiClassSummary summary function. Why is this?

I checked caret::confusionMatrix, and F1 is one of the byClass statistics, even though it is not displayed by default when the confusion matrix is printed out. So, when reading the code above (which may be out of date) for multiClassSummary, specifically:

class_stats <- colMeans(CM$byClass)
names(class_stats) <- paste("Mean", names(class_stats))

I expected "Mean F1" to be available as a metric.

@topepo
Copy link
Owner

@topepo topepo commented Jan 5, 2017

Your model is using multiClassSummary (not confusionMatrix) to calculate performance.

You can use your own function to pass in to train if you want the results from confusionMatrix.

F1 is one of the byClass statistics, even though it is not displayed by default

Take a look at ?confusionMatrix. Is has an option called mode:

a single character string either "sens_spec", "prec_recall", or "everything"

This controls what is displayed versus what is computed.

@adatum
Copy link
Contributor

@adatum adatum commented Jan 5, 2017

From the following line seems that multiClassSummary does use the results from confusionMatrix, no?

CM <- confusionMatrix(data[, "pred"], data[, "obs"])

And then takes column means for byClass statistics, as I previously posted, to get multi-class summary statistics.

Thanks for the tip about confusionMatrix's mode. I keep learning new details about caret. So much hidden treasure!

@topepo
Copy link
Owner

@topepo topepo commented Jan 5, 2017

Yes, but it looks like the item stat_list does not return the F scores from the byClass statistics. I can change that in the next release but, for now, you'll have to use an altered version of the function.

@adatum
Copy link
Contributor

@adatum adatum commented Jan 5, 2017

Thanks. Where can I find the source code for the most recent implementation of multiClassSummary so I can modify it right now? You mention stat_list which is not found in the code in this thread.

@topepo
Copy link
Owner

@topepo topepo commented Jan 5, 2017

It is here. If you add it, feel free to do a pull request.

Thanks,

Max

@adatum
Copy link
Contributor

@adatum adatum commented Jan 7, 2017

I added "Mean_F1" to stat_list and made a pull request.

In testing, I've run into some errors:


# source local copy of fixed multiClassSummary

in_training <- createDataPartition(iris$Species, p = 0.6, list = FALSE)
training <- iris[in_training, ]
testing <- iris[-in_training, ]

mycontrol <- trainControl(
    method = 'repeatedcv',
    number = 10,
    repeats = 3,
    classProbs = TRUE,
    selectionFunction = "oneSE",
    summaryFunction = multiClassSummary
)

set.seed(19556)

model_nnet <- train(x = subset(training, select = -Species),
                    y = training$Species,
                    preProcess = c("center", "scale"),
                    trControl = mycontrol,
                    method = "nnet",
                    metric = "Mean_F1")

Result is:

Error in vector(type, length) :
vector: cannot make a vector of mode 'NULL'.
In addition: Warning message:
In train.default(x = subset(training, select = -Species), y = training$Species, :
The metric "Mean_F1" was not in the result set. NA will be used instead.

I'm sure I've read of such errors on issues here (at least one which was fixed due to named/unnamed list) and on SO, etc. Ideas?

PS. As an aside, in my local copy of multiClassSummary I had to rename all instances of requireNamespaceQuietStop to requireNamespace since no function by the former name was found. Does caret or another package implement it as a wrapper for requireNamespace( ... , quietly = TRUE) ?

@topepo
Copy link
Owner

@topepo topepo commented Jan 8, 2017

The metric "Mean_F1" was not in the result set. NA will be used instead.c

Look in the package test directory and you'll see the test. It could be that you just need to adjust the expected results so that it 1) knows that it should output that value and 2) it knows what the correct value is. It could be something else though so look carefully.

Does caret or another package implement it as a wrapper for requireNamespace( ... , quietly = TRUE)

Yes, it is in caret. For debugging, you can just use caret:: requireNamespaceQuietStop but remove the namespace part before committing. R CMD check has some explicit rules about how you are not allowed to call by namespace in packages and those functions are just more formal ways of doing it.

Thanks

@adatum
Copy link
Contributor

@adatum adatum commented Jan 8, 2017

Sorry, ignore the previous error; I could not reproduce it. Not sure, but I think it was because before running it, I had used doParallel without having run registerDoSEQ afterwards, even though I had run stopCluster and (probably) restarted the R session in RStudio.

My example runs without error, but it does give a warning:

Warning message:
In nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, :
There were missing values in resampled performance measures.

I've been seeing this warning a lot in other code I've been running. Is this of concern?

I also ran the testthat tests and they all passed, including for multiClassSummary. You're probably aware, but I thought I'd mention there were four warnings for confusionMatrix and one for earth:

Warnings --------------------------------------------------------------------------

  1. Confusion matrix works (@test_confusionMatrix.R#30) - Levels are not in the same order for reference and data. Refactoring data to match.

  2. Confusion matrix works (@test_confusionMatrix.R#31) - Levels are not in the same order for reference and data. Refactoring data to match.

  3. Confusion matrix works (@test_confusionMatrix.R#32) - The data contains levels not found in the data, but they are empty and will be dropped.

  4. Confusion matrix works (@test_confusionMatrix.R#32) - Levels are not in the same order for reference and data. Refactoring data to match.

  5. bagEarth simple classification (@test_models_bagEarth.R#23) - glm.fit: fitted probabilities numerically 0 or 1 occurred

@topepo
Copy link
Owner

@topepo topepo commented Jan 9, 2017

I've been seeing this warning a lot in other code I've been running. Is this of concern?

Probably not. It often happens when a regression model predicts using the mean value and R^2 can't be computed. You would probably want to look at the resamples part of the train object to really tell. It cna also happen if the model fails for a particular resample.

The other warnings are expected; we fix a few things here and there and let people know when we do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
5 participants
You can’t perform that action at this time.