Skip to content

predict(..., type = 'prob') with glmnet logistic_reg #206

@skinnider

Description

@skinnider

When trying to predict probabilities from a multinomial logistic regression with the 'glmnet' engine, I get the following error:

options(stringsAsFactors = F)
library(tidyverse)
library(magrittr)
#> 
#> Attaching package: 'magrittr'
#> The following object is masked from 'package:purrr':
#> 
#>     set_names
#> The following object is masked from 'package:tidyr':
#> 
#>     extract
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.2 ──
#> ✔ broom     0.5.2       ✔ recipes   0.1.6  
#> ✔ dials     0.0.2       ✔ rsample   0.0.5  
#> ✔ infer     0.4.0.1     ✔ yardstick 0.0.3  
#> ✔ parsnip   0.0.3.1
#> ── Conflicts ───────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ scales::discard()     masks purrr::discard()
#> ✖ magrittr::extract()   masks tidyr::extract()
#> ✖ dplyr::filter()       masks stats::filter()
#> ✖ recipes::fixed()      masks stringr::fixed()
#> ✖ dplyr::lag()          masks stats::lag()
#> ✖ magrittr::set_names() masks purrr::set_names()
#> ✖ yardstick::spec()     masks readr::spec()
#> ✖ recipes::step()       masks stats::step()
library(glmnet)
#> Loading required package: Matrix
#> 
#> Attaching package: 'Matrix'
#> The following object is masked from 'package:tidyr':
#> 
#>     expand
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loaded glmnet 2.0-18

# create a toy dataset
n_obs = 100
n_feats = 200
mat = matrix(NA, nrow = n_obs, ncol = n_feats,
             dimnames = list(paste0("observation_", seq_len(n_obs)),
                             paste0("feature_", seq_len(n_feats))))
mat[] = rnorm(length(c(mat)))

# create labels
labels = runif(n = n_obs, min = 0, max = 3) %>% floor() %>% factor()

# get optimized penalty with cv.glmnet
penalty = cv.glmnet(mat, labels, nfolds = 3, family = 'multinomial') %>%
  extract2("lambda.1se")

# create classifier
clf = logistic_reg(mixture = 1,
                   penalty = penalty,
                   mode = 'classification') %>%
  set_engine('glmnet', family = 'multinomial')

# fit models in cross-validation
x = as.data.frame(mat) %>% mutate(label = labels)
cv = vfold_cv(x, v = 3, strata = 'label')
folded = cv %>%
  mutate(
    recipes = splits %>%
      map(~ prepper(., recipe = recipe(.$data, label ~ .))),
    test_data = splits %>% map(analysis),
    fits = map2(
      recipes,
      test_data,
      ~ fit(
        clf,
        label ~ .,
        data = bake(object = .x, new_data = .y)
      )
    )
  )

# predict on the left-out data
retrieve_predictions = function(split, recipe, model) {
  test = bake(recipe, assessment(split))
  tbl = tibble(
    true = test$label,
    pred = predict(model, test)$.pred_class,
    prob = predict(model, test, type = 'prob')) %>%
    # convert prob from nested df to columns
    cbind(.$prob) %>%
    select(-prob)
  return(tbl)
}
predictions = folded %>%
  mutate(
    pred = list(
      splits,
      recipes,
      fits
    ) %>%
      pmap(retrieve_predictions)
  )
#> Error in attr(x, "names") <- as.character(value): 'names' attribute [3] must be the same length as the vector [2]

Created on 2019-08-21 by the reprex package (v0.3.0)

The issue is caused by this line:

predict(model, test, type = 'prob')

Running traceback() on that line alone gives the following:

10: `names<-.tbl_df`(`*tmp*`, value = value)
9: `names<-`(`*tmp*`, value = value)
8: `colnames<-`(`*tmp*`, value = object$lvl)
7: object$spec$method$pred$prob$post(res, object)
6: predict_classprob.model_fit(object, new_data = new_data, ...)
5: predict_classprob._multnet(object = object, new_data = new_data, 
       ...)
4: predict_classprob(object = object, new_data = new_data, ...)
3: predict.model_fit(object = object, new_data = new_data, type = type, 
       opts = opts)
2: predict._multnet(model, test, type = "prob")
1: predict(model, test, type = "prob")

(As a side note, I am having trouble interpreting the output returned by predict(model, test)$.pred_class. In most tidymodels predict functions, this is an object with length equal to that of test$label, but in this example, it's three times as long - why is this?)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions