Skip to content

[BUG] Download of categorical data in MinariExperienceReplay is completely broken #3105

@marcosgalleterobbva

Description

@marcosgalleterobbva

Describe the bug

When downloading datasets from Minari that have categorical data (like missions), the download of the dataset is broken, as every step in the dataset will get assigned the array of categorical values for the first episode.

This error originates from treating the Minari categorical data as NonTensorData instead of NonTensorStack.

To Reproduce

Take this code, for example:

from torchrl.data.datasets.minari_data import MinariExperienceReplay

BATCH_SIZE = 1      # Small batch size to minimize processing
SAVE_ROOT = None    # Set to your desired save location, or None for default


def download_minari_datasets(dataset_id):
    a = MinariExperienceReplay(
        dataset_id=dataset_id,
        batch_size=BATCH_SIZE,
        root=SAVE_ROOT,
    )
    print(f"✓ Successfully downloaded {dataset_id}")


if __name__ == "__main__":
    download_minari_datasets("minigrid/BabyAI-Pickup/optimal-v0")

And now watch very closely what happens when we ask for mission values across random steps in random episodes:

In [3]: a[0][('observation', 'mission')][:5], len(a[0][('observation', 'mission')])
Out[3]: 
(array([b'pick up a blue box', b'pick up a blue box',
        b'pick up a blue box', b'pick up a blue box',
        b'pick up a blue box'], dtype=object),
 109)

In [4]: a[200][('observation', 'mission')][:5], len(a[0][('observation', 'mission')])
Out[4]: 
(array([b'pick up a blue box', b'pick up a blue box',
        b'pick up a blue box', b'pick up a blue box',
        b'pick up a blue box'], dtype=object),
 109)

In [5]: a[5000][('observation', 'mission')][:5], len(a[0][('observation', 'mission')])
Out[5]: 
(array([b'pick up a blue box', b'pick up a blue box',
        b'pick up a blue box', b'pick up a blue box',
        b'pick up a blue box'], dtype=object),
 109)

It is the same array copied over and over again!

Expected behavior

Minari datasets that have missions clearly need this dataset to be able to function. And the dataset clearly is intended to be used in a stacked manner.

Additional context

I opened one PR and one issue in pytorch/tensordict incorrectly handling this issue:
pytorch/tensordict#1399
pytorch/tensordict#1400

In this comment, @vmoens pointed me to the correct way of dealing with this issue, which is to use the set_list_to_stack function.

Reason and Possible fixes

I have manually corrected this in the MinariExperienceReplay. I will attach the PR as soon as I can

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions