New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add refit.NNETAR #287
add refit.NNETAR #287
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. Some small changes requested. I'll fix support for the scaling of inputs.
|
||
# check for scale_inputs: | ||
scale_in <- TRUE | ||
if (length(unlist(object$scales)) == 0) scale_in <- FALSE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified to:
scale_in <- length(unlist(object$scales)) != 0
wts_list <- object$model %>% | ||
purrr::map(.,"wts") | ||
|
||
out <- train_nnetar(new_data, specials, n_nodes = size, n_networks = n_nets, scale_inputs = scale_in, wts = wts_list,...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A simple trick to refit without re-estimation is to set maxit=0
:
# from ?nnet::nnet
library(nnet)
ir <- rbind(iris3[,,1],iris3[,,2],iris3[,,3])
targets <- class.ind( c(rep("s", 50), rep("c", 50), rep("v", 50)) )
samp <- c(sample(1:50,25), sample(51:100,25), sample(101:150,25))
ir1 <- nnet(ir[samp,], targets[samp,], size = 2, rang = 0.1,
decay = 5e-4, maxit = 200)
#> # weights: 19
#> initial value 55.564730
#> iter 10 value 45.612146
#> iter 20 value 11.169247
#> iter 30 value 2.334405
#> iter 40 value 2.300697
#> iter 50 value 2.285848
#> iter 60 value 2.269265
#> iter 70 value 2.214906
#> iter 80 value 1.942610
#> iter 90 value 1.779559
#> iter 100 value 1.760020
#> iter 110 value 1.752948
#> iter 120 value 1.751550
#> iter 130 value 1.750537
#> iter 140 value 1.750014
#> iter 150 value 1.749923
#> iter 160 value 1.749906
#> iter 170 value 1.749886
#> iter 180 value 1.749876
#> final value 1.749870
#> converged
ir2 <- nnet(ir[samp,], targets[samp,], size = 2, rang = 0.1,
decay = 5e-4, maxit = 0, Wts = ir1$wts)
#> # weights: 19
identical(ir1$fitted.values, ir2$fitted.values)
#> [1] TRUE
Created on 2020-06-30 by the reprex package (v0.3.0)
Oops, unintentionally merged this when trying to push to your pull request. I'll fix the above requests. Thanks again for this pull request! |
No description provided.