Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Webdataset (Liaon115M) + Torchlightning (pl.DataModule) with visualizing progressbar during training #346

Closed
HuangChiEn opened this issue Mar 19, 2024 · 1 comment

Comments

@HuangChiEn
Copy link

HuangChiEn commented Mar 19, 2024

This issue aims to end with the discussion that how to run the webdataset with torchlightning.

Several concepts should be clarified here :

  1. If we use webdataset, it means the dataset is large-scale (if not, then why bother ?), so the multi-gpu support is definitely needed!!
  2. Yes, dataset length is impotent. No one! yes, no one want to train on a dataset without knowing the step per epoch in large-scale dataset!!. So, dataset length (num of data sample in dataset) is definitely needed!!
  3. It's weird that torch IterableDataset doesn't support too much in multi-gpu manner. Because the reason we use IterableDataset just exactly that we're working with the large-scale dataset and Multi-gpu is necessary for large-scale training.

ok, if dataset length is impotent, then how can I get it ?

while, we wrote a parallel script for scanning the large-scale dataset, and record the image for each tar-file and store in a json file with dict manner : {'/data/laion115m/00000.tar':1147, '/data/laion115m/00001.tar':..., /data/laion115m/11164.tar' : 551}.


After some murmur, let's begin 👍

  1. First step, use Webdataset with toy usage only (for-loop with single stream iteration)
# feed dummy pipe `wds.split_by_worker`, which is the default of workersplit, thus it'll not affect the results
dataset = wds.WebDataset(url, nodesplitter=wds.split_by_worker).shuffle(1000).decode('pilrgb', handler=wds.warn_and_continue).to_tuple("jpg", "txt", handler=wds.warn_and_continue).map(trfs)

# toy usage : 
for batch_id, sample in enumerate(dataset):
      print('ok, good! we can get the sample')
  1. Wrap webdataset in torch IterableDataset (aims to do nodesplit)
from torch.distributed import get_rank, get_world_size

class Iter_ds(torch.utils.data.IterableDataset):
    def __init__(self, urls, trfs, n_sample):
        self.urls = urls
        self.trfs = trfs
        self.n_sample = n_sample
        
    def __len__(self):  
        # let's say i have 100 image totally, 2 gpus, batch_size = 4.
        # then n_step per epoch should be : `100 // (2 * 4) = 17`, and last batch doesn't fill with 4 samples.
        # in here, we directly control how many n_step by __len__. 
        # for the above example, in here should be `100 // 2`, then torch dataloader will try to divided this number by batch_size automatically ~ (since we setup batch_size in torch dataloader)  
        return self.n_sample // get_world_size()
                       
    def __iter__(self):
        process_rank = get_rank()
        world_size = get_world_size()
        for url in self.urls:  
            # feed dummy pipe `wds.split_by_worker`, which is the default of workersplit, thus it'll not affect the results
            # or the default `wds.single_node_split` will broken the node_split procedure 
            dataset = wds.WebDataset(url, nodesplitter=wds.split_by_worker).shuffle(1000).decode('pilrgb', handler=wds.warn_and_continue).to_tuple("jpg", "txt", handler=wds.warn_and_continue).map(self.trfs)
            for batch_id, sample in enumerate(dataset):
                # assign a independent batch for the gpu wrt. gpu_id (nodesplitter in here)
                if batch_id % world_size == process_rank: 
                    yield sample
                # skip the batch it doesn't belong to the gpu
                else: 
                    continue
  1. The last step, take torch IterableDataset as a member in pl.DataModule
class Liaon115M(pl.LightningDataModule):

    def __init__(self, 
        data_dir: str = "path/to/dir", 
        split_ratio: list = [0.7, 0.1, 0.2], 
        img_transforms = None, 
        txt_transforms = None, 
        num_workers: int = 4, 
        batch_size: int = 16,
        num_epoch: int = 1,
        pin_memory: bool = False
    ):
        super().__init__()
        self.data_dir = Path(data_dir)
        self.split_ratio = split_ratio
        self.img_transforms = img_transforms
        self.txt_transforms = txt_transforms

        self.trfs = lambda tup : (self.img_transforms(tup[0]), self.txt_transforms(tup[1]))

        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        # laion 115M, success downloaded 104881705 records.. 
        
    # chking dataset is downloaded..
    def prepare_data(self): 
        with self.data_dir.open('r') as js_ptr:
            self.tar_dict = json.load(js_ptr)
        
        tar_lst = list(self.tar_dict.keys())
        n_shard = len(tar_lst)
        
        tra_ratio, val_ratio, _ = self.split_ratio
        self.tra_lst = tra_lst = tar_lst[ : int(n_shard * tra_ratio) ]
        self.val_lst = val_lst = tar_lst[ len(tra_lst) : len(tra_lst) + int(n_shard * val_ratio) ]
        self.tst_lst = tar_lst[ len(tra_lst) + len(val_lst) :  ]

    def _get_sample_num(self, tar_lst):
            cnt = 0
            for tar_key in tar_lst:
                cnt += self.tar_dict[tar_key]
            return cnt

    def setup(self, stage: str = 'train', rank=None, world_size=1):
        self.prepare_data()
        if stage == 'train':
            n_tra_sample = self._get_sample_num(self.tra_lst)
            print(f'Total sample in train ds : {n_tra_sample}')
            self.laion_train = Iter_ds(self.tra_lst, trfs=self.trfs, n_sample=n_tra_sample)
            n_val_sample = self._get_sample_num(self.val_lst)
            print(f'Total sample in valid ds : {n_val_sample}')
            self.laion_valid = Iter_ds(self.val_lst, trfs=self.trfs, n_sample=n_val_sample)
        else:
            n_tst_sample = self._get_sample_num(self.tst_lst)
            print(f'Total sample in test ds : {n_tst_sample}')
            self.laion_test = Iter_ds(self.tst_lst, trfs=self.trfs, n_sample=n_tst_sample)

    def train_dataloader(self):
        loader = torch.utils.data.DataLoader(
            self.laion_train, 
            batch_size=self.batch_size, 
            shuffle=False, 
            pin_memory=True, 
            num_workers=self.num_workers,
            prefetch_factor=2,
            drop_last=True
        )
        return loader 

    def val_dataloader(self):
        loader = torch.utils.data.DataLoader(
            self.laion_valid, 
            batch_size=self.batch_size, 
            shuffle=False, 
            pin_memory=True, 
            num_workers=self.num_workers,
            prefetch_factor=2,
            drop_last=True
        )
        return loader 
  1. Since we setup the len in iterableDataset, it will show the n_step and epoch in progress-bar

version of package 🥇

  1. webdataset : 0.2.86
  2. pytorch-lightning : 2.2.1
  3. torch : 2.2.0+cu118

ok, now it's peace, no argument for torchlightning + webdataset, hopefully...

@HuangChiEn
Copy link
Author

HAPPY CODING WITH TORCHLIGHTNING!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant