Join GitHub today
GitHub is home to over 50 million developers working together to host and review code, manage projects, and build software together.
Sign uppredict(..., type = 'prob') with glmnet logistic_reg #206
Comments
|
I have the same problem with predicting probabilities from a multinomial logistic regression using glm. If I do, e.g.,
Then I get:
Which seems consistent with what is reported above.
Now the value of
I assume from the naming that the function above is a postprocessing function that is to be applied to the vector of predictions |
|
Better to use |
|
@patr1ckm is right about the model type (and using
In the meantime, the options(stringsAsFactors = F)
library(tidyverse)
library(magrittr)
#>
#> Attaching package: 'magrittr'
#> The following object is masked from 'package:purrr':
#>
#> set_names
#> The following object is masked from 'package:tidyr':
#>
#> extract
library(tidymodels)
#> Registered S3 method overwritten by 'xts':
#> method from
#> as.zoo.xts zoo
#> ── Attaching packages ─────────────────────────────────────────────────────────────────────────── tidymodels 0.0.3 ──
#> ✔ broom 0.5.2 ✔ recipes 0.1.7.9001
#> ✔ dials 0.0.3.9001 ✔ rsample 0.0.5
#> ✔ infer 0.5.0 ✔ yardstick 0.0.4
#> ✔ parsnip 0.0.4
#> ── Conflicts ────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ scales::discard() masks purrr::discard()
#> ✖ magrittr::extract() masks tidyr::extract()
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ recipes::fixed() masks stringr::fixed()
#> ✖ dplyr::lag() masks stats::lag()
#> ✖ dials::margin() masks ggplot2::margin()
#> ✖ dials::offset() masks stats::offset()
#> ✖ magrittr::set_names() masks purrr::set_names()
#> ✖ yardstick::spec() masks readr::spec()
#> ✖ recipes::step() masks stats::step()
library(glmnet)
#> Loading required package: Matrix
#>
#> Attaching package: 'Matrix'
#> The following objects are masked from 'package:tidyr':
#>
#> expand, pack, unpack
#> Loaded glmnet 3.0
set.seed(363)
n_obs = 100
n_feats = 200
mat = matrix(NA, nrow = n_obs, ncol = n_feats,
dimnames = list(paste0("observation_", seq_len(n_obs)),
paste0("feature_", seq_len(n_feats))))
mat[] = rnorm(length(c(mat)))
labels = runif(n = n_obs, min = 0, max = 3) %>% floor() %>% factor()
# get optimized penalty with cv.glmnet
penalty = cv.glmnet(mat, labels, nfolds = 3, family = 'multinomial') %>%
extract2("lambda.1se")
clf = multinom_reg(mixture = 1,
penalty = penalty,
mode = 'classification') %>%
set_engine('glmnet')
x = as.data.frame(mat) %>% mutate(label = labels)
cv = vfold_cv(x, v = 3, strata = 'label')
# devtools::install_github("tidymodels/tune")
library(tune)
res <- fit_resamples(label ~ ., model = clf, resamples = cv,
control = control_resamples(save_pred = TRUE))
collect_predictions(res)
#> # A tibble: 100 x 4
#> id .pred_class .row label
#> <chr> <fct> <int> <fct>
#> 1 Fold1 1 5 2
#> 2 Fold1 1 6 1
#> 3 Fold1 1 7 2
#> 4 Fold1 1 13 0
#> 5 Fold1 1 21 0
#> 6 Fold1 1 23 1
#> 7 Fold1 1 27 0
#> 8 Fold1 1 28 2
#> 9 Fold1 1 31 0
#> 10 Fold1 1 37 0
#> # … with 90 more rowsCreated on 2019-12-01 by the reprex package (v0.3.0) |
|
Fixed now: options(stringsAsFactors = F)
library(tidyverse)
library(magrittr)
#>
#> Attaching package: 'magrittr'
#> The following object is masked from 'package:purrr':
#>
#> set_names
#> The following object is masked from 'package:tidyr':
#>
#> extract
library(tidymodels)
#> Registered S3 method overwritten by 'xts':
#> method from
#> as.zoo.xts zoo
#> ── Attaching packages ────────────────────────────────────────────────────────────────────────── tidymodels 0.0.3 ──
#> ✔ broom 0.5.2 ✔ recipes 0.1.7.9001
#> ✔ dials 0.0.3.9002 ✔ rsample 0.0.5
#> ✔ infer 0.5.0 ✔ yardstick 0.0.4
#> ✔ parsnip 0.0.4.9000
#> ── Conflicts ───────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ scales::discard() masks purrr::discard()
#> ✖ magrittr::extract() masks tidyr::extract()
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ recipes::fixed() masks stringr::fixed()
#> ✖ dplyr::lag() masks stats::lag()
#> ✖ dials::margin() masks ggplot2::margin()
#> ✖ dials::offset() masks stats::offset()
#> ✖ magrittr::set_names() masks purrr::set_names()
#> ✖ yardstick::spec() masks readr::spec()
#> ✖ recipes::step() masks stats::step()
library(glmnet)
#> Loading required package: Matrix
#>
#> Attaching package: 'Matrix'
#> The following objects are masked from 'package:tidyr':
#>
#> expand, pack, unpack
#> Loaded glmnet 3.0
set.seed(363)
n_obs = 100
n_feats = 200
mat = matrix(NA, nrow = n_obs, ncol = n_feats,
dimnames = list(paste0("observation_", seq_len(n_obs)),
paste0("feature_", seq_len(n_feats))))
mat[] = rnorm(length(c(mat)))
labels = runif(n = n_obs, min = 0, max = 3) %>% floor() %>% factor()
# get optimized penalty with cv.glmnet
penalty = cv.glmnet(mat, labels, nfolds = 3, family = 'multinomial') %>%
extract2("lambda.1se")
clf = multinom_reg(mixture = 1,
penalty = penalty,
mode = 'classification') %>%
set_engine('glmnet')
x = as.data.frame(mat) %>% mutate(label = labels)
cv = vfold_cv(x, v = 3, strata = 'label')
folded = cv %>%
mutate(
recipes = splits %>%
map(~ prepper(., recipe = recipe(.$data, label ~ .))),
test_data = splits %>% map(analysis),
fits = map2(
recipes,
test_data,
~ fit(
clf,
label ~ .,
data = bake(object = .x, new_data = .y)
)
)
)
# predict on the left-out data
retrieve_predictions = function(split, recipe, model) {
test = bake(recipe, assessment(split))
tbl = tibble(
true = test$label,
pred = predict(model, test)$.pred_class,
prob = predict(model, test, type = 'prob')) %>%
# convert prob from nested df to columns
cbind(.$prob) %>%
select(-prob)
return(tbl)
}
predictions = folded %>%
mutate(
pred = list(
splits,
recipes,
fits
) %>%
pmap(retrieve_predictions)
)
head(predictions)
#> # A tibble: 3 x 6
#> splits id recipes test_data fits pred
#> * <named list> <chr> <named lis> <named list> <named li> <named list>
#> 1 <split [66/34… Fold1 <recipe> <df[,201] [66 × 20… <fit[+]> <df[,5] [34 ×…
#> 2 <split [67/33… Fold2 <recipe> <df[,201] [67 × 20… <fit[+]> <df[,5] [33 ×…
#> 3 <split [67/33… Fold3 <recipe> <df[,201] [67 × 20… <fit[+]> <df[,5] [33 ×…Created on 2019-12-01 by the reprex package (v0.3.0) |
When trying to predict probabilities from a multinomial logistic regression with the 'glmnet' engine, I get the following error:
Created on 2019-08-21 by the reprex package (v0.3.0)
The issue is caused by this line:
predict(model, test, type = 'prob')Running
traceback()on that line alone gives the following:(As a side note, I am having trouble interpreting the output returned by
predict(model, test)$.pred_class. In most tidymodels predict functions, this is an object with length equal to that oftest$label, but in this example, it's three times as long - why is this?)