You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue aims to end with the discussion that how to run the webdataset with torchlightning.
Several concepts should be clarified here :
If we use webdataset, it means the dataset is large-scale (if not, then why bother ?), so the multi-gpu support is definitely needed!!
Yes, dataset length is impotent. No one! yes, no one want to train on a dataset without knowing the step per epoch in large-scale dataset!!. So, dataset length (num of data sample in dataset) is definitely needed!!
It's weird that torch IterableDataset doesn't support too much in multi-gpu manner. Because the reason we use IterableDataset just exactly that we're working with the large-scale dataset and Multi-gpu is necessary for large-scale training.
ok, if dataset length is impotent, then how can I get it ?
while, we wrote a parallel script for scanning the large-scale dataset, and record the image for each tar-file and store in a json file with dict manner : {'/data/laion115m/00000.tar':1147, '/data/laion115m/00001.tar':..., /data/laion115m/11164.tar' : 551}.
After some murmur, let's begin 👍
First step, use Webdataset with toy usage only (for-loop with single stream iteration)
# feed dummy pipe `wds.split_by_worker`, which is the default of workersplit, thus it'll not affect the results
dataset = wds.WebDataset(url, nodesplitter=wds.split_by_worker).shuffle(1000).decode('pilrgb', handler=wds.warn_and_continue).to_tuple("jpg", "txt", handler=wds.warn_and_continue).map(trfs)
# toy usage :
for batch_id, sample in enumerate(dataset):
print('ok, good! we can get the sample')
Wrap webdataset in torch IterableDataset (aims to do nodesplit)
from torch.distributed import get_rank, get_world_size
class Iter_ds(torch.utils.data.IterableDataset):
def __init__(self, urls, trfs, n_sample):
self.urls = urls
self.trfs = trfs
self.n_sample = n_sample
def __len__(self):
# let's say i have 100 image totally, 2 gpus, batch_size = 4.
# then n_step per epoch should be : `100 // (2 * 4) = 17`, and last batch doesn't fill with 4 samples.
# in here, we directly control how many n_step by __len__.
# for the above example, in here should be `100 // 2`, then torch dataloader will try to divided this number by batch_size automatically ~ (since we setup batch_size in torch dataloader)
return self.n_sample // get_world_size()
def __iter__(self):
process_rank = get_rank()
world_size = get_world_size()
for url in self.urls:
# feed dummy pipe `wds.split_by_worker`, which is the default of workersplit, thus it'll not affect the results
# or the default `wds.single_node_split` will broken the node_split procedure
dataset = wds.WebDataset(url, nodesplitter=wds.split_by_worker).shuffle(1000).decode('pilrgb', handler=wds.warn_and_continue).to_tuple("jpg", "txt", handler=wds.warn_and_continue).map(self.trfs)
for batch_id, sample in enumerate(dataset):
# assign a independent batch for the gpu wrt. gpu_id (nodesplitter in here)
if batch_id % world_size == process_rank:
yield sample
# skip the batch it doesn't belong to the gpu
else:
continue
The last step, take torch IterableDataset as a member in pl.DataModule
This issue aims to end with the discussion that how to run the webdataset with torchlightning.
Several concepts should be clarified here :
ok, if dataset length is impotent, then how can I get it ?
After some murmur, let's begin 👍
version of package 🥇
ok, now it's peace, no argument for torchlightning + webdataset, hopefully...
The text was updated successfully, but these errors were encountered: