Skip to content

glmnet engines should always respect the penalty value set in the spec #858

@hfrick

Description

@hfrick

parsnip requires a penalty value to be set in the model spec and uses that value for predictions if no other value is provided. This is working fine for linear_reg() but predictions of type = "raw" from logistic regression and multinomial regression ignore that value.

library(parsnip)

# for linear_reg(), predict() respects the penalty value set in the spec
data("hpc_data", package = "modeldata")
hpc <- hpc_data[1:150, c(2:5, 8)]

lm_spec <- linear_reg(penalty = 0.123) |> set_engine("glmnet")
lm_fit <- fit(lm_spec, input_fields ~ log(compounds) + class, data = hpc)

predict(lm_fit, hpc[1:5,], type = "numeric") 
#> # A tibble: 5 × 1
#>   .pred
#>   <dbl>
#> 1  569.
#> 2  164.
#> 3  169.
#> 4  159.
#> 5  168.
predict(lm_fit, hpc[1:5,], type = "raw") 
#>         s1
#> 1 569.0363
#> 2 164.1463
#> 3 168.7923
#> 4 159.3046
#> 5 167.6483


# for logistic_reg(), it does for types class and prob but not for raw
data(lending_club, package = "modeldata")

lr_spec <- logistic_reg(penalty = 0.123) %>% set_engine("glmnet")
f_fit <- fit(lr_spec, Class ~ log(funded_amnt) + int_rate + term,
             data = lending_club)

predict(f_fit, lending_club[1:5, ], type = "class") 
#> # A tibble: 5 × 1
#>   .pred_class
#>   <fct>      
#> 1 good       
#> 2 good       
#> 3 good       
#> 4 good       
#> 5 good
predict(f_fit, lending_club[1:5, ], type = "prob") 
#> # A tibble: 5 × 2
#>   .pred_bad .pred_good
#>       <dbl>      <dbl>
#> 1    0.0525      0.948
#> 2    0.0525      0.948
#> 3    0.0525      0.948
#> 4    0.0525      0.948
#> 5    0.0525      0.948
predict(f_fit, lending_club[1:5, ], type = "raw")
#>         s0       s1       s2       s3       s4       s5       s6       s7
#> 1 2.894019 2.873291 2.859939 2.851942 2.847857 2.846633 2.847493 2.849842
#> 2 2.894019 2.905465 2.920151 2.936813 2.954577 2.972826 2.991126 3.009075
#> 3 2.894019 2.836290 2.790695 2.754340 2.725129 2.701510 2.682315 2.666723
#> 4 2.894019 2.878439 2.869573 2.865522 2.864932 2.866824 2.870474 2.875319
#> 5 2.894019 2.979466 3.058638 3.132017 3.200033 3.263072 3.321482 3.375312
#>         s8       s9      s10      s11      s12      s13      s14      s15
#> 1 2.853267 2.857425 2.862063 2.866986 2.872046 2.877132 2.882159 2.887067
#> 2 3.026646 3.043592 3.059812 3.075242 3.089848 3.103611 3.116535 3.128633
#> 3 2.653881 2.643334 2.634652 2.627491 2.621574 2.616680 2.612627 2.609267
#> 4 2.881007 2.887212 2.893703 2.900307 2.906894 2.913368 2.919660 2.925718
#> 5 3.425417 3.471776 3.514636 3.554232 3.590791 3.624514 3.655600 3.684235
#>        s16      s17      s18      s19      s20      s21      s22      s23
#> 1 2.891811 2.896360 2.900693 2.904798 2.908667 2.912303 2.908247 2.900061
#> 2 3.139927 3.150444 3.160218 3.169284 3.177673 3.185434 3.208686 3.240286
#> 3 2.606478 2.604163 2.602239 2.600640 2.599310 2.598202 2.588008 2.573076
#> 4 2.931510 2.937013 2.942217 2.947116 2.951708 2.956004 2.974772 3.001445
#> 5 3.710593 3.734838 3.757126 3.777601 3.796387 3.813635 3.827194 3.838366
#>        s24      s25      s26      s27      s28      s29      s30      s31
#> 1 2.892788 2.886310 2.880529 2.875362 2.870732 2.866588 2.862868 2.859525
#> 2 3.269413 3.296215 3.320859 3.343503 3.364276 3.383355 3.400853 3.416893
#> 3 2.559637 2.547531 2.536616 2.526763 2.517867 2.509820 2.502540 2.495949
#> 4 3.026068 3.048760 3.069653 3.088874 3.106530 3.122760 3.137658 3.151324
#> 5 3.848788 3.858456 3.867409 3.875687 3.883304 3.890355 3.896851 3.902831
#>        s32      s33      s34      s35      s36      s37      s38      s39
#> 1 2.856518 2.853811 2.848782 2.842368 2.836579 2.831347 2.826616 2.822334
#> 2 3.431589 3.445046 3.454869 3.462132 3.468825 3.474972 3.480612 3.485784
#> 3 2.489978 2.484566 2.479786 2.475555 2.471746 2.468315 2.465222 2.462431
#> 4 3.163855 3.175338 3.187092 3.198830 3.209593 3.219447 3.228464 3.236712
#> 5 3.908329 3.913381 3.924719 3.940136 3.954274 3.967211 3.979043 3.989861
#>        s40      s41      s42      s43      s44      s45      s46      s47
#> 1 2.818457 2.814944 2.811759 2.808871 2.806251 2.803873 2.801714 2.799754
#> 2 3.490525 3.494868 3.498844 3.502423 3.505756 3.508806 3.511594 3.514142
#> 3 2.459911 2.457634 2.455575 2.453725 2.452037 2.450507 2.449121 2.447864
#> 4 3.244255 3.251150 3.257451 3.263152 3.268415 3.273223 3.277613 3.281620
#> 5 3.999750 4.008786 4.017042 4.024471 4.031366 4.037661 4.043409 4.048654
#>        s48      s49      s50      s51
#> 1 2.797973 2.796354 2.794883 2.793552
#> 2 3.516470 3.518597 3.520539 3.522250
#> 3 2.446724 2.445689 2.444749 2.443911
#> 4 3.285278 3.288616 3.291662 3.294374
#> 5 4.053441 4.057809 4.061794 4.065295

