Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix start index in DataLoader. #215

Merged
merged 1 commit into from
Jan 30, 2023
Merged

Fix start index in DataLoader. #215

merged 1 commit into from
Jan 30, 2023

Conversation

RehMoritz
Copy link
Contributor

The start index in the DataLoader should range between [0, dataset_size], as the batch_indices are otherwise repetitive for the batches inside one epoch.

A MWE that demonstrates that the same data is drawn for two different steps is attached. The proposed solution is also in the MWE.

import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx
from typing import Tuple


class DataLoader(eqx.Module):
    arrays: Tuple[jnp.ndarray]
    batch_size: int
    key: jrandom.PRNGKey

    def __post_init__(self):
        dataset_size = self.arrays[0].shape[0]
        assert all(array.shape[0] == dataset_size for array in self.arrays)

    def __call__(self, step):
        dataset_size = self.arrays[0].shape[0]
        num_batches = dataset_size // self.batch_size
        epoch = step // num_batches
        key = jrandom.fold_in(self.key, epoch)
        perm = jrandom.permutation(key, jnp.arange(dataset_size))
        start = (step % num_batches) * self.batch_size
        start = step * self.batch_size
        slice_size = self.batch_size
        batch_indices = lax.dynamic_slice_in_dim(perm, start, slice_size)
        return tuple(array[batch_indices] for array in self.arrays)


data = jnp.arange(10)
dataloader = DataLoader(arrays=(data,), batch_size=5, key=jax.random.PRNGKey(0))
step = 100

print(dataloader(step))
print(dataloader(step + 1))

The start index in the DataLoader should range between ``[0, dataset_size]``, as the ``batch_indices`` are otherwise repetitive for the batches inside one epoch.

```
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx
from typing import Tuple


class DataLoader(eqx.Module):
    arrays: Tuple[jnp.ndarray]
    batch_size: int
    key: jrandom.PRNGKey

    def __post_init__(self):
        dataset_size = self.arrays[0].shape[0]
        assert all(array.shape[0] == dataset_size for array in self.arrays)

    def __call__(self, step):
        dataset_size = self.arrays[0].shape[0]
        num_batches = dataset_size // self.batch_size
        epoch = step // num_batches
        key = jrandom.fold_in(self.key, epoch)
        perm = jrandom.permutation(key, jnp.arange(dataset_size))
        start = (step % num_batches) * self.batch_size
        start = step * self.batch_size
        slice_size = self.batch_size
        batch_indices = lax.dynamic_slice_in_dim(perm, start, slice_size)
        return tuple(array[batch_indices] for array in self.arrays)


data = jnp.arange(10)
dataloader = DataLoader(arrays=(data,), batch_size=5, key=jax.random.PRNGKey(0))
step = 100

print(dataloader(step))
print(dataloader(step + 1))
```
@patrick-kidger patrick-kidger merged commit 5c0bfb3 into patrick-kidger:main Jan 30, 2023
@patrick-kidger
Copy link
Owner

Whoops!
Good catch, and thank you for submitting this PR.

@RehMoritz RehMoritz deleted the patch-1 branch January 30, 2023 13:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants