Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using DDP with WebDataset in pytorch lightning #250

Open
adhakal224 opened this issue Feb 20, 2023 · 21 comments
Open

Using DDP with WebDataset in pytorch lightning #250

adhakal224 opened this issue Feb 20, 2023 · 21 comments

Comments

@adhakal224
Copy link

adhakal224 commented Feb 20, 2023

This is my first time using WebDataset and I have multiple shards (about 60) with a large number of images. It was working as I would expect in the normal Dataset class when I was using a single GPU. However once I set the devices to 2 I received the error ValueError: you need to add an explicit nodesplitter to your input pipeline for multi-node training Webdataset.
I saw two approach to allow using multiple gpus with WebDataset but my training is immensely slow at the moment.

  1. Using .with_epochs
    According to WebDataset Github I could simply use the with_epochs function in my dataset as follows:
dataset = wds.WebDataset(url, resampled=True).shuffle(1000).decode("rgb").to_tuple("png", "json").map(preprocess).with_epoch(10000) 
dataloader = wds.WebLoader(dataset, batch_size=batch_size)
  1. Using ddp_equalize
    According to WebDataset MultiNode
dataset_size, batch_size = 1282000, 64 
dataset = wds.WebDataset(urls).decode("pil").shuffle(5000).batched(batch_size, partial=False) 
loader = wds.WebLoader(dataset, num_workers=4) loader = loader.ddp_equalize(dataset_size // batch_size)

Could someone please help me understand what is happening in these two pieces of code and how they are different. In the second case is the dataset_size just a nominal size? Which if any is better? There is also mention of MultiDataset in the docs but I could not find a good detailed documentation about it. I would also appreciate if someone has an example of what is the best way to use Webdataset with pytorch lightning in multi-gpu and multi-node scenario.

Currently I am using the first approach and my training is extremely slow. Currently my dataloader roughly looks like this:

self.dataset = wds.WebDataset(self.wds_path, resampled=True)
self.dataset = self.dataset.shuffle(1000).decode('rgb').to_tuple("a.jpg", "b.jpg", "c.json","__key__")
self.dataset= self.dataset.with_epochs(10000)
dataloader = wds.WebLoader(self.dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

Do I need to do anything other than this to make DDP work properly with WebDataset? Would also appreciate any feedback that might make this a bit more efficient.

@superhero-7
Copy link

superhero-7 commented Feb 26, 2023

I also use DDP with WebDataset in pytorch lightning recently. And I take openclip codebase as reference, link: https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py#L152. I can run on multigpus at first, however it will stuck and not throw any bugs after a while. By the way, I also feel confuse why resample + with_epoch can works in distribute maner. Do you successfully run webdataset with Pytorch lightning?

@adhakal224
Copy link
Author

adhakal224 commented Feb 26, 2023

@superhero-7 I have been training for a couple of days with the resample+with_epoch approach and so far the training is running without any issues. However, it would be nice to have an example on the optimal way to use WebDataset with lightning and ddp somewhere in the docs.

@superhero-7
Copy link

Sounds great! Can you share your code example? I still get stuck after training for a while. By the way, I am curious whether your progress bar is displayed normally? I could not see the progress bar work properly like bellow:
1677461654298

@adhakal224
Copy link
Author

adhakal224 commented Mar 1, 2023

Sure! Here is my dataset and dataloader

self.dataset = self.dataset.shuffle(1000).decode('pil').to_tuple("groundlevel.jpg", "overhead.jpg", "metadata.json","__key__").map(self.do_transforms).batched(self.args.train_batch_size).with_epoch(self.args.train_epoch_length)

trainloader = wds.WebLoader(self.trainset, batch_size=None,
                    shuffle=False, pin_memory=True, num_workers=self.hparams.num_workers)
trainloader = trainloader.unbatched().shuffle(1000).batched(self.hparams.train_batch_size)

I cant see the progress bar normally. I think thats natural as there is no concept of length in the WebDataset. Hence, it is not possible to figure out how many steps it will take for 1 epoch to end.

@superhero-7
Copy link

Thanks! I can run it normally now, but I still need to do a validation make sure the result is right. Does your progress bar looks like mine? For example, it already run 253 its in above image, but there could not see the total number need to run, I also guess it is beacause webdataset not give the length to lightning.

@tmbdev
Copy link
Collaborator

tmbdev commented Mar 3, 2023

You can find a DDP training example here:

https://github.com/webdataset/webdataset-imagenet

At this point, there are several ways of dealing with DDP training:

(1) use node splitting and Join (this gives you "exact epochs")
(2) use resampling and with_epoch; this gives you slightly different sample statistics but is actually a good thing to do
(3) use Ray datasets with webdataset (once that is released)
(4) use WebDataset with repeat and with_epoch ("ddp equalize")
(5) use torchdata

(1) didn't use to work with Lightning, but I think they added support for it now. (2) should work fine with Lightning, but I haven't tested it recently. (3) allows you to do on the fly repartitioning and shuffling, which is the right thing, but I don't know the performance yet. (4) is really more of a workaround, and (1) is preferable.

As for (5), I don't know what they are currently doing for DDP training, but I expect they will have a good solution sooner or later.

I usually use (2).

@cliffzhao
Copy link

@superhero-7 I have been training for a couple of days with the resample+with_epoch approach and so far the training is running without any issues. However, it would be nice to have an example on the optimal way to use WebDataset with lightning and ddp somewhere in the docs.

Hi @adhakal224 , have you solved the slow speed issue? Currently, I'm using webdataset with pytorch-lightning in DDP training, but the speed is extremely slow.

@ForJadeForest
Copy link

You can find a DDP training example here:

https://github.com/webdataset/webdataset-imagenet

At this point, there are several ways of dealing with DDP training:

(1) use node splitting and Join (this gives you "exact epochs") (2) use resampling and with_epoch; this gives you slightly different sample statistics but is actually a good thing to do (3) use Ray datasets with webdataset (once that is released) (4) use WebDataset with repeat and with_epoch ("ddp equalize") (5) use torchdata

(1) didn't use to work with Lightning, but I think they added support for it now. (2) should work fine with Lightning, but I haven't tested it recently. (3) allows you to do on the fly repartitioning and shuffling, which is the right thing, but I don't know the performance yet. (4) is really more of a workaround, and (1) is preferable.

As for (5), I don't know what they are currently doing for DDP training, but I expect they will have a good solution sooner or later.

I usually use (2).

If resample=True will lead to the different device get the same data?
Besides, does the with_epoch(1000) mean all urls will repeat 1000 times?

@cliffzhao
Copy link

@superhero-7 I have been training for a couple of days with the resample+with_epoch approach and so far the training is running without any issues. However, it would be nice to have an example on the optimal way to use WebDataset with lightning and ddp somewhere in the docs.

Hi @adhakal224 , have you solved the slow speed issue? Currently, I'm using webdataset with pytorch-lightning in DDP training, but the speed is extremely slow.

The slow issue has been resolved, and it was not a problem with WebDataset.

@tmbdev
Copy link
Collaborator

tmbdev commented Mar 20, 2023

If resample=True will lead to the different device get the same data?

It should. The RNGs are initialized differently on each host and worker.

Besides, does the with_epoch(1000) mean all urls will repeat 1000 times?

  • with_epoch sets the epoch length to 1000 items (either samples or batches, depending on where in the pipeline you use it).
  • with_length sets the nominal length of the dataset (the value returned by len(...)), but doesn't change anything else about the pipeline.
  • repeat(n) repeats the iterator n times (i.e., it runs for n epochs)

@ForJadeForest
Copy link

If resample=True will lead to the different device get the same data?

It should. The RNGs are initialized differently on each host and worker.

Besides, does the with_epoch(1000) mean all urls will repeat 1000 times?

  • with_epoch sets the epoch length to 1000 items (either samples or batches, depending on where in the pipeline you use it).
  • with_length sets the nominal length of the dataset (the value returned by len(...)), but doesn't change anything else about the pipeline.
  • repeat(n) repeats the iterator n times (i.e., it runs for n epochs)

I find that the step num in one epoch is not correct when I set num worker > 0. Here is my data pipeline:

dataset = (
    wds.WebDataset(urls, resampled=True)
    .shuffle(shuffle)
    .decode("pil")
    .to_tuple("jpg", "txt")
    .map(lambda x: (transform(x[0]), tokenize(x[1], truncate=True).squeeze()))
    .with_epoch(768 * 20)
)

loader = wds.WebLoader(
    dataset,
    batch_size=768,
    shuffle=False,
    num_workers=16,
    pin_memory=True
)

During training, I obtained 320 batches, each consisting of 768 items. However, in my opinion, the num_worker should not affect the total number of batches, which would mean that I should have obtained only 20 batches.

In addition, I would like to know the correct way to set with_epoch(n) when training with DDP and num_workers > 0. The total number of data points is approximately 3 million image-text pairs, with a batch size of 768 and 16 workers. How should I determine the appropriate value for with_epoch(n) to make use of all URLs (all image-text pairs)?

Btw, the dataset has the .batched() pipeline, where should I set the batch_size (dataloader or dataset)?

@urinieto
Copy link

In case this helps, here's an example of using DDP + WebDataset + PyTorch Lightning: https://github.com/webdataset/webdataset-lightning

@RoyJames
Copy link
Contributor

RoyJames commented Jun 6, 2023

Sure! Here is my dataset and dataloader

self.dataset = self.dataset.shuffle(1000).decode('pil').to_tuple("groundlevel.jpg", "overhead.jpg", "metadata.json","__key__").map(self.do_transforms).batched(self.args.train_batch_size).with_epoch(self.args.train_epoch_length)

trainloader = wds.WebLoader(self.trainset, batch_size=None,
                    shuffle=False, pin_memory=True, num_workers=self.hparams.num_workers)
trainloader = trainloader.unbatched().shuffle(1000).batched(self.hparams.train_batch_size)

I cant see the progress bar normally. I think thats natural as there is no concept of length in the WebDataset. Hence, it is not possible to figure out how many steps it will take for 1 epoch to end.

You can also display the progress bar correctly by using both .with_epoch(epoch_length).with_length(epoch_length)

@alexqdh
Copy link

alexqdh commented Jun 12, 2023

  1. ddp_equalize seems deprecated now. Refer to ddp_equalize #194IGNORE_test_ddp_equalize
  2. ddp_equalize is also use with_epoch and with_length funtions. Refer to ddp fixes

@jrcavani
Copy link

(2) use resampling and with_epoch; this gives you slightly different sample statistics but is actually a good thing to do

@tmbdev does this mean not going through strictly non-overlapping in an epoch, but each rank independently random-sample shards is in practice better during training?

@laolongboy
Copy link

If resample=True will lead to the different device get the same data?

It should. The RNGs are initialized differently on each host and worker.

Besides, does the with_epoch(1000) mean all urls will repeat 1000 times?

  • with_epoch sets the epoch length to 1000 items (either samples or batches, depending on where in the pipeline you use it).
  • with_length sets the nominal length of the dataset (the value returned by len(...)), but doesn't change anything else about the pipeline.
  • repeat(n) repeats the iterator n times (i.e., it runs for n epochs)

I find that the step num in one epoch is not correct when I set num worker > 0. Here is my data pipeline:

dataset = (
    wds.WebDataset(urls, resampled=True)
    .shuffle(shuffle)
    .decode("pil")
    .to_tuple("jpg", "txt")
    .map(lambda x: (transform(x[0]), tokenize(x[1], truncate=True).squeeze()))
    .with_epoch(768 * 20)
)

loader = wds.WebLoader(
    dataset,
    batch_size=768,
    shuffle=False,
    num_workers=16,
    pin_memory=True
)

During training, I obtained 320 batches, each consisting of 768 items. However, in my opinion, the num_worker should not affect the total number of batches, which would mean that I should have obtained only 20 batches.

In addition, I would like to know the correct way to set with_epoch(n) when training with DDP and num_workers > 0. The total number of data points is approximately 3 million image-text pairs, with a batch size of 768 and 16 workers. How should I determine the appropriate value for with_epoch(n) to make use of all URLs (all image-text pairs)?

Btw, the dataset has the .batched() pipeline, where should I set the batch_size (dataloader or dataset)?

Same question. How to set the dataloader's num_worker to get the correct num of batches for each epoch?

@HuangChiEn
Copy link

HuangChiEn commented Mar 19, 2024

DDP_equlize sucks and deprecated!!
resampled=True + with_epoch(.) hard to understand the behavior, and doesn't support multinode (each node seeing different part of dataset, but consume the same dataset in each epoch)
naive wds.split_by_node simply stuck the single-node multi-gpu ~

Hope the following thread close this issue ~
Webdataset (Liaon115M) + Torchlightning (pl.DataModule) with visualizing progressbar during training

@tmbdev
Copy link
Collaborator

tmbdev commented Mar 20, 2024

@jrcavani Resampling is a typical strategy in statistics to generate slight variations of the dataset. It is used for various statistical estimators. As such, you can view it as a kind of "data augmentation". It's potentially a good thing to do in that sense.

@tmbdev
Copy link
Collaborator

tmbdev commented Mar 21, 2024

@laolongboy

Same question. How to set the dataloader's num_worker to get the correct num of batches for each epoch?

Short answer: use .with_epoch on WebLoader, not WebDataset. Note that this will specify the epoch in terms of the number of batches by default.

Long answer, the best pipeline is probably:

dataset = WebDataset(...) ... .batched(16)
loader = WebLoader(dataset, batch_size=None, ...).unbatched().shuffle(1000).with_epoch(n).batched(batch_size)

The first batching is done to make the worker-to-loader transfers more efficient. The unbatching, shuffling, and rebatching shuffles between workers and then constitutes the final batches.

@tmbdev
Copy link
Collaborator

tmbdev commented Mar 21, 2024

@HuangChiEn

The short answer is: use wids and ShardListDataset. It behaves just like other indexed datasets and works exactly like other datasets for distributed training.

We implemented this because distributed training based on IterableDataset and DataLoader is just hard.

resampled=True + with_epoch(.) hard to understand the behavior, and doesn't support multinode (each node seeing different part of dataset, but consume the same dataset in each epoch)

.with_epoch is just a synonym for islice; that's all. The proper place to use it is on WebLoader, not WebDataset (see above).

In many environments, we want every shard to be available on every node. If you are trying to split shards by node so that each node only sees a subset, you can use DataPipeline to do that.

The distributed sampler in wids uses a fixed subset of shards on each node by default. So, again, the simplest solution if you want that behavior is to use wids.

@HuangChiEn
Copy link

@HuangChiEn

The short answer is: use wids and ShardListDataset. It behaves just like other indexed datasets and works exactly like other datasets for distributed training.

We implemented this because distributed training based on IterableDataset and DataLoader is just hard.

resampled=True + with_epoch(.) hard to understand the behavior, and doesn't support multinode (each node seeing different part of dataset, but consume the same dataset in each epoch)

.with_epoch is just a synonym for islice; that's all. The proper place to use it is on WebLoader, not WebDataset (see above).

In many environments, we want every shard to be available on every node. If you are trying to split shards by node so that each node only sees a subset, you can use DataPipeline to do that.

The distributed sampler in wids uses a fixed subset of shards on each node by default. So, again, the simplest solution if you want that behavior is to use wids.

i see, haven't try it before, but sounds also a promised solution, tks for sharing ~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests