Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When to do mdl.train() and mdl.eval() for MAML? #19

Closed
brando90 opened this issue Sep 8, 2021 · 3 comments
Closed

When to do mdl.train() and mdl.eval() for MAML? #19

brando90 opened this issue Sep 8, 2021 · 3 comments

Comments

@brando90
Copy link

brando90 commented Sep 8, 2021

I was thinking that one would do it as follows:

  1. During meta-training (fitting):
  • inner loop (support set) it does have the mdl.train() (because we want to collect the running average accross tasks)
  • query set, it has the same mdl.train() (to use the same params)

which is what your doing here: https://github.com/tristandeleu/pytorch-meta/blob/d487ad0a1268bd6e6a7290b8780c6b62c7bed688/examples/maml-higher/train.py#L93

The real question is what to do during evaluation (since at meta-eval, the tasks are completely different e.g. image classes we've never seen). There really are 3 options (call them a b c)

2.a. During meta-eval (inference e.g. validation, testing):
2.a. - use .train() for both the support (inner loop) and query set. Here the issue is the model would (accidently) cheat since it would use the stats of the eval set

2.b. - use .eval() for both the support (inner loop) and query set. Here the model would use the stats from training and would not cheat. The pro is that the model was trained with those stats so perhaps thats good - but the true stats of the eval set is something completely different (most likely since the classes have not been seen)

2.c. - use eval() AND set track_running_stats = False. This would use batch statistics. Which would mean the model uses "the right stats" but it was not trained on them...so, who knows if that is better. Plus idk what the BN layer would do for 1-shot learning...probably crash unless it uses layer norm LN.

I am basically curious what the standard maml does. From your code here:

self.model.eval()
I infer that you choose option 2.b. So during the inner loop (support set) and the query set your model has eval and uses stats from training.

Is that right?


my implementation currently:

        # inner_opt = torch.optim.SGD(self.base_model.parameters(), lr=self.lr_inner)
        inner_opt = NonDiffMAML(self.base_model.parameters(), lr=self.lr_inner)
        # inner_opt = torch.optim.Adam(self.base_model.parameters(), lr=self.lr_inner)
        self.args.inner_opt_name = str(inner_opt)

        # Accumulate gradient of meta-loss wrt fmodel.param(t=0)
        meta_batch_size = spt_x.size(0)
        meta_losses, meta_accs = [], []
        for t in range(meta_batch_size):
            spt_x_t, spt_y_t, qry_x_t, qry_y_t = spt_x[t], spt_y[t], qry_x[t], qry_y[t]
            # if torch.cuda.is_available():
            #     spt_x_t, spt_y_t, qry_x_t, qry_y_t = spt_x_t.cuda(), spt_y_t.cuda(), qry_x_t.cuda(), qry_y_t.cuda()
            # Inner Loop Adaptation
            with higher.innerloop_ctx(self.base_model, inner_opt, copy_initial_weights=self.args.copy_initial_weights,
                                      track_higher_grads=self.args.track_higher_grads) as (fmodel, diffopt):
                diffopt.fo = self.fo
                for i_inner in range(self.args.nb_inner_train_steps):
                    fmodel.train()

                    # base/child model forward pass
                    spt_logits_t = fmodel(spt_x_t)
                    inner_loss = self.args.criterion(spt_logits_t, spt_y_t)
                    # inner_train_err = calc_error(mdl=fmodel, X=S_x, Y=S_y)  # for more advanced learners like meta-lstm

                    # inner-opt update
                    diffopt.step(inner_loss)

            fmodel.train() if self.args.split == 'train' else fmodel.eval()
            # Evaluate on query set for current task
            qry_logits_t = fmodel(qry_x_t)
            qry_loss_t = self.args.criterion(qry_logits_t, qry_y_t)

            # Accumulate gradients wrt meta-params for each task: https://github.com/facebookresearch/higher/issues/104
            # qry_loss_t.backward()  # note this is more memory efficient (as it removes intermediate data that used to be needed since backward has already been called)
            (qry_loss_t / meta_batch_size).backward()  # note this is more memory efficient (as it removes intermediate data that used to be needed since backward has already been called)

            # get accuracy
            if self.target_type == 'classification':
                qry_acc_t = calc_accuracy_from_logits(y_logits=qry_logits_t, y=qry_y_t)  #
            else:
                qry_acc_t = r2_score_from_torch(qry_y_t, qry_logits_t)
                # qry_acc_t = compressed_r2_score(y_true=qry_y_t.detach().numpy(), y_pred=qry_logits_t.detach().numpy())

            # collect losses & accs for logging/debugging
            meta_losses.append(qry_loss_t.item())
            meta_accs.append(qry_acc_t)

ref: https://stats.stackexchange.com/questions/544048/what-does-the-batch-norm-layer-for-maml-model-agnostic-meta-learning-do-for-du

@brando90
Copy link
Author

brando90 commented Nov 5, 2021

My conclusion finally:

- Importantly, during inference (eval/testing) running_mean, running_std is used - that was calculated from training(because they want a deterministic output and to use estimates of the population statistics).  
- During training the batch statistics is used but a population statistic is estimated with running averages. I assume the reason batch_stats is used during training is to introduce noise that regularizes training (noise robustness)
- in meta-learning I think using batch statistics is the best during testing (and not calculate the running means) since we are supposed to be seeing distribution anyway. Price we pay is loss of determinism. Could be interesting just out of curiosity what the accuracy is using population stats estimated from meta-trian.

see comments for extended links and discussion: https://stackoverflow.com/questions/69845469/when-should-one-call-eval-and-train-when-doing-maml-with-the-pytorch-highe

@brando90
Copy link
Author

brando90 commented May 5, 2022

My conclusion finally:

- Importantly, during inference (eval/testing) running_mean, running_std is used - that was calculated from training(because they want a deterministic output and to use estimates of the population statistics).  
- During training the batch statistics is used but a population statistic is estimated with running averages. I assume the reason batch_stats is used during training is to introduce noise that regularizes training (noise robustness)
- in meta-learning I think using batch statistics is the best during testing (and not calculate the running means) since we are supposed to be seeing distribution anyway. Price we pay is loss of determinism. Could be interesting just out of curiosity what the accuracy is using population stats estimated from meta-trian.

see comments for extended links and discussion: https://stackoverflow.com/questions/69845469/when-should-one-call-eval-and-train-when-doing-maml-with-the-pytorch-highe

one sentence answer:

Always use .train() and never .eval() in fsl meta-learning.

@brando90
Copy link
Author

@tristandeleu will close this. Let me know (or anyone else) if you disagree.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant