-
Notifications
You must be signed in to change notification settings - Fork 398
Description
Describe the bug
When downloading datasets from Minari that have categorical data (like missions), the download of the dataset is broken, as every step in the dataset will get assigned the array of categorical values for the first episode.
This error originates from treating the Minari categorical data as NonTensorData instead of NonTensorStack.
To Reproduce
Take this code, for example:
from torchrl.data.datasets.minari_data import MinariExperienceReplay
BATCH_SIZE = 1 # Small batch size to minimize processing
SAVE_ROOT = None # Set to your desired save location, or None for default
def download_minari_datasets(dataset_id):
a = MinariExperienceReplay(
dataset_id=dataset_id,
batch_size=BATCH_SIZE,
root=SAVE_ROOT,
)
print(f"✓ Successfully downloaded {dataset_id}")
if __name__ == "__main__":
download_minari_datasets("minigrid/BabyAI-Pickup/optimal-v0")
And now watch very closely what happens when we ask for mission values across random steps in random episodes:
In [3]: a[0][('observation', 'mission')][:5], len(a[0][('observation', 'mission')])
Out[3]:
(array([b'pick up a blue box', b'pick up a blue box',
b'pick up a blue box', b'pick up a blue box',
b'pick up a blue box'], dtype=object),
109)
In [4]: a[200][('observation', 'mission')][:5], len(a[0][('observation', 'mission')])
Out[4]:
(array([b'pick up a blue box', b'pick up a blue box',
b'pick up a blue box', b'pick up a blue box',
b'pick up a blue box'], dtype=object),
109)
In [5]: a[5000][('observation', 'mission')][:5], len(a[0][('observation', 'mission')])
Out[5]:
(array([b'pick up a blue box', b'pick up a blue box',
b'pick up a blue box', b'pick up a blue box',
b'pick up a blue box'], dtype=object),
109)
It is the same array copied over and over again!
Expected behavior
Minari datasets that have missions clearly need this dataset to be able to function. And the dataset clearly is intended to be used in a stacked manner.
Additional context
I opened one PR and one issue in pytorch/tensordict incorrectly handling this issue:
pytorch/tensordict#1399
pytorch/tensordict#1400
In this comment, @vmoens pointed me to the correct way of dealing with this issue, which is to use the set_list_to_stack function.
Reason and Possible fixes
I have manually corrected this in the MinariExperienceReplay. I will attach the PR as soon as I can
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)