# for multinom_reg(), it does for types class and prob but not for raw
data("penguins", package = "modeldata")
penguins <- tidyr::drop_na(penguins)

mr_spec <- multinom_reg(penalty = 0.123) %>% set_engine("glmnet")
f_fit <- fit(mr_spec, species ~ island + bill_length_mm + bill_depth_mm,
             data = penguins)

predict(f_fit, penguins[1:5,], type = "class")
#> # A tibble: 5 × 1
#>   .pred_class
#>   <fct>      
#> 1 Adelie     
#> 2 Adelie     
#> 3 Adelie     
#> 4 Adelie     
#> 5 Adelie
predict(f_fit, penguins[1:5,], type = "prob")
#> # A tibble: 5 × 3
#>   .pred_Adelie .pred_Chinstrap .pred_Gentoo
#>          <dbl>           <dbl>        <dbl>
#> 1        0.848          0.0759       0.0761
#> 2        0.765          0.0756       0.159 
#> 3        0.770          0.0924       0.138 
#> 4        0.921          0.0459       0.0327
#> 5        0.888          0.0835       0.0283
predict(f_fit, penguins[1:5,], type = "raw")
#> , , s0
#> 
#>      Adelie  Chinstrap    Gentoo
#> 1 0.3228607 -0.4412382 0.1183776
#> 2 0.3228607 -0.4412382 0.1183776
#> 3 0.3228607 -0.4412382 0.1183776
#> 4 0.3228607 -0.4412382 0.1183776
#> 5 0.3228607 -0.4412382 0.1183776
#> 
#> , , s1
#> 
#>         Adelie  Chinstrap     Gentoo
#> 1 -0.070245047 -0.9610952 -0.4388192
#> 2 -0.080752141 -0.9610952 -0.4080160
#> 3 -0.101766329 -0.9610952 -0.4222329
#> 4 -0.007202482 -0.9610952 -0.4530360
#> 5 -0.075498594 -0.9610952 -0.4838391
#> 
#> , , s2
#> 
#>       Adelie Chinstrap    Gentoo
#> 1 -0.6584729 -1.651419 -1.236512
#> 2 -0.6778014 -1.651419 -1.121317
#> 3 -0.7164584 -1.651419 -1.174484
#> 4 -0.5425019 -1.651419 -1.289679
#> 5 -0.6681371 -1.651419 -1.404875
#> 
#> , , s3
#> 
#>      Adelie Chinstrap    Gentoo
#> 1 -1.209431 -2.295697 -1.987125
#> 2 -1.237089 -2.295697 -1.793266
#> 3 -1.292407 -2.295697 -1.882739
#> 4 -1.043477 -2.295697 -2.076599
#> 5 -1.223260 -2.295697 -2.270459
#> 
#> , , s4
#> 
#>      Adelie Chinstrap    Gentoo
#> 1 -1.733630 -2.907092 -2.704208
#> 2 -1.769264 -2.907092 -2.435748
#> 3 -1.840532 -2.907092 -2.559653
#> 4 -1.519825 -2.907092 -2.828113
#> 5 -1.751447 -2.907092 -3.096573
#> 
#> , , s5
#> 
#>      Adelie Chinstrap    Gentoo
#> 1 -2.238098 -3.494453 -3.396790
#> 2 -2.281444 -3.494453 -3.056687
#> 3 -2.368136 -3.494453 -3.213657
#> 4 -1.978021 -3.494453 -3.553761
#> 5 -2.259771 -3.494453 -3.893864

[and a whole lot more like this]

Created on 2023-01-18 with reprex v2.0.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions