-
Notifications
You must be signed in to change notification settings - Fork 880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/reproducibility RNN #118
Conversation
darts/models/rnn_model.py
Outdated
""" | ||
|
||
kwargs['output_length'] = output_length | ||
kwargs['input_size'] = input_size | ||
kwargs['output_size'] = output_size | ||
|
||
# TODO : make it a util function? -> reusable in other torch models that needs fixed seed... | ||
# set the random seed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be a util function used in every torch model that need fixed seed. What do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably more part of the superclass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I think it could be one of kwargs in ForecastingModel and set there if possible - just need to make sure fixing the seed in one class will not leak outside the scope of current instance and affect all of the other ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it is enough to set it in TorchForecastingModel
(at least for now)
This was definitely missing, thanks! Just one thing: Do you think it would be possible to add this functionality to the superclass |
I was thinking about that but then it is supposed to be there only for models that have some randomness (probably most of the torch implemented model will) but that's why I proposed to implement it as a util function used only on a selection of torch model. If you think they will all need it probably then yes we should move it. What do you think ? |
Hmm yeah I see what you mean. To be honest I'm not sure what's best. Any ideas @hrzn ? |
After some thoughts I think adding it to the superclass is better as it will cover most of the use cases. If a inherited model is deterministic it will still work, and if you really want to enforce that you can't specify a |
@@ -145,12 +149,22 @@ def __init__(self, | |||
Sizes of hidden layers connecting the last hidden layer of the RNN module to the output, if any. | |||
dropout | |||
Fraction of neurons afected by Dropout. | |||
random_state | |||
Control the randomness of the weights initialization. Check this | |||
`link <https://scikit-learn.org/stable/glossary.html#term-random-state>`_ for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please correct me if I am wrong but I think that random_state
from sklearn affects only function that is passed to. But here I see that torch seed will be set by random_state
for all torch related pseudorandom number generation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes you are absolutely right I didn't find a cleaner way to avoid side effect do you know any ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this could be a lead: https://pytorch.org/docs/stable/random.html#torch.random.fork_rng
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, It would need to be added also to fit function as well, I think we use shuffle=true in there, but it look possible with fork_rng
. What do you think about just using manual_seed
before using the model rather than in model itself? I feel like it would be much simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for manual_seed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, It would need to be added also to fit function as well, I think we use shuffle=true in there, but it look possible with
fork_rng
. What do you think about just usingmanual_seed
before using the model rather than in model itself? I feel like it would be much simpler.
it would be simpler indeed however we would not have a unified API meaning the models from sklearn would need to be passed a random_state
while torch would need torch.manual_seed
before usage for reproducibility
In order to make sure there is no side effect and provide the end user with the same API as sklearn to fix a random state, I propose the following. Let me know what you think @Kostiiii, @TheMP, @pennfranc, @hrzn (might be too much ?) # parent = TorchForecastingModel
class Parent:
def __init__(self, random_state = None):
if not hasattr(self, "_random_instance"):
self._random_instance = np.random.RandomState(random_state) # a random_instance will be associated with the model and used in each function that require randomness # children = a specific model (i.e. RNNModel, GRU ...)
# the darts developer add a decorator @random_method to each method that will use random number generator (rgn)
class Children(Parent):
@random_method
def __init__(self, **kwargs):
print("create some model with random initial weights: {}".format(torch.randn(5)))
super().__init__(**kwargs)
@random_method
def fit(self):
print("train model with randomized batches {}".format(torch.randn(5))) # in darts.utils.torch
MAX_TORCH_SEED_VALUE = (1 << 63) - 1
def random_method(decorated):
def decorator(self, *args, **kwargs):
if hasattr(self, "_random_instance"):
# if parent class has been initialized already, should have a random instance -> use it
with torch.random.fork_rng():
torch.random.manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
decorated(self, *args, **kwargs)
elif "random_state" in kwargs.keys():
# if parent class has not been initialized but a random_state was provided as argument -> use it
self._random_instance = np.random.RandomState(kwargs["random_state"])
with torch.random.fork_rng():
torch.random.manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
decorated(self, *args, **kwargs)
else:
# else parent class has not been initialized and no random_state provided -> default randomness (not reproducible)
decorated(self, *args, **kwargs)
return decorator Usage for a darts user: children = Children(...params, random_state=42)
children.fit(some_data)
children.predict(...) [EDIT] |
The decorator approach looks quite neat @guillaumeraille, I think you can go for it. |
This reverts commit c3e70c3.
Fixes #DARTS-123.
Summary
Adds possibility to specify a
random_state
at model creation on RNN model use the same API as sklearn for easy usage across the whole DARTS library.Other Information