Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch: only allow entries with the same length #1087

Open
MischaPanch opened this issue Apr 3, 2024 · 3 comments
Open

Batch: only allow entries with the same length #1087

MischaPanch opened this issue Apr 3, 2024 · 3 comments
Labels
Batch and Buffer Improvements in internal data structures, temporary label breaking changes Changes in public interfaces. Includes small changes or changes in keys refactoring No change to functionality

Comments

@MischaPanch
Copy link
Collaborator

Currenly sequences in Batch can be of different lengths which can easily lead to problems. E.g.,

b = Batch(a=[1, 2, 3], b=[1, 2])

will lead to

len(b) == 2

and list(b) will result in essentially [Batch(a=1, b=1), Batch(a=2, b=2)], silently ignoring the last element of the sequence associated with a.

This issue becomes more severe when noticing that batches can (and often do) contain subbatches, and if these contain sequences, also there the shortest sequence will determine the length and elements of the overall batch.

Since Batch is supposed to behave somehow like an array, being sliceable and all, it should only support equal-size sequences.

Note that empty subbatches and empty subarrays are treated differently! E.g.,

b = Batch(a=[1,2,3], b= {})

leads to a batch of length 3, whereas

b = Batch(a=[1,2,3], b= ())

leads to a batch of length 0

@MischaPanch MischaPanch added refactoring No change to functionality breaking changes Changes in public interfaces. Includes small changes or changes in keys Batch and Buffer Improvements in internal data structures, temporary label labels Apr 3, 2024
@MischaPanch MischaPanch added this to To do in Overall Tianshou Status via automation Apr 3, 2024
@Trinkle23897
Copy link
Collaborator

It's mostly for gathering env info in buffer. If the info has multiple keys and different info dict has different keys, it would be hard to calculate the length to do proper slice. For example,

arr_len_2 = [1, 2]
buffer = Batch(obs=arr_len_2, info={})
buffer[0]  # will return Batch(obs=1, info={}), which needs to make sure len(buffer) > 0
buffer = Batch(obs=arr_len_2, info={"a": arr_len_2})
buffer[0]  # will return Batch(obs=1, info={"a": 1})
buffer = Batch(obs=arr_len_2, info={"a": arr_len_2, "b": [1]})
buffer[1]  # will raise error because len(buffer) == 1

@Trinkle23897
Copy link
Collaborator

But I agree we can make a better interface to make sure it won't happen, instead of making sugars inside batch

@MischaPanch
Copy link
Collaborator Author

Yeah, I was thinking of:

  1. by default add NaN or user-provided default entries to each nested info key
  2. Option to raise errors if things would lead to different lengths, like the differing keys in your example above. For envs, it should be ok to demand that info always returns a dict with the same keys, even on reset. If not, an error should be raised, or the missing entries filled up
  3. Raise error if new keys are added with adding a single item or if a sequence of wrong length would be added

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Batch and Buffer Improvements in internal data structures, temporary label breaking changes Changes in public interfaces. Includes small changes or changes in keys refactoring No change to functionality
Development

No branches or pull requests

2 participants