Skip to content

Commit

Permalink
add Live Training Cartpole DQN NB
Browse files Browse the repository at this point in the history
  • Loading branch information
jonkrohn committed Apr 4, 2018
1 parent 5e2be15 commit 67865f9
Showing 1 changed file with 273 additions and 0 deletions.
273 changes: 273 additions & 0 deletions notebooks/live_training/cartpole_dqn_LT.ipynb
@@ -0,0 +1,273 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cartpole DQN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Deep Q-Learning Network with Keras and OpenAI Gym, based on [Keon Kim's code](https://github.com/keon/deep-q-learning/blob/master/dqn.py)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Import dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import gym\n",
"import numpy as np\n",
"from collections import deque\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense\n",
"from keras.optimizers import Adam\n",
"import os "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env = # FILL IN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"state_size = env.observation_space.shape[0]\n",
"state_size"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"action_size = env.action_space.n\n",
"action_size"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_size = # FILL IN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_episodes = # FILL IN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output_dir = 'model_output/cartpole/'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not os.path.exists(output_dir):\n",
" os.makedirs(output_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Define agent"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DQNAgent:\n",
" \n",
" def __init__(self, state_size, action_size):\n",
" self.state_size = state_size\n",
" self.action_size = action_size\n",
" \n",
" self.memory = deque(maxlen=2000) \n",
" \n",
" self.gamma = # FILL IN\n",
" \n",
" self.epsilon = # FILL IN\n",
" self.epsilon_decay = # FILL IN\n",
" self.epsilon_min = # FILL IN\n",
" \n",
" self.learning_rate = # FILL IN\n",
" \n",
" self.model = self._build_model() \n",
" \n",
" def _build_model(self):\n",
"\n",
" model = Sequential()\n",
"\n",
" # FILL IN NEURAL NETWORK ARCHITECTURE\n",
"\n",
" # COMPILE MODEL\n",
" \n",
" return model\n",
" \n",
" def remember(self, state, action, reward, next_state, done):\n",
" self.memory.append((state, action, reward, next_state, done)) \n",
"\n",
" def act(self, state):\n",
" if np.random.rand() <= self.epsilon: \n",
" return random.randrange(self.action_size)\n",
" act_values = self.model.predict(state)\n",
" return np.argmax(act_values[0]) \n",
"\n",
" def replay(self, batch_size): \n",
" \n",
" minibatch = random.sample(self.memory, batch_size) \n",
" \n",
" for state, action, reward, next_state, done in minibatch: \n",
" target = reward # N.B.: if done\n",
" if not done: \n",
" target = (reward + self.gamma * np.amax(self.model.predict(next_state)[0])) # (maximum target Q based on future action a')\n",
" target_f = self.model.predict(state) \n",
" target_f[0][action] = target\n",
" \n",
" self.model.fit(state, target_f, epochs=1, verbose=0)\n",
"\n",
" if self.epsilon > self.epsilon_min:\n",
" self.epsilon *= self.epsilon_decay\n",
"\n",
" def load(self, name):\n",
" self.model.load_weights(name)\n",
"\n",
" def save(self, name):\n",
" self.model.save_weights(name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Interact with environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent = DQNAgent(state_size, action_size) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"done = False\n",
"\n",
"for e in range(n_episodes): \n",
"\n",
" state = env.reset() \n",
" state = np.reshape(state, [1, state_size])\n",
" \n",
" for time in range(5000): \n",
" \n",
" # env.render()\n",
"\n",
" action = agent.act(state) \n",
" \n",
" next_state, reward, done, _ = env.step(action) \n",
" \n",
" reward = reward if not done else -10 \n",
" \n",
" next_state = np.reshape(next_state, [1, state_size])\n",
" \n",
" agent.remember(state, action, reward, next_state, done) \n",
"\n",
" state = next_state \n",
" \n",
" if done: \n",
" print(\"episode: {}/{}, score: {}, e: {:.2}\".format(e, n_episodes, time, agent.epsilon))\n",
" break \n",
"\n",
" if len(agent.memory) > batch_size:\n",
" agent.replay(batch_size)\n",
"\n",
" if e % 50 == 0:\n",
" agent.save(output_dir + \"weights_\" + '{:04d}'.format(e) + \".hdf5\") "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# saved agents can be loaded with agent.load(\"./path/filename.hdf5\") "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 67865f9

Please sign in to comment.