Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
273 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Cartpole DQN" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Deep Q-Learning Network with Keras and OpenAI Gym, based on [Keon Kim's code](https://github.com/keon/deep-q-learning/blob/master/dqn.py)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Import dependencies" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import random\n", | ||
"import gym\n", | ||
"import numpy as np\n", | ||
"from collections import deque\n", | ||
"from keras.models import Sequential\n", | ||
"from keras.layers import Dense\n", | ||
"from keras.optimizers import Adam\n", | ||
"import os " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Set parameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"env = # FILL IN" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"state_size = env.observation_space.shape[0]\n", | ||
"state_size" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"action_size = env.action_space.n\n", | ||
"action_size" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"batch_size = # FILL IN" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"n_episodes = # FILL IN" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"output_dir = 'model_output/cartpole/'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"if not os.path.exists(output_dir):\n", | ||
" os.makedirs(output_dir)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Define agent" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class DQNAgent:\n", | ||
" \n", | ||
" def __init__(self, state_size, action_size):\n", | ||
" self.state_size = state_size\n", | ||
" self.action_size = action_size\n", | ||
" \n", | ||
" self.memory = deque(maxlen=2000) \n", | ||
" \n", | ||
" self.gamma = # FILL IN\n", | ||
" \n", | ||
" self.epsilon = # FILL IN\n", | ||
" self.epsilon_decay = # FILL IN\n", | ||
" self.epsilon_min = # FILL IN\n", | ||
" \n", | ||
" self.learning_rate = # FILL IN\n", | ||
" \n", | ||
" self.model = self._build_model() \n", | ||
" \n", | ||
" def _build_model(self):\n", | ||
"\n", | ||
" model = Sequential()\n", | ||
"\n", | ||
" # FILL IN NEURAL NETWORK ARCHITECTURE\n", | ||
"\n", | ||
" # COMPILE MODEL\n", | ||
" \n", | ||
" return model\n", | ||
" \n", | ||
" def remember(self, state, action, reward, next_state, done):\n", | ||
" self.memory.append((state, action, reward, next_state, done)) \n", | ||
"\n", | ||
" def act(self, state):\n", | ||
" if np.random.rand() <= self.epsilon: \n", | ||
" return random.randrange(self.action_size)\n", | ||
" act_values = self.model.predict(state)\n", | ||
" return np.argmax(act_values[0]) \n", | ||
"\n", | ||
" def replay(self, batch_size): \n", | ||
" \n", | ||
" minibatch = random.sample(self.memory, batch_size) \n", | ||
" \n", | ||
" for state, action, reward, next_state, done in minibatch: \n", | ||
" target = reward # N.B.: if done\n", | ||
" if not done: \n", | ||
" target = (reward + self.gamma * np.amax(self.model.predict(next_state)[0])) # (maximum target Q based on future action a')\n", | ||
" target_f = self.model.predict(state) \n", | ||
" target_f[0][action] = target\n", | ||
" \n", | ||
" self.model.fit(state, target_f, epochs=1, verbose=0)\n", | ||
"\n", | ||
" if self.epsilon > self.epsilon_min:\n", | ||
" self.epsilon *= self.epsilon_decay\n", | ||
"\n", | ||
" def load(self, name):\n", | ||
" self.model.load_weights(name)\n", | ||
"\n", | ||
" def save(self, name):\n", | ||
" self.model.save_weights(name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Interact with environment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"agent = DQNAgent(state_size, action_size) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"done = False\n", | ||
"\n", | ||
"for e in range(n_episodes): \n", | ||
"\n", | ||
" state = env.reset() \n", | ||
" state = np.reshape(state, [1, state_size])\n", | ||
" \n", | ||
" for time in range(5000): \n", | ||
" \n", | ||
" # env.render()\n", | ||
"\n", | ||
" action = agent.act(state) \n", | ||
" \n", | ||
" next_state, reward, done, _ = env.step(action) \n", | ||
" \n", | ||
" reward = reward if not done else -10 \n", | ||
" \n", | ||
" next_state = np.reshape(next_state, [1, state_size])\n", | ||
" \n", | ||
" agent.remember(state, action, reward, next_state, done) \n", | ||
"\n", | ||
" state = next_state \n", | ||
" \n", | ||
" if done: \n", | ||
" print(\"episode: {}/{}, score: {}, e: {:.2}\".format(e, n_episodes, time, agent.epsilon))\n", | ||
" break \n", | ||
"\n", | ||
" if len(agent.memory) > batch_size:\n", | ||
" agent.replay(batch_size)\n", | ||
"\n", | ||
" if e % 50 == 0:\n", | ||
" agent.save(output_dir + \"weights_\" + '{:04d}'.format(e) + \".hdf5\") " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# saved agents can be loaded with agent.load(\"./path/filename.hdf5\") " | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |