In [1]:
import webdataset as wds
import braceexpand
from torch.utils.data import IterableDataset

# Local and Remote Storage URLs

WebDataset refers to data sources using file paths or URLs. The following are all valid ways of referring to a data source:

In [2]:
dataset = wds.WebDataset("dataset-000.tar")
dataset = wds.WebDataset("file:dataset-000.tar")
dataset = wds.WebDataset("http://server/dataset-000.tar")

An additional way of referring to data is using the `pipe:` scheme, so the following is also equivalent to the above references:

In [3]:
dataset = wds.WebDataset("pipe:cat dataset-000.tar")

You can use the same notation for accessing data in cloud storage:

In [4]:
dataset = wds.WebDataset("pipe:gsutil cat gs://somebucket/dataset-000.tar")

Note that access to standard web schemas are implemented using `curl`. That is, `http://server/dataset.tar` is internally simply treated like `pipe:curl -s -L 'http://server/dataset.tar'`. The use of `curl` to access Internet protocols actually is more efficient than using the built-in `http` library because it results in asynchronous name resolution and downloads.

You can define handlers for new schemes or override implementations for existing schemes by adding entries to `wds.gopen_schemes`:

In [5]:
def gopen_gs(url, mode="rb", bufsize=8192):
    ...

wds.gopen_schemes["gs"] = gopen_gs 

# Standard Input/Output

As a special case, the string "-" refers to standard input (reading) or standard output (writing). This allows code using WebDataset to be used as pipes, and it permits code written over shards to be applied to individual files on disk.

For example, assume that you have an image classification command line program and you want to apply it to a collection of images in a directory. You might write:

```Python
tar cf - *.jpg | shard-classifier - -o - | tar xvf - --include '.cls'
```

This is the rough equivalent of:

```Python
for image in *.jpg; do
   image-classifier $image > $(basename $image .jpg).cls
done
```

# Multiple Shards and Mixing Datasets

The `WebDataset` and `ShardList` classes take either a string or a list of strings as an argument. When given a string, the string is expanded using `braceexpand`. Therefore, the following three datasets are equivalent:

In [6]:
dataset = wds.WebDataset(["dataset-000.tar", "dataset-001.tar", "dataset-002.tar", "dataset-003.tar"])
dataset = wds.WebDataset("dataset-{000..003}.tar")
dataset = wds.WebDataset("file:dataset-{000..003}.tar")

For complex training problems, you may want to mix multiple datasets, where each dataset consists of multiple shards. A good way is to expand each shard spec individually using `braceexpand` and concatenate the lists. Then you can pass the result list as an argument to `WebDataset`.

In [7]:
urls = (
    list(braceexpand.braceexpand("imagenet-{000000..000146}.tar")) +
    list(braceexpand.braceexpand("openimages-{000000..000547}.tar")) +
    list(braceexpand.braceexpand("custom-images-{000000..000999}.tar"))
)
print(len(urls))
dataset = wds.WebDataset(urls, shardshuffle=True).shuffle(10000).decode("torchrgb")

1695


# Mixing Datsets with a Custom `IterableDataset` Class

For more complex sampling problems, you can also write sample processors. For example, to sample equally from several datasets, you could write something like this (the `Shorthands` and `Composable` base classes just add some convenience methods):

In [8]:
class SampleEqually(IterableDataset, wds.Shorthands, wds.Composable):
    def __init__(self, datasets):
        super().__init__()
        self.datasets = datasets
    def __iter__(self):
        sources = [iter(ds) for ds in self.datasets]
        while True:
            for source in sources:
                try:
                    yield next(source)
                except StopIteration:
                    return

Now we can mix samples from different sources in more complex ways:

In [9]:
dataset1 = wds.WebDataset("imagenet-{000000..000146}.tar", shardshuffle=True).shuffle(1000).decode("torchrgb")
dataset2 = wds.WebDataset("openimages-{000000..000547}.tar", shardshuffle=True).shuffle(1000).decode("torchrgb")
dataset3 = wds.WebDataset("custom-images-{000000..000999}.tar", shardshuffle=True).shuffle(1000).decode("torchrgb")
dataset = SampleEqually([dataset1, dataset2, dataset3]).shuffle(1000)