### Infinite dSprites

#### Introduction
The 'InfiniteDSprites` dataset is an infinite streams of 2D sprites undergoing translation, rotation, and scaling. It is an extension of [dSprites](https://github.com/deepmind/dsprites-dataset).

#### Quick start
To start, create an instance of `InfiniteDSprites`. It is a subclass of PyTorch `IterableDataset`, so it is possible to iterate over it using a `DataLoader`.

In [None]:
from torch.utils.data import DataLoader

from disco.data import InfiniteDSprites
from disco.visualization import draw_batch

dataset = InfiniteDSprites()
dataloader = DataLoader(dataset, batch_size=16)
batch_img, latents = next(iter(dataloader))
draw_batch(batch_img, show=True)
print(latents.position_y)

#### Fine-grained control
The dataset has many dials that you can turn to adjust the distribution of the data. In particular, you can control the image size and the range of variability of the generative factors (by default, they are identical to the original dSprites):

In [None]:
import numpy as np

dataset = InfiniteDSprites(
    img_size=128,
    color_range=["red", "green", "blue"],
    scale_range= np.linspace(0, 1, 100),
    orientation_range= np.linspace(0, 2 * np.pi, 100),
    position_x_range= np.linspace(0, 1, 100),
    position_y_range= np.linspace(0, 1, 100),
)
dataloader = DataLoader(dataset, batch_size=16)
batch_img, _ = next(iter(dataloader))
draw_batch(batch_img, show=True)

By default, the resulting images will be an exhaustive product of these ranges in the order of (shape, color, scale, orientation, position_x, position_y). Think about it as an odometer, with the rightmost element advancing at every iteration.


> **NOTE**: the order in the `latents` namedtuple is (color, shape, scale, orientation, position_x, position_y). This is to maintain compatibility with dSprites. In any case, it is advisable to access the specific latent factors by name.

#### Subclasses
Use the `RandomDSprites` subclass to sample random values of each latent from an appropriate range at each step:

In [None]:
from disco.data import RandomDSprites

dataset = RandomDSprites(img_size=128, color_range=["red", "green", "blue"])
dataloader = DataLoader(dataset, batch_size=16)
batch_img, latents = next(iter(dataloader))
draw_batch(batch_img, show=True)
print(latents.position_y)
print(latents.color)