In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import webdataset as wds
import logging
import torch
import io

def url_to_dataloader(url, num_workers=4, batch_size=16):
    def log_and_continue(exn):
        """Call in an exception handler to ignore any exception, issue a warning, and continue."""
        logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
        return True
    
    def filter_no_latent(sample):
        return 'latent.pt' in sample

    def load_latent(z):
        return torch.load(io.BytesIO(z), map_location='cpu').to(torch.float32)
    
    pipeline = [
        wds.SimpleShardList(url),
        wds.split_by_node,
        wds.split_by_worker,
        wds.tarfile_to_samples(handler=log_and_continue),
        wds.select(filter_no_latent),
        wds.shuffle(bufsize=5000, initial=1000),
        wds.rename(image="latent.pt", txt="txt"),
        wds.map_dict(image=load_latent, txt=lambda x: x.decode("utf-8")),
        wds.to_tuple("image", "txt"),
        wds.batched(batch_size, partial=False),
    ]

    dataset = wds.DataPipeline(*pipeline)

    loader = wds.WebLoader(
        dataset, batch_size=None, shuffle=False, num_workers=num_workers,
    )
    return loader

In [3]:
url = "/share/datasets/datasets/laicoyo/{000000..000009}.tar"
loader = url_to_dataloader(url)

In [None]:
d = next(iter(loader))
d