Skip to content

Commit

Permalink
Fixed MCTS
Browse files Browse the repository at this point in the history
  • Loading branch information
kharitonov-ivan committed Apr 5, 2019
1 parent 472387d commit 8a5c682
Showing 1 changed file with 50 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"# In google collab, uncomment this:\n",
"# !wget https://bit.ly/2FMJP5K -q -O setup.py\n",
"# !bash setup.py 2>&1 1>stdout.log | tee stderr.log\n",
"\n",
"# This code creates a virtual display to draw game images on.\n",
"# If you are running locally, just ignore it\n",
"import os\n",
"if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
" !bash ../xvfb start\n",
" %env DISPLAY = : 1\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
Expand All @@ -29,6 +39,7 @@
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"from gym.core import Wrapper\n",
"from pickle import dumps, loads\n",
"from collections import namedtuple\n",
Expand All @@ -53,7 +64,7 @@
" - self.render(close=True) #close window, same as self.env.render(close=True)\n",
" \"\"\"\n",
"\n",
" def get_snapshot(self):\n",
" def get_snapshot(self, render=False):\n",
" \"\"\"\n",
" :returns: environment state that can be loaded with load_snapshot \n",
" Snapshots guarantee same env behaviour each time they are loaded.\n",
Expand All @@ -67,13 +78,16 @@
" In case of doubt, use pickle.dumps or deepcopy.\n",
"\n",
" \"\"\"\n",
" self.render() # close popup windows since we can't pickle them\n",
" if render:\n",
" self.render() # close popup windows since we can't pickle them\n",
" self.close()\n",
" \n",
" if self.unwrapped.viewer is not None:\n",
" self.unwrapped.viewer.close()\n",
" self.unwrapped.viewer = None\n",
" return dumps(self.env)\n",
"\n",
" def load_snapshot(self, snapshot):\n",
" def load_snapshot(self, snapshot, render=False):\n",
" \"\"\"\n",
" Loads snapshot as current env state.\n",
" Should not change snapshot inplace (in case of doubt, deepcopy).\n",
Expand All @@ -82,8 +96,9 @@
" assert not hasattr(self, \"_monitor\") or hasattr(\n",
" self.env, \"_monitor\"), \"can't backtrack while recording\"\n",
"\n",
" # close popup windows since we can't load into them\n",
" self.render(close=True)\n",
" if render:\n",
" self.render() # close popup windows since we can't load into them\n",
" self.close()\n",
" self.env = loads(snapshot)\n",
"\n",
" def get_result(self, snapshot, action):\n",
Expand All @@ -101,15 +116,16 @@
" <your code here load, commit, take snapshot >\n",
"\n",
" return ActionResult(< next_snapshot > , #fill in the variables\n",
" < next_observation > ,\n",
" < next_observation > ,\n",
" < reward > , < is_done > , < info > )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### try out snapshots:\n"
"### Try out snapshots:\n",
"Let`s check our wrapper. At first, reset environment and save it, further randomly play some actions and restore our environment from the snapshot. It should be the same as our previous initial state."
]
},
{
Expand All @@ -132,8 +148,8 @@
"outputs": [],
"source": [
"print(\"initial_state:\")\n",
"\n",
"plt.imshow(env.render('rgb_array'))\n",
"env.close()\n",
"\n",
"# create first snapshot\n",
"snap0 = env.get_snapshot()"
Expand All @@ -154,7 +170,7 @@
"\n",
"print(\"final state:\")\n",
"plt.imshow(env.render('rgb_array'))\n",
"plt.show()"
"env.close()"
]
},
{
Expand All @@ -168,7 +184,7 @@
"\n",
"print(\"\\n\\nAfter loading snapshot\")\n",
"plt.imshow(env.render('rgb_array'))\n",
"plt.show()"
"env.close()"
]
},
{
Expand Down Expand Up @@ -196,6 +212,13 @@
"\n",
"In this section, we'll implement the vanilla MCTS algorithm with UCB1-based node selection.\n",
"\n",
"$$\n",
"\\dot{v_a} = v_a + \\sqrt{\\frac{2 \\log {N}}{n_a}}\n",
"$$\n",
"\n",
"where: $N$ - number of time-steps so far,\n",
"$n_a$ - times action a is taken\n",
"\n",
"We will start by implementing the `Node` class - a simple class that acts like MCTS node and supports some of the MCTS algorithm steps.\n",
"\n",
"This MCTS implementation makes some assumptions about the environment, you can find those _in the notes section at the end of the notebook_."
Expand Down Expand Up @@ -429,6 +452,7 @@
"metadata": {},
"outputs": [],
"source": [
"env = WithSnapshots(gym.make(\"CartPole-v0\"))\n",
"root_observation = env.reset()\n",
"root_snapshot = env.get_snapshot()\n",
"root = Root(root_snapshot, root_observation)"
Expand Down Expand Up @@ -620,7 +644,20 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
"version": "3.6.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
Expand Down

0 comments on commit 8a5c682

Please sign in to comment.