New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Great Stuff but Needs Better Usability #1
Comments
FYI, I have just opened a feature request in Pytorch repository as well: pytorch/pytorch#52626 |
Thank you for your interest in our work, and opening the feature request! It is a good idea to make the code applicable to any network without having to adding some functions to the network class to switch between the states. I just looked up the documents and feel this can be achievable without adding new features to PyTorch. We can use |
Sounds great!!! If we could also use it with any given optimizer, that would be perfect. |
Hi, In the end we will probably try to implement it by ourselves, but we would rather use the official instructions to insure it works properly. Great work! thanks for the publication, |
@kayuksel I have just pushed a new version to support any CNN that only has nn.Conv2d, nn.Linear and nn.BatchNorm2d as its parameterized layers. Please refer to the note in README for more details and feel free to ask if you have any further question. @danarte Thanks for the interest in our work! Basically we use GradInit on all parameters of the network. We learn a scale factor for each weight and bias (if any, and non-zero at initialization). Please refer to the notes in the updated README.md to see how to extend to other models like Transformers. Basically we just need to enable iterating all trainable modules in a fixed order and take gradient steps (to compute the objective of Eq.1) for all their parameters. I will release the code for fairseq ASAP. Feel free to open a new issue if you have any question. |
@zhuchen03 I see that it requires dataloader and seems to be specific to the classification. |
@kayuksel I'm curious. How is your problem like? I think you can try it out as long as your model can be optimized with SGD. You just need to replace the loss function with yours. I agree the current version is restricted to image classification but it shouldn't be too difficult to adapt to other tasks. Happy to assist or maybe improve the API if you could provide more details. |
@zhuchen03 In my case, it is a generative model that is trained by QHAdam (with an adaptive gradient clipping wrapper), which learns to continously generates population of solutions to e.g. a mathematical function. In these type of reinforcement learning problems, the network initialization can be an important factor as it effects how the agent starts taking actions and hence how experiences are acquired to update the policy. (leading to the severe reproducibility issues and random seed sensitivity of RL) |
@kayuksel I see. I do not have much background in the problem you are trying to solve. From your description, it looks like you are using some Adam-like optimizer, and GradInit should be applicable as long as we can write down its update rule for the first step. I can check whether there are other issues hindering implementing GradInit for your problem, if you could share some simple sample code. |
Thanks @zhuchen03, how can I send you a sample code? Can I use the (cs.umd.edu) e-mail that is mentioned at your resume? |
Yes that works. Thank you! |
Hi @kayuksel, @danarte. Just in case I wanted to point out to our recent work https://github.com/facebookresearch/ppuda. You should be able to initialize almost any neural net in a single function call. |
Hello,
Thanks for such a great work. Auto-initializing a DNN in a proper way definetely sounds amazing.
Yet, the usability needs to be significantly improved so that I can plug this in my existing networks.
It would be great if that could be as easy as installing and then importing an additional package.
We should maybe open a feature request in PyTorch so that they integrate this into the framework.
The text was updated successfully, but these errors were encountered: