Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for multiple train loaders #1959

Merged
merged 38 commits into from Jan 4, 2021
Merged

Add Support for multiple train loaders #1959

merged 38 commits into from Jan 4, 2021

Conversation

justusschock
Copy link
Member

@justusschock justusschock commented May 26, 2020

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

What does this PR do?

When this is finished it adds support for drawing batches from multiple train loaders at once. If the loaders are specified as a Mapping (dict), the resulting batch will consist of one batch per loader under the same keys as the loaders like this:

loaders = {"x": loader_x, "y": loader_y, "z": loader_z}

will result in a batch like this:

{"x": batch_from_loader_x, "y": batch_from_loader_y, "z": batch_from_loader_z}

and loaders in a sequence will return in a sequence-batch built of the separate batches in the correct order:

loaders = [loader_0, loader_1, loader_2]

will result in a batch like this:

[batch_from_loader_0, batch_from_loader_1, batch_from_loader_2]

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@justusschock justusschock added the feature Is an improvement or enhancement label May 26, 2020
@justusschock justusschock self-assigned this May 26, 2020
@pep8speaks
Copy link

pep8speaks commented May 26, 2020

Hello @justusschock! Thanks for updating this PR.

Line 260:93: W291 trailing whitespace
Line 310:1: W293 blank line contains whitespace

Line 45:1: W293 blank line contains whitespace

Line 177:1: W391 blank line at end of file

Comment last updated at 2021-01-04 19:58:04 UTC

@mergify mergify bot requested a review from a team May 26, 2020 15:01
@awaelchli
Copy link
Member

in the eval loop we pass in a dataloader_idx, what are your thoughts on that? does it make sense or not to have that in training_step as well?

@justusschock
Copy link
Member Author

I don't think so, because in the validation phase we will use the dataloaders sequentially and in the training we will use them in parallel

@awaelchli
Copy link
Member

yup, I like your idea. just wanted to raise this because in slack there was talk about making it consistent with eval loops.

@Borda Borda modified the milestones: 0.7.7, 0.8.0 May 26, 2020
@williamFalcon
Copy link
Contributor

How do we handle different length datasets?

Option 1:
Cycle through smaller one

(L is a long dataset, 0,1,2,3,4 are cycles of the smaller datasets)
LLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLL
0000000 1111111111 22222222 33333333 44

Option 2:
Cycle through the min length of all datasets.

I personally think it should be option 1.

@justusschock one very simple way to get this feature right now is just to add all trainining datasets to concat dataset for the user.

import torch


class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        result = []
        for dataset in self.datasets:
            cycled_i = i % len(dataset)
            result.append(dataset[cycled_i])

        return tuple(result)

    def __len__(self):
        return max(len(d) for d in self.datasets)

@justusschock
Copy link
Member Author

justusschock commented May 26, 2020

@williamFalcon I think it should be the minimum length, since this would be the most explicit version.
IMO the cycling should be done by the user within the dataset. (I don't favor this version, but just wanted to state it here)

Another option would be to sample from each loader as long as they did not yet run out of samples and to omit the loaders which are already exhausted.

@williamFalcon
Copy link
Contributor

but seems weird to me that i wouldn’t use the full dataset if i had a smaller one.
maybe we need to support a few modes?

@Borda
Copy link
Member

Borda commented Jun 9, 2020

@justusschock any progress here?
or someone from @PyTorchLightning/core-contributors can help...

@Borda Borda added help wanted Open to be worked on waiting on author Waiting on user action, correction, or update labels Jun 9, 2020
@justusschock
Copy link
Member Author

I'll get this done, once the metrics are finally finished :D Sorry for the delay on my side

@reactivetype
Copy link

@justusschock any progress here?

I could get around this problem using a custom pytorch BatchSampler which I pass to the dataloader. My dataset's get_item takes a tuple of integers as index, each integer gets a data item for corresponding dataset. The only drawback is that the data item needs to be same shape for all 3 datasets but the batch size can be different for each dataset's items.

@justusschock justusschock marked this pull request as ready for review June 29, 2020 13:35
@justusschock justusschock changed the title WIP: Add Support for multiple train loaders Add Support for multiple train loaders Jun 29, 2020
@justusschock
Copy link
Member Author

justusschock commented Jun 29, 2020

Okay, this is almost finished. Currently there still is a bug, if a loader has no length. Any ideas how we should proceed with this? Shall we set the overall length simply to inf?

@awaelchli @williamFalcon @Borda

@Borda
Copy link
Member

Borda commented Jun 29, 2020

I guess that we made similar "hotfix" to valid dataloader and salt length to inf
https://github.com/PyTorchLightning/pytorch-lightning/blob/f1c96930b19e608f9875df642235cc48dea2f8ee/pytorch_lightning/trainer/data_loading.py#L288-L291

@christofer-f
Copy link
Contributor

Hi, @omiita spotted this error...
The following code gives wrong number of iterations in a training cycle

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

class MNISTModel(pl.LightningModule):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l_mnist = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l_mnist(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch['mnist']
        y_hat = torch.relu(self.l_mnist(x.view(x.size(0), -1)))
        loss_mnist = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss_mnist}
        return {'loss': loss_mnist, 'log': tensorboard_logs}

    def configure_optimizers(self):
        opt_mnist = torch.optim.Adam(self.l_mnist.parameters(), lr=0.02)
        return opt_mnist

    def train_dataloader(self):
        loader_mnist = DataLoader(MNIST(os.getcwd(), train=True, 
            download=True, transform=transforms.ToTensor()), batch_size=32)
        loaders = {"mnist": loader_mnist}
        return loaders

def main():
    mnist_model = MNISTModel()
    trainer = pl.Trainer(gpus=1, fast_dev_run=False, max_epochs=1)    
    trainer.fit(mnist_model)   

if __name__ == "__main__":
    main()

@Borda Borda modified the milestones: 1.1.x, 1.2 Dec 31, 2020
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
pytorch_lightning/trainer/supporters.py Outdated Show resolved Hide resolved
pytorch_lightning/utilities/data.py Outdated Show resolved Hide resolved
tests/base/model_train_dataloaders.py Outdated Show resolved Hide resolved
Borda and others added 2 commits December 31, 2020 10:53
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
pytorch_lightning/trainer/data_loading.py Show resolved Hide resolved
pytorch_lightning/trainer/supporters.py Show resolved Hide resolved
pytorch_lightning/trainer/supporters.py Outdated Show resolved Hide resolved
length = all_lengths

elif isinstance(all_lengths, Mapping):
length = compute_func(all_lengths.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if something defines something like

{"a":{"b":...}}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, this would currently fail. Is this something we want to enable?

Copy link
Contributor

@tchaton tchaton Jan 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at this point I think. People will open an issue if needed.

pytorch_lightning/trainer/supporters.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/supporters.py Show resolved Hide resolved
pytorch_lightning/trainer/train_loader_patch.py Outdated Show resolved Hide resolved
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work !

@Borda Borda merged commit d88cf4a into release/1.2-dev Jan 4, 2021
@Borda Borda deleted the train_loaders branch January 4, 2021 19:57
@Borda
Copy link
Member

Borda commented Jan 4, 2021

@justusschock it seems that this break checks, mind check it ASAP 🐰

@adamgayoso
Copy link

It doesn't look like the documentation for fit() wasn't updated. Should fit() be able to take multiple train data loaders as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on priority: 0 High priority task ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Using multiple dataloaders in the training_step?