Skip to content

Consistent behavior for fit classification when the response isn't a factor #115

@EmilHvitfeldt

Description

@EmilHvitfeldt

For whatever reason you might accidentally specify your response as something other then a factor. This gives rather informative error messages.

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────────────────────────── tidymodels 0.0.2 ──
#> ✔ broom     0.5.1     ✔ purrr     0.2.5
#> ✔ dials     0.0.2     ✔ recipes   0.1.4
#> ✔ dplyr     0.7.8     ✔ rsample   0.0.3
#> ✔ ggplot2   3.1.0     ✔ tibble    1.4.2
#> ✔ infer     0.4.0     ✔ yardstick 0.0.2
#> ✔ parsnip   0.0.1
#> ── Conflicts ───────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ rsample::fill()  masks tidyr::fill()
#> ✖ dplyr::filter()  masks stats::filter()
#> ✖ dplyr::lag()     masks stats::lag()
#> ✖ recipes::step()  masks stats::step()

### factor

model_logistic <- logistic_reg(mode = "classification") %>%
  set_engine("glm") %>%
  fit(Species ~ ., data = iris)
#> Warning: glm.fit: algorithm did not converge
#> Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred

predict_class(model_logistic, iris)
#>   [1] setosa     setosa     setosa     setosa     setosa     setosa    
#>   [7] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [13] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [19] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [25] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [31] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [37] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [43] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [49] setosa     setosa     versicolor versicolor versicolor versicolor
#>  [55] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [61] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [67] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [73] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [79] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [85] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [91] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [97] versicolor versicolor versicolor versicolor versicolor versicolor
#> [103] versicolor versicolor versicolor versicolor versicolor versicolor
#> [109] versicolor versicolor versicolor versicolor versicolor versicolor
#> [115] versicolor versicolor versicolor versicolor versicolor versicolor
#> [121] versicolor versicolor versicolor versicolor versicolor versicolor
#> [127] versicolor versicolor versicolor versicolor versicolor versicolor
#> [133] versicolor versicolor versicolor versicolor versicolor versicolor
#> [139] versicolor versicolor versicolor versicolor versicolor versicolor
#> [145] versicolor versicolor versicolor versicolor versicolor versicolor
#> Levels: setosa versicolor virginica

### logical

model_logistic <- logistic_reg(mode = "classification") %>%
  set_engine("glm") %>%
  fit(Species ~ ., data = iris %>% mutate(Species = Species == "setosa"))
#> Warning: glm.fit: algorithm did not converge

#> Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred

predict_class(model_logistic, iris)
#> Warning in rep(yes, length.out = length(ans)): 'x' is NULL so the result
#> will be NULL
#> Error in ans[test & ok] <- rep(yes, length.out = length(ans))[test & ok]: replacement has length zero

### numeric

model_logistic <- logistic_reg(mode = "classification") %>%
  set_engine("glm") %>%
  fit(Species ~ ., data = iris %>% mutate(Species = as.numeric(Species)))
#> Error in eval(family$initialize): y values must be 0 <= y <= 1

### character

model_logistic <- logistic_reg(mode = "classification") %>%
  set_engine("glm") %>%
  fit(Species ~ ., data = iris %>% mutate(Species = as.character(Species)))
#> Error in eval(family$initialize): y values must be 0 <= y <= 1

Created on 2018-12-13 by the reprex package (v0.2.1)

When response is logical we manage to fit, but predict_class fails. And when the response is numeric and character it fails to fit in both.

I propose we check inside fit.model_spec that the response is of type factor when doing classification.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions