Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced Batch slicing & minor fix of RNN support #106

Merged
merged 8 commits into from
Jun 30, 2020
Merged

Conversation

Trinkle23897
Copy link
Collaborator

Related to #104

@Trinkle23897 Trinkle23897 linked an issue Jun 29, 2020 that may be closed by this pull request
@Trinkle23897
Copy link
Collaborator Author

Trinkle23897 commented Jun 29, 2020

@duburcqa I add as what you said, but it turns out that it is not compatible with tests in test/base.

Some issues:

  • min(*data_shape) will return the smallest data_shape, not the smallest size in every dimension. E.g., min((1,2,3,4), (3,2,1)) will result in (1,2,3,4) instead of (1,2,1);
  • if we want to force use np.array instead of list, we should further overwrite the __setattr__ method;
  • some tests fail strangely, TypeError: zip argument #2 must support iteration, and in "test_collector_with_dict_state" it fails to deep-copy. Should we change all the list in test to np.ndarray or add extra except-handling method within Batch?

One more thing: ListReplayBuffer use list instead of np.ndarray. To entirely change list to np.ndarray, this class should be abandoned, which will related to #105 ?

@Trinkle23897
Copy link
Collaborator Author

elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)):
self.state[id] *= 0

I haven't figured it out the best practice to make it compatible with _create_value (None issue). Maybe we should create batch.zero_()?

tianshou/data/batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
@duburcqa
Copy link
Collaborator

Maybe we should create batch.zero_()

Yes it would be nice, or rather empty_, since it is doing something closer to empty than zero.

@Trinkle23897
Copy link
Collaborator Author

  • some tests fail strangely, TypeError: zip argument #2 must support iteration, and in "test_collector_with_dict_state" it fails to deep-copy. Should we change all the list in test to np.ndarray or add extra except-handling method within Batch?

One more thing: ListReplayBuffer use list instead of np.ndarray. To entirely change list to np.ndarray, this class should be abandoned, which will related to #105 ?

How about these issue?

@duburcqa
Copy link
Collaborator

duburcqa commented Jun 29, 2020

min(*data_shape) will return the smallest data_shape, not the smallest size in every dimension. E.g., min((1,2,3,4), (3,2,1)) will result in (1,2,3,4) instead of (1,2,1);

True. It should be replaced by list(map(min, zip(*data_shape)))

@duburcqa
Copy link
Collaborator

if we want to force use np.array instead of list, we should further overwrite the setattr method;

Yes, and I think it is a good thing to better control the data added to Batch instances or modified.

@duburcqa
Copy link
Collaborator

some tests fail strangely, TypeError: zip argument #2 must support iteration, and in "test_collector_with_dict_state" it fails to deep-copy. Should we change all the list in test to np.ndarray or add extra except-handling method within Batch?

I think it is related to __setattr__ issue, not converting list to array.

@duburcqa
Copy link
Collaborator

One more thing: ListReplayBuffer use list instead of np.ndarray. To entirely change list to np.ndarray, this class should be abandoned.

Good point. I don't know how to handle this, but it is only used by Collector, that you are refactoring, so maybe it would be better to refactor this part. I would only a raw list of Batch/dict, replacing:

self._cached_buf[i].add(**data)

by simply

self._cached_buf[i].append(Batch(data))

@Trinkle23897
Copy link
Collaborator Author

One more thing: ListReplayBuffer use list instead of np.ndarray. To entirely change list to np.ndarray, this class should be abandoned.

Good point. I don't know how to handle this, but it is only used by Collector, that you are refactoring, so maybe it would be better to refactor this part. I would only a raw list of Batch/dict, replacing:

self._cached_buf[i].add(**data)

by simply

self._cached_buf[i].append(Batch(data))

I found an interesting fact:

        if self._meta.__dict__.get(name, None) is None:
            self._meta.__dict__[name] = []
        self._meta.__dict__[name].append(inst)

this will bypass both of the __setitem__ and __setattr__ method. So the ListReplayBuffer would not be changed.

@Trinkle23897
Copy link
Collaborator Author

Why did you use @staticmethod in Batch.stack, @classmethod in Batch.cat?

@duburcqa
Copy link
Collaborator

duburcqa commented Jun 29, 2020

this will bypass both of the setitem and setattr method. So the ListReplayBuffer would not be changed.

Yes, but in fact it would be better if we could use self._meta[name] instead of self._meta.__dict__[name], precisely to avoid such bypass behavior, but I was afraid of the impact on the performances (yet I haven't tried).

Why did you use @staticmethod in Batch.stack, @classmethod in Batch.cat?

No reason, @staticmethod is enough for both.

@Trinkle23897
Copy link
Collaborator Author

Trinkle23897 commented Jun 29, 2020

@duburcqa It should be ready now. Please have a code review at your convenience.
By the way, I have not removed the redundant code. You can comment and point them out (which will be added in this pr).

test/base/test_batch.py Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Outdated Show resolved Hide resolved
@duburcqa
Copy link
Collaborator

Very good job ! Still 2 comments to address 😁

@duburcqa
Copy link
Collaborator

I was thinking Batch.__init__ could have an extra argument copy similar to numpy (default to False), to enforce deepcopy of input argument. It could be convenient and it is only 2 lines of code.

@duburcqa
Copy link
Collaborator

By the way, did you do some profiling to check the performances ?

@Trinkle23897
Copy link
Collaborator Author

By the way, did you do some profiling to check the performances?

On my computer, it is the same as the last commit. However, GitHub shows that after the third commit of this PR, it shows down to some degree. I think maybe it is because of slicing: every time the __getitem__ will recursively call the _valid_bounds.

@duburcqa
Copy link
Collaborator

duburcqa commented Jun 30, 2020

output

Apparently it is fine, at least for PPO.

test/base/test_batch.py Outdated Show resolved Hide resolved
tianshou/data/batch.py Show resolved Hide resolved
tianshou/data/batch.py Show resolved Hide resolved
@Trinkle23897
Copy link
Collaborator Author

Trinkle23897 commented Jun 30, 2020

I'd like to merge because I have other things to do (two projects and one deadline tomorrow). You can create a new PR to adapt your thought, and I can make a code review for you :)

Something to do:

  • change Batch.empty to in-place fill
  • Batch(copy=True)

@Trinkle23897 Trinkle23897 merged commit db0e2e5 into master Jun 30, 2020
@Trinkle23897 Trinkle23897 deleted the jiayi-dev branch June 30, 2020 10:02
@duburcqa duburcqa mentioned this pull request Jul 8, 2020
BFAnas pushed a commit to BFAnas/tianshou that referenced this pull request May 5, 2024
* add shape property and modify __getitem__

* change Batch.size to Batch.shape

* setattr

* Batch.empty

* remove scalar in advanced slicing

* modify empty_ and __getitem__

* missing testcase

* fix empty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Advance slicing method of Batch
2 participants