-
Notifications
You must be signed in to change notification settings - Fork 106
Closed
Description
parsnip requires a penalty value to be set in the model spec and uses that value for predictions if no other value is provided. This is working fine for linear_reg() but predictions of type = "raw" from logistic regression and multinomial regression ignore that value.
library(parsnip)
# for linear_reg(), predict() respects the penalty value set in the spec
data("hpc_data", package = "modeldata")
hpc <- hpc_data[1:150, c(2:5, 8)]
lm_spec <- linear_reg(penalty = 0.123) |> set_engine("glmnet")
lm_fit <- fit(lm_spec, input_fields ~ log(compounds) + class, data = hpc)
predict(lm_fit, hpc[1:5,], type = "numeric")
#> # A tibble: 5 × 1
#> .pred
#> <dbl>
#> 1 569.
#> 2 164.
#> 3 169.
#> 4 159.
#> 5 168.
predict(lm_fit, hpc[1:5,], type = "raw")
#> s1
#> 1 569.0363
#> 2 164.1463
#> 3 168.7923
#> 4 159.3046
#> 5 167.6483
# for logistic_reg(), it does for types class and prob but not for raw
data(lending_club, package = "modeldata")
lr_spec <- logistic_reg(penalty = 0.123) %>% set_engine("glmnet")
f_fit <- fit(lr_spec, Class ~ log(funded_amnt) + int_rate + term,
data = lending_club)
predict(f_fit, lending_club[1:5, ], type = "class")
#> # A tibble: 5 × 1
#> .pred_class
#> <fct>
#> 1 good
#> 2 good
#> 3 good
#> 4 good
#> 5 good
predict(f_fit, lending_club[1:5, ], type = "prob")
#> # A tibble: 5 × 2
#> .pred_bad .pred_good
#> <dbl> <dbl>
#> 1 0.0525 0.948
#> 2 0.0525 0.948
#> 3 0.0525 0.948
#> 4 0.0525 0.948
#> 5 0.0525 0.948
predict(f_fit, lending_club[1:5, ], type = "raw")
#> s0 s1 s2 s3 s4 s5 s6 s7
#> 1 2.894019 2.873291 2.859939 2.851942 2.847857 2.846633 2.847493 2.849842
#> 2 2.894019 2.905465 2.920151 2.936813 2.954577 2.972826 2.991126 3.009075
#> 3 2.894019 2.836290 2.790695 2.754340 2.725129 2.701510 2.682315 2.666723
#> 4 2.894019 2.878439 2.869573 2.865522 2.864932 2.866824 2.870474 2.875319
#> 5 2.894019 2.979466 3.058638 3.132017 3.200033 3.263072 3.321482 3.375312
#> s8 s9 s10 s11 s12 s13 s14 s15
#> 1 2.853267 2.857425 2.862063 2.866986 2.872046 2.877132 2.882159 2.887067
#> 2 3.026646 3.043592 3.059812 3.075242 3.089848 3.103611 3.116535 3.128633
#> 3 2.653881 2.643334 2.634652 2.627491 2.621574 2.616680 2.612627 2.609267
#> 4 2.881007 2.887212 2.893703 2.900307 2.906894 2.913368 2.919660 2.925718
#> 5 3.425417 3.471776 3.514636 3.554232 3.590791 3.624514 3.655600 3.684235
#> s16 s17 s18 s19 s20 s21 s22 s23
#> 1 2.891811 2.896360 2.900693 2.904798 2.908667 2.912303 2.908247 2.900061
#> 2 3.139927 3.150444 3.160218 3.169284 3.177673 3.185434 3.208686 3.240286
#> 3 2.606478 2.604163 2.602239 2.600640 2.599310 2.598202 2.588008 2.573076
#> 4 2.931510 2.937013 2.942217 2.947116 2.951708 2.956004 2.974772 3.001445
#> 5 3.710593 3.734838 3.757126 3.777601 3.796387 3.813635 3.827194 3.838366
#> s24 s25 s26 s27 s28 s29 s30 s31
#> 1 2.892788 2.886310 2.880529 2.875362 2.870732 2.866588 2.862868 2.859525
#> 2 3.269413 3.296215 3.320859 3.343503 3.364276 3.383355 3.400853 3.416893
#> 3 2.559637 2.547531 2.536616 2.526763 2.517867 2.509820 2.502540 2.495949
#> 4 3.026068 3.048760 3.069653 3.088874 3.106530 3.122760 3.137658 3.151324
#> 5 3.848788 3.858456 3.867409 3.875687 3.883304 3.890355 3.896851 3.902831
#> s32 s33 s34 s35 s36 s37 s38 s39
#> 1 2.856518 2.853811 2.848782 2.842368 2.836579 2.831347 2.826616 2.822334
#> 2 3.431589 3.445046 3.454869 3.462132 3.468825 3.474972 3.480612 3.485784
#> 3 2.489978 2.484566 2.479786 2.475555 2.471746 2.468315 2.465222 2.462431
#> 4 3.163855 3.175338 3.187092 3.198830 3.209593 3.219447 3.228464 3.236712
#> 5 3.908329 3.913381 3.924719 3.940136 3.954274 3.967211 3.979043 3.989861
#> s40 s41 s42 s43 s44 s45 s46 s47
#> 1 2.818457 2.814944 2.811759 2.808871 2.806251 2.803873 2.801714 2.799754
#> 2 3.490525 3.494868 3.498844 3.502423 3.505756 3.508806 3.511594 3.514142
#> 3 2.459911 2.457634 2.455575 2.453725 2.452037 2.450507 2.449121 2.447864
#> 4 3.244255 3.251150 3.257451 3.263152 3.268415 3.273223 3.277613 3.281620
#> 5 3.999750 4.008786 4.017042 4.024471 4.031366 4.037661 4.043409 4.048654
#> s48 s49 s50 s51
#> 1 2.797973 2.796354 2.794883 2.793552
#> 2 3.516470 3.518597 3.520539 3.522250
#> 3 2.446724 2.445689 2.444749 2.443911
#> 4 3.285278 3.288616 3.291662 3.294374
#> 5 4.053441 4.057809 4.061794 4.065295
# for multinom_reg(), it does for types class and prob but not for raw
data("penguins", package = "modeldata")
penguins <- tidyr::drop_na(penguins)
mr_spec <- multinom_reg(penalty = 0.123) %>% set_engine("glmnet")
f_fit <- fit(mr_spec, species ~ island + bill_length_mm + bill_depth_mm,
data = penguins)
predict(f_fit, penguins[1:5,], type = "class")
#> # A tibble: 5 × 1
#> .pred_class
#> <fct>
#> 1 Adelie
#> 2 Adelie
#> 3 Adelie
#> 4 Adelie
#> 5 Adelie
predict(f_fit, penguins[1:5,], type = "prob")
#> # A tibble: 5 × 3
#> .pred_Adelie .pred_Chinstrap .pred_Gentoo
#> <dbl> <dbl> <dbl>
#> 1 0.848 0.0759 0.0761
#> 2 0.765 0.0756 0.159
#> 3 0.770 0.0924 0.138
#> 4 0.921 0.0459 0.0327
#> 5 0.888 0.0835 0.0283
predict(f_fit, penguins[1:5,], type = "raw")
#> , , s0
#>
#> Adelie Chinstrap Gentoo
#> 1 0.3228607 -0.4412382 0.1183776
#> 2 0.3228607 -0.4412382 0.1183776
#> 3 0.3228607 -0.4412382 0.1183776
#> 4 0.3228607 -0.4412382 0.1183776
#> 5 0.3228607 -0.4412382 0.1183776
#>
#> , , s1
#>
#> Adelie Chinstrap Gentoo
#> 1 -0.070245047 -0.9610952 -0.4388192
#> 2 -0.080752141 -0.9610952 -0.4080160
#> 3 -0.101766329 -0.9610952 -0.4222329
#> 4 -0.007202482 -0.9610952 -0.4530360
#> 5 -0.075498594 -0.9610952 -0.4838391
#>
#> , , s2
#>
#> Adelie Chinstrap Gentoo
#> 1 -0.6584729 -1.651419 -1.236512
#> 2 -0.6778014 -1.651419 -1.121317
#> 3 -0.7164584 -1.651419 -1.174484
#> 4 -0.5425019 -1.651419 -1.289679
#> 5 -0.6681371 -1.651419 -1.404875
#>
#> , , s3
#>
#> Adelie Chinstrap Gentoo
#> 1 -1.209431 -2.295697 -1.987125
#> 2 -1.237089 -2.295697 -1.793266
#> 3 -1.292407 -2.295697 -1.882739
#> 4 -1.043477 -2.295697 -2.076599
#> 5 -1.223260 -2.295697 -2.270459
#>
#> , , s4
#>
#> Adelie Chinstrap Gentoo
#> 1 -1.733630 -2.907092 -2.704208
#> 2 -1.769264 -2.907092 -2.435748
#> 3 -1.840532 -2.907092 -2.559653
#> 4 -1.519825 -2.907092 -2.828113
#> 5 -1.751447 -2.907092 -3.096573
#>
#> , , s5
#>
#> Adelie Chinstrap Gentoo
#> 1 -2.238098 -3.494453 -3.396790
#> 2 -2.281444 -3.494453 -3.056687
#> 3 -2.368136 -3.494453 -3.213657
#> 4 -1.978021 -3.494453 -3.553761
#> 5 -2.259771 -3.494453 -3.893864
[and a whole lot more like this]Created on 2023-01-18 with reprex v2.0.2
Metadata
Metadata
Assignees
Labels
No labels