Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Deep Q-Network #617

Merged
merged 36 commits into from
Aug 10, 2020
Merged

Implement Deep Q-Network #617

merged 36 commits into from
Aug 10, 2020

Conversation

seungjaeryanlee
Copy link
Contributor

@seungjaeryanlee seungjaeryanlee commented Jun 29, 2020

This is a mostly finished implementation of Deep Q-Network on the CartPole environment that needs some improvements. Following the original paper (Mnih et al., 2015), the code has both experience replay and target network implemented.

I have identified three places of improvements:

  • Parallelize training
    • The Q-network is given a batch of states and outputs a batch of Q-values for all actions (batch size x # actions). This should be transformed to a 1D vector (batch size x 1) that selects the Q-value of the action that was actually selected. In Python, tf.gather_nd seems to be the best option, but it does not seem to exist in _Raw. Is there a way to use this function?
    • _Raw.gatherNd exists
  • Refactor code
    • There are a lot of commented out print() statements. These are left out to aid me as I work through the improvements, but they should be removed once I am finished.
  • Remove uses of _Raw
    • As per @BradLarson 's suggestion, most _Raw commands can replaced by S4TF functions. Replace them accordingly.
  • Use Huber loss
    • Huber loss is mentioned in the original paper and can help with training stability.
  • Use :on for Tensors to support X10 devices
  • Improve reproducibility
    • The hyperparameters were chosen via a rough random search, and the results vary heavily.
  • Separate gatherNd into a separate file
    • For a more streamlined improvement to swift-api, separate the gatherNd redefinition to a new file.

Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
@dan-zheng dan-zheng added the gsoc Google Summer of Code label Jul 8, 2020
@seungjaeryanlee
Copy link
Contributor Author

I have tried various hyperparameter settings from existing DQN implementations, including eaplatanios/swift-rl, higgsfield/RL-Adventure, and TF-Agents. None of them consistently got a score of 200, so I improved the model with three additional algorithms:

With these three improvements and the TF-Agents hyperparameters, I am now getting consistent 200s!

… to four files, added Tensor extension, formatted via swift-format.
@BradLarson
Copy link
Contributor

I just pushed a commit that does a few things:

  • Fixes segmentation faults on GPU eager mode execution.
  • Creates an extension on Tensor for gatherNd (which I'm calling dimensionalGathering until we can figure out a better name or better way of expressing the difference in parameter names).
  • Splits out the types from main.swift into a few files to better organize the components here.
  • Replaces a few Numpy uses with native (and faster) Swift equivalents.
  • Reformats the files per swift-format's style that we use for this repository.

The segmentation fault I was seeing was due to trying to set a Tensor at an index within a larger Tensor during append() in the ReplayBuffer. I've changed this to a simpler (and probably more performant) approach using arrays of tensors that are stacked when needed during the batch extraction. For a replay buffer, we might want to consider using pure scalars there, because they'll be faster to work with when interacting with environments that take and provide scalars. I've done this with other environments and gotten very fast speeds by converting them to Tensors only at the point where you want to use the accelerator (model training and agent logic).

There are a few places where Numpy functions are being used when there are Swift-native equivalents, and that can impact performance on tight loops around an environment. I've replaced a few of these, but I think that all the rest should also be able to be replaced with Swift functions.

This isn't always solving the environment (it seems to solve ~half of the time), so if that's unexpected, you may want to check over my replay buffer logic and other changes to make sure I didn't introduce a subtle error somewhere.

For organization, I thought there was a lot of code in main.swift, so I broke it out into separate files for the core types. I think that makes it a little easier to follow what's going on here with the various pieces.

@seungjaeryanlee
Copy link
Contributor Author

seungjaeryanlee commented Aug 6, 2020

Thank you for the wonderful commit @BradLarson !

  1. I agree that the replay buffer should store numbers as Swift native types instead of Tensors and only convert them in the sample() function. I forgot Swift was fast! I will make that change.
  2. Converting NumPy functions to Swift is definitely a good idea. I could not find some of the functions that I wanted which was why I have been using NumPy, but using Swift functions makes much more sense.
  3. The division of code into Gathering / Agent / ReplayBuffer / main makes sense to me. I am considering moving the eval() function to the agent later.
  4. The issue with the degraded performance stems from the ReplayBuffer being changed from class to struct. As far as I know, that means replayBuffer is passed by value, so the agent.replayBuffer is never updated and is empty the entire time, and thus the training never occurs. I changed ReplayBuffer to be a class again and it solves CartPole in <200 episodes or <10000 steps.

Other than that, I used epsilon decay to make it solve CartPole a bit faster.

@BradLarson
Copy link
Contributor

@seungjaeryanlee - Oh, right, I accidentally left the class->struct conversion in there from my testing. I had been trying to see if ReplayBuffer as a class was interacting badly with the autodiff system, thus the attempt at converting it to a struct. I totally forgot to convert that back, and that makes complete sense as to why that wasn't being mutated properly.

There might be an argument in the future for converting that to a struct and then using appropriate inout semantics throughout to utilize some advantages of value types while preserving appropriate mutation, but we're good for now in having this be a class. Good catch, and glad that was that simple to fix.

Do you feel ready to convert this from a draft to a full PR for final review and submission?

@seungjaeryanlee
Copy link
Contributor Author

Sounds good!

@seungjaeryanlee seungjaeryanlee marked this pull request as ready for review August 6, 2020 23:11
Copy link
Member

@dan-zheng dan-zheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, I think this PR is ready to land! I left some minor comments that can be addressed later.

Gym/README.md Show resolved Hide resolved
Gym/DQN/Agent.swift Outdated Show resolved Hide resolved
Gym/DQN/Agent.swift Outdated Show resolved Hide resolved
Gym/DQN/Agent.swift Outdated Show resolved Hide resolved
Gym/DQN/ReplayBuffer.swift Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
Gym/DQN/main.swift Outdated Show resolved Hide resolved
@seungjaeryanlee
Copy link
Contributor Author

I have added documentation comments, focusing mainly on the ReplayBuffer and Agent classes and the list of hyperparameters. I would love to hear your feedback on them!

@BradLarson
Copy link
Contributor

BradLarson commented Aug 10, 2020

Thanks for working hard on getting this ready, it's exciting to have an implementation of this in the models.

To follow on, I've created an issue in swift-apis to track lifting up the dimensionGathering() function (or whatever we end up naming it) into swift-apis and an issue in swift-models to add a little extra documentation to the Gym Readme describing the Deep Q-Network model in the same way as the other Gym models as well as indicating the additional dependencies of Matplotlib and Numpy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
gsoc Google Summer of Code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants