Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add refit.NNETAR #287

Merged
merged 4 commits into from Jun 30, 2020
Merged

add refit.NNETAR #287

merged 4 commits into from Jun 30, 2020

Conversation

Tim-TU
Copy link
Contributor

@Tim-TU Tim-TU commented Jun 30, 2020

No description provided.

Copy link
Member

@mitchelloharawild mitchelloharawild left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Some small changes requested. I'll fix support for the scaling of inputs.


# check for scale_inputs:
scale_in <- TRUE
if (length(unlist(object$scales)) == 0) scale_in <- FALSE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified to:
scale_in <- length(unlist(object$scales)) != 0

wts_list <- object$model %>%
purrr::map(.,"wts")

out <- train_nnetar(new_data, specials, n_nodes = size, n_networks = n_nets, scale_inputs = scale_in, wts = wts_list,...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple trick to refit without re-estimation is to set maxit=0:

# from ?nnet::nnet
library(nnet)
ir <- rbind(iris3[,,1],iris3[,,2],iris3[,,3])
targets <- class.ind( c(rep("s", 50), rep("c", 50), rep("v", 50)) )
samp <- c(sample(1:50,25), sample(51:100,25), sample(101:150,25))
ir1 <- nnet(ir[samp,], targets[samp,], size = 2, rang = 0.1,
            decay = 5e-4, maxit = 200)
#> # weights:  19
#> initial  value 55.564730 
#> iter  10 value 45.612146
#> iter  20 value 11.169247
#> iter  30 value 2.334405
#> iter  40 value 2.300697
#> iter  50 value 2.285848
#> iter  60 value 2.269265
#> iter  70 value 2.214906
#> iter  80 value 1.942610
#> iter  90 value 1.779559
#> iter 100 value 1.760020
#> iter 110 value 1.752948
#> iter 120 value 1.751550
#> iter 130 value 1.750537
#> iter 140 value 1.750014
#> iter 150 value 1.749923
#> iter 160 value 1.749906
#> iter 170 value 1.749886
#> iter 180 value 1.749876
#> final  value 1.749870 
#> converged
ir2 <- nnet(ir[samp,], targets[samp,], size = 2, rang = 0.1,
            decay = 5e-4, maxit = 0, Wts = ir1$wts)
#> # weights:  19
identical(ir1$fitted.values, ir2$fitted.values)
#> [1] TRUE

Created on 2020-06-30 by the reprex package (v0.3.0)

@mitchelloharawild mitchelloharawild merged commit acd925b into tidyverts:master Jun 30, 2020
1 check failed
@mitchelloharawild
Copy link
Member

Oops, unintentionally merged this when trying to push to your pull request. I'll fix the above requests.

Thanks again for this pull request!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants