-
Notifications
You must be signed in to change notification settings - Fork 95
Closed
Description
When trying to predict probabilities from a multinomial logistic regression with the 'glmnet' engine, I get the following error:
options(stringsAsFactors = F)
library(tidyverse)
library(magrittr)
#>
#> Attaching package: 'magrittr'
#> The following object is masked from 'package:purrr':
#>
#> set_names
#> The following object is masked from 'package:tidyr':
#>
#> extract
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.2 ──
#> ✔ broom 0.5.2 ✔ recipes 0.1.6
#> ✔ dials 0.0.2 ✔ rsample 0.0.5
#> ✔ infer 0.4.0.1 ✔ yardstick 0.0.3
#> ✔ parsnip 0.0.3.1
#> ── Conflicts ───────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ scales::discard() masks purrr::discard()
#> ✖ magrittr::extract() masks tidyr::extract()
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ recipes::fixed() masks stringr::fixed()
#> ✖ dplyr::lag() masks stats::lag()
#> ✖ magrittr::set_names() masks purrr::set_names()
#> ✖ yardstick::spec() masks readr::spec()
#> ✖ recipes::step() masks stats::step()
library(glmnet)
#> Loading required package: Matrix
#>
#> Attaching package: 'Matrix'
#> The following object is masked from 'package:tidyr':
#>
#> expand
#> Loading required package: foreach
#>
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#>
#> accumulate, when
#> Loaded glmnet 2.0-18
# create a toy dataset
n_obs = 100
n_feats = 200
mat = matrix(NA, nrow = n_obs, ncol = n_feats,
dimnames = list(paste0("observation_", seq_len(n_obs)),
paste0("feature_", seq_len(n_feats))))
mat[] = rnorm(length(c(mat)))
# create labels
labels = runif(n = n_obs, min = 0, max = 3) %>% floor() %>% factor()
# get optimized penalty with cv.glmnet
penalty = cv.glmnet(mat, labels, nfolds = 3, family = 'multinomial') %>%
extract2("lambda.1se")
# create classifier
clf = logistic_reg(mixture = 1,
penalty = penalty,
mode = 'classification') %>%
set_engine('glmnet', family = 'multinomial')
# fit models in cross-validation
x = as.data.frame(mat) %>% mutate(label = labels)
cv = vfold_cv(x, v = 3, strata = 'label')
folded = cv %>%
mutate(
recipes = splits %>%
map(~ prepper(., recipe = recipe(.$data, label ~ .))),
test_data = splits %>% map(analysis),
fits = map2(
recipes,
test_data,
~ fit(
clf,
label ~ .,
data = bake(object = .x, new_data = .y)
)
)
)
# predict on the left-out data
retrieve_predictions = function(split, recipe, model) {
test = bake(recipe, assessment(split))
tbl = tibble(
true = test$label,
pred = predict(model, test)$.pred_class,
prob = predict(model, test, type = 'prob')) %>%
# convert prob from nested df to columns
cbind(.$prob) %>%
select(-prob)
return(tbl)
}
predictions = folded %>%
mutate(
pred = list(
splits,
recipes,
fits
) %>%
pmap(retrieve_predictions)
)
#> Error in attr(x, "names") <- as.character(value): 'names' attribute [3] must be the same length as the vector [2]
Created on 2019-08-21 by the reprex package (v0.3.0)
The issue is caused by this line:
predict(model, test, type = 'prob')
Running traceback()
on that line alone gives the following:
10: `names<-.tbl_df`(`*tmp*`, value = value)
9: `names<-`(`*tmp*`, value = value)
8: `colnames<-`(`*tmp*`, value = object$lvl)
7: object$spec$method$pred$prob$post(res, object)
6: predict_classprob.model_fit(object, new_data = new_data, ...)
5: predict_classprob._multnet(object = object, new_data = new_data,
...)
4: predict_classprob(object = object, new_data = new_data, ...)
3: predict.model_fit(object = object, new_data = new_data, type = type,
opts = opts)
2: predict._multnet(model, test, type = "prob")
1: predict(model, test, type = "prob")
(As a side note, I am having trouble interpreting the output returned by predict(model, test)$.pred_class
. In most tidymodels predict functions, this is an object with length equal to that of test$label
, but in this example, it's three times as long - why is this?)
Metadata
Metadata
Assignees
Labels
No labels