New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement forward_all function that performs forwarding on mulptiple variables at once #243
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your PR! We think you made a good point.
While your code looks good for the most part, we have two suggestions that may further improve it.
First, GIL should be released when forward_all is running.
It is not desirable that other python threads are paused while forward_all is running.
So, could you make it callable without GIL as we suggested?
Second, could you write a test code for forward_all function?
We would like to make sure that forward_all works correctly, especially for the networks with more complicated architectures, such as skip, branching and merging connections.
It may be helpful to reference the test code for normal forward function at https://github.com/sony/nnabla/blob/master/python/test/test_graph.py
We thank you again for your contribution.
@TE-AkioHayakawa Hi, Hayakawa san, it's been a week! |
@TE-AkioHayakawa Hi, Hayakawa san! |
Sorry for late reply. Thank you for your commitment! |
Test failed in python 2.7.15 environment. (test_graph_clear_buffer, test_graph_rewire). |
Thank you for reviewing my code! I'll check unit tests with Python 2. |
Sorry, I might be wrong. This problem is not related to python 2. I found that the result gradients are not correct when CpuArray is used (not CpuCachedArray) and backward() is called with clear_buffer=True. |
I think I got your point. Unlike |
CI test has passed. |
@takuseno I will merge your pull request as soon as possible. |
@TE-AkioHayakawa Thank you!! |
Hi, @TE-TakuyaNarihira san.
I added new function
forward_all
.Formerly, forwarding multiple outputs which shared hidden layers (just like policy and value of actor-critic) requires multiple forwardings in static graph. See the example below.
This is critical at huge architectures.
Thus, I added
forward_all
function.This function performs forwarding with shared fclosed, which prevents shared layers from being revisited.
What do you think of this?