-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Deep Q-Network #617
Conversation
I have tried various hyperparameter settings from existing DQN implementations, including eaplatanios/swift-rl, higgsfield/RL-Adventure, and TF-Agents. None of them consistently got a score of 200, so I improved the model with three additional algorithms:
With these three improvements and the TF-Agents hyperparameters, I am now getting consistent 200s! |
… to four files, added Tensor extension, formatted via swift-format.
I just pushed a commit that does a few things:
The segmentation fault I was seeing was due to trying to set a Tensor at an index within a larger Tensor during There are a few places where Numpy functions are being used when there are Swift-native equivalents, and that can impact performance on tight loops around an environment. I've replaced a few of these, but I think that all the rest should also be able to be replaced with Swift functions. This isn't always solving the environment (it seems to solve ~half of the time), so if that's unexpected, you may want to check over my replay buffer logic and other changes to make sure I didn't introduce a subtle error somewhere. For organization, I thought there was a lot of code in main.swift, so I broke it out into separate files for the core types. I think that makes it a little easier to follow what's going on here with the various pieces. |
Thank you for the wonderful commit @BradLarson !
Other than that, I used epsilon decay to make it solve CartPole a bit faster. |
@seungjaeryanlee - Oh, right, I accidentally left the class->struct conversion in there from my testing. I had been trying to see if ReplayBuffer as a class was interacting badly with the autodiff system, thus the attempt at converting it to a struct. I totally forgot to convert that back, and that makes complete sense as to why that wasn't being mutated properly. There might be an argument in the future for converting that to a struct and then using appropriate inout semantics throughout to utilize some advantages of value types while preserving appropriate mutation, but we're good for now in having this be a class. Good catch, and glad that was that simple to fix. Do you feel ready to convert this from a draft to a full PR for final review and submission? |
Sounds good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, I think this PR is ready to land! I left some minor comments that can be addressed later.
I have added documentation comments, focusing mainly on the |
Thanks for working hard on getting this ready, it's exciting to have an implementation of this in the models. To follow on, I've created an issue in swift-apis to track lifting up the |
This is a mostly finished implementation of Deep Q-Network on the CartPole environment that needs some improvements. Following the original paper (Mnih et al., 2015), the code has both experience replay and target network implemented.
I have identified three places of improvements:
The Q-network is given a batch of states and outputs a batch of Q-values for all actions (batch size x # actions). This should be transformed to a 1D vector (batch size x 1) that selects the Q-value of the action that was actually selected. In Python,tf.gather_nd
seems to be the best option, but it does not seem to exist in_Raw
. Is there a way to use this function?_Raw.gatherNd
existsThere are a lot of commented outprint()
statements. These are left out to aid me as I work through the improvements, but they should be removed once I am finished._Raw
As per @BradLarson 's suggestion, most_Raw
commands can replaced by S4TF functions. Replace them accordingly.Huber loss is mentioned in the original paper and can help with training stability.:on
for Tensors to support X10 devicesThe hyperparameters were chosen via a rough random search, and the results vary heavily.gatherNd
into a separate fileswift-api
, separate thegatherNd
redefinition to a new file.