-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When to do mdl.train() and mdl.eval() for MAML? #19
Comments
My conclusion finally:
see comments for extended links and discussion: https://stackoverflow.com/questions/69845469/when-should-one-call-eval-and-train-when-doing-maml-with-the-pytorch-highe |
one sentence answer:
|
@tristandeleu will close this. Let me know (or anyone else) if you disagree. |
I was thinking that one would do it as follows:
mdl.train()
(because we want to collect the running average accross tasks)mdl.train()
(to use the same params)which is what your doing here: https://github.com/tristandeleu/pytorch-meta/blob/d487ad0a1268bd6e6a7290b8780c6b62c7bed688/examples/maml-higher/train.py#L93
The real question is what to do during evaluation (since at meta-eval, the tasks are completely different e.g. image classes we've never seen). There really are 3 options (call them a b c)
2.a. During meta-eval (inference e.g. validation, testing):
2.a. - use .train() for both the support (inner loop) and query set. Here the issue is the model would (accidently) cheat since it would use the stats of the eval set
2.b. - use .eval() for both the support (inner loop) and query set. Here the model would use the stats from training and would not cheat. The pro is that the model was trained with those stats so perhaps thats good - but the true stats of the eval set is something completely different (most likely since the classes have not been seen)
2.c. - use eval() AND set
track_running_stats = False
. This would use batch statistics. Which would mean the model uses "the right stats" but it was not trained on them...so, who knows if that is better. Plus idk what the BN layer would do for 1-shot learning...probably crash unless it uses layer norm LN.I am basically curious what the standard maml does. From your code here:
pytorch-maml/maml/metalearners/maml.py
Line 231 in 4410427
Is that right?
my implementation currently:
ref: https://stats.stackexchange.com/questions/544048/what-does-the-batch-norm-layer-for-maml-model-agnostic-meta-learning-do-for-du
The text was updated successfully, but these errors were encountered: