Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating to Lightning 2.0 #210

Merged
merged 30 commits into from
Sep 6, 2023
Merged

Updating to Lightning 2.0 #210

merged 30 commits into from
Sep 6, 2023

Conversation

RaulPPelaez
Copy link
Collaborator

This is an effort to update the pytorch lightning dependency (see #168 )
A lot has changed since the LNNP module was written (current version is 1.6.4) and I expect this update would increase compatibility with some functionalities we seek (AKA torch compile/torchscript training).

@RaulPPelaez
Copy link
Collaborator Author

The resume_from_checkpoint option in train.py has disappeared:

    trainer = pl.Trainer(
        strategy=DDPStrategy(find_unused_parameters=False),
        max_epochs=args.num_epochs,
        accelerator="auto",
        devices=args.ngpus,
        num_nodes=args.num_nodes,
        default_root_dir=args.log_dir,
        #resume_from_checkpoint=None if args.reset_trainer else args.load_model,
        callbacks=[early_stopping, checkpoint_callback],
        logger=_logger,
        precision=args.precision,
        gradient_clip_val=args.gradient_clipping,
        inference_mode=False
    )

Trying to see how one is supposed to replace it

@AntonioMirarchi
Copy link
Contributor

From LightningDeprecationWarning: Setting Trainer(resume_from_checkpoint=) is deprecated in v1.5 and will be removed in v1.7. Please pass Trainer.fit(ckpt_path=) directly instead.
So it could be something like this:

if not args.load_model:
   trainer.fit(model, data)
else:
   trainer.fit(model, data, ckpt_path=load_model)

@RaulPPelaez
Copy link
Collaborator Author

There is one lingering issue. Training seems to be overall x6 faster with this PR than current main:
image
image

But as you see the testing graphs are not being updated to wandb. I am scared it is just not testing, but the train and val loss seem identical to me.

@RaulPPelaez
Copy link
Collaborator Author

I found this issue in which they discuss why Lighting does not want to support testing during the training loop:
Lightning-AI/pytorch-lightning#9254
They also describe the trick that is being used now of adding the test dataloader as a validation dataloader so you can use it during training.

However, I do not understand why we are paying such a high price in performance to test during training. @PhilippThoelke , you implemented this AFAIK, would you provide some insights?

@PhilippThoelke
Copy link
Collaborator

Testing during training was useful since it can be difficult to estimate model performance on val loss, when val loss is also used to adjust the learning rate (e.g. through the ReduceLROnPlateau schedule). This is particularly important for fast prototyping and architecture development. I wasn't aware that this slows down training by 6x, that is insane! How much of that is actually due to the testing though, and how much are improvements in Lightning 2.0?
I don't know if that lr scheduler is still relevant for torchmd-net, but a 6x efficiency decrease is never a viable solution.
Possibly related: #27

@RaulPPelaez
Copy link
Collaborator Author

TBH I do not really know what amount of the speedup is due to other improvements in lighting.
It makes sense to me that not testing is the main cause of the speedup simply because one is not going over a whole dataset every few epochs, which requires reading a bunch of stuff from disk, etc.
I understand the fast prototyping argument. But we are currently going against the lighting devs.
I will try to recreate the multiple-dataloader hack (looking at the docs/github issues I do not see any other workaround) so the functionality is still provided and I can measure if this is indeed the cause of the slowdown.

Well I guess I can also run the baseline without testing during training -.-

If that turns out to be the case then perhaps we can print a warning in the "test during training" case to explain that you are paying a high price for it hehe

@giadefa
Copy link
Contributor

giadefa commented Aug 8, 2023 via email

@RaulPPelaez
Copy link
Collaborator Author

Ok since I do not know yet how to reproduce the trick in latest Lighting I ran current main with and without testing during validation to compare:
image
Note that I was not trying to benchmark when I noticed this, I was checking correctness. Thus this train example must not be the most representative one (it does not fully occupy the GPU), but still...
Testing is carried out every 10 epochs in the orange line.
Roughly it is a x3 slowdown if you test every 10 epochs. Without testing I see a x3 speedup between Lightning 1.6.3 and 2.0.4.

I get testing being super slow because the way it is set up in this example it just goes over the whole dataset each time. Besides being just a lot of forward passes it requires reading the whole thing from disk.

With this I know believe that:

  1. Updating Ligthing is very worth it.
  2. We should keep the testing functionality for prototyping.
  3. We should include a warning explaining that test-during-training can be expensive.

@RaulPPelaez
Copy link
Collaborator Author

I managed to bring back the functionality.
As far as I can tell the only way to reproduce the test-during-train trick is to reload the dataloader every epoch.
This does not seem to have a profound effect in performance.
Sadly the performance when test-during-training is enabled does not look better than before.

I will test some more to make sure these performance numbers are not a fluke, but functionality-wise this should be ready to go!

Comment on lines -119 to -121
def _get_dataloader(self, dataset, stage, store_dataloader=True):
store_dataloader = (
store_dataloader and self.trainer.reload_dataloaders_every_n_epochs <= 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this check here.
By removing it the dataloaders are always stored.
AFAIK it was never actually used the check the project, so the behavior is unchanged. But I want to point it out in case I am missing something.

@guillemsimeon
Copy link
Collaborator

I get that testing during training does only matter when you are prototyping, and in fact it helped in my case to stop training after few epochs when some hyperparameter change did not provide any improvement. If we are changing this, I have two suggestions:

  • Report validation MAEs on both energies and forces on top of (separate) validation loss(es), which are MSEs. This can give a better sense of how the training is going.
  • Even when no testing is performed during training, I would suggest in that case to provide test errors when the training is finished by default. Then people would not need to run inference separately on the test set to report performances.

@guillemsimeon
Copy link
Collaborator

Another option is, when the training is finished, look for the checkpoint with lowest validation loss and run test with this one. I think this is what happens in MACE, for example.

@RaulPPelaez
Copy link
Collaborator Author

Thanks for your insights!
I am convinced test-during-training has some useful cases, luckily I was able to cook it in again so I believe it is best to leave it there as an option at least for now. If you want to go go fast simply do not set the "test_interval" option.

Another option is, when the training is finished, look for the checkpoint with lowest validation loss and run test with this one. I think this is what happens in MACE, for example.

AFAIK this is exactly what torchmd-train does:

trainer.fit(model, data)
# run test set after completing the fit
model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
trainer = pl.Trainer(logger=_logger)
trainer.test(model, data)

Report validation MAEs

As you say, the values reported as "val_loss_y/dy" are mse_loss. I can add another reported value called "val_l1_loss_y/dy". It should have no effect whatsoever on performance.

@RaulPPelaez
Copy link
Collaborator Author

Adding these metrics to the reports has zero cost on performance.

I changed things so that each val epoch loss is reported for every loss function in a list (currently l1 and l2).
Additionally, y, neg_dy and total losses are reported always for each loss function regardless of the weights of each one (even if total_loss=loss_neg_dy you might be interested in knowing how loss_y behaves).
At this point, the following metrics are reported when test-during-training is off (by default, set by test_interval<0):

Metric
train_loss
train_neg_dy_mse_loss
train_total_mse_loss
train_y_mse_loss
val_loss
val_neg_dy_l1_loss
val_neg_dy_mse_loss
val_total_l1_loss
val_total_mse_loss
val_y_l1_loss
val_y_mse_loss

train_loss and val_loss are aliases to train_total_mse_loss and val_total_mse_loss respectively. I left them there because the in place naming scheme for the checkpoints uses this particular names.

If test_interval>0 then test_[y,neg_dy,total]_l1_loss entries are added.

@RaulPPelaez
Copy link
Collaborator Author

I noticed there was no "on_test_epoch_end" member in LNNP. Meaning that the trainer.test line in train.py effectively just discarded all the work it did as far as I can tell. Also there is no logger attached to the test run, so nothing is written to disk about it.

I wrote one so that test losses are actually reported to the terminal like this:
image

Always log losses on y and neg_dy even if they are weighted 0 for the
total loss
@RaulPPelaez RaulPPelaez changed the title [WIP] Updating to Lightning 2.0 Updating to Lightning 2.0 Sep 5, 2023
@RaulPPelaez
Copy link
Collaborator Author

This is ready to merge on my part.
@AntonioMirarchi, could you rerun some training with this PR to double check?
cc @raimis please review

tests/test_model.py Outdated Show resolved Hide resolved
torchmdnet/module.py Outdated Show resolved Hide resolved
torchmdnet/module.py Outdated Show resolved Hide resolved
torchmdnet/module.py Outdated Show resolved Hide resolved
@stefdoerr
Copy link
Collaborator

Test failed for some reason. Can you take a look?

@RaulPPelaez
Copy link
Collaborator Author

It was just an overprotective numerical check. Just retriggering the CI has fixed it.

@guillemsimeon
Copy link
Collaborator

guillemsimeon commented Sep 5, 2023

In the end is this intrinsically faster or was it just the effect of testing during training?

@RaulPPelaez
Copy link
Collaborator Author

This small, not representative example, I am using is def faster beyond the test thing.
Although every test I have run turns out faster, I would not expect real life runs to be much better. For instance, the ET-SPICE.yaml test, removing test-during-training, is just about 5% faster. YMMV.
I would expect to see improvements when epochs are really short, for instance.

@RaulPPelaez RaulPPelaez merged commit ac16c09 into torchmd:main Sep 6, 2023
1 check passed
@cuicathy
Copy link

cuicathy commented Feb 7, 2024

Hi, Thanks for the efforts to test-during-training. I wonder can we use test-during-training now? If yes, could you please give us an example? Thank you very much!
P.S. I do not always need to test-during-training, but in some project for specific research purpose I have to...

@RaulPPelaez
Copy link
Collaborator Author

Hi, The behavior is just like before this PR via the "test_interval" option. However, setting it to -1 will skip test during training.

Check out this doc page https://torchmd-net.readthedocs.io/en/latest/torchmd-train.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants