Fix format for single prediction (for multinom_reg()
and prediction type "prob"
)
#612
Labels
multinom_reg()
and prediction type "prob"
)
#612
When calling
stats::predict()
on a multinom model (engine = "nnet"), the output format differs between a scalar input and an input of length > 1. I tried searching a bit whether this is intentional, but could not find anything. So im not sure whether this is a bug or a feature.I personally like the way it behaves for inputs of length > 1, where the column names represent the labels. When predicting on a scalar it is not clear what labels the probabilities correspond to. I assume one could get them from
model_trained$lvl
, but I would be nervous whether the order is guaranteed to be the same, especially since this model is going to go to production in my use case.So my suggestion is to always use the output format that is currently used when inputting a vector.
Here is a reproducible example:
The text was updated successfully, but these errors were encountered: