Skip to content
Permalink
Browse files

Do not require ALE if using DQN with gym only (fix #1091)

  • Loading branch information...
ppwwyyxx committed Feb 21, 2019
1 parent e4aca03 commit 7b9d9b8c25478875782c484e8f88300fd3f081fc
Showing with 25 additions and 10 deletions.
  1. +1 −0 docs/conf.py
  2. +17 −7 docs/tutorial/save-load.md
  3. +1 −1 examples/DeepQNetwork/DQN.py
  4. +5 −1 examples/DeepQNetwork/expreplay.py
  5. +1 −1 examples/DoReFa-Net/README.md
@@ -401,6 +401,7 @@ def process_signature(app, what, name, obj, options, signature,
'average_grads',
'aggregate_grads',
'allreduce_grads',
'get_checkpoint_path'
])

def autodoc_skip_member(app, what, name, obj, skip, options):
@@ -10,28 +10,37 @@ Both are necessary.

`tf.train.NewCheckpointReader` is the offical tool to parse TensorFlow checkpoint.
Read [TF docs](https://www.tensorflow.org/api_docs/python/tf/train/NewCheckpointReader) for details.
Tensorpack also provides some small tools to work with checkpoints, see
[documentation](../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars)
Tensorpack also provides a small tool to load checkpoints, see
[load_chkpt_vars](../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars)
for details.

[scripts/ls-checkpoint.py](../scripts/ls-checkpoint.py)
demos how to print all variables and their shapes in a checkpoint.

[scripts/dump-model-params.py](../scripts/dump-model-params.py) can be used to remove unnecessary variables in a checkpoint.
It takes a metagraph file (which is also saved by `ModelSaver`) and only saves variables that the model needs at inference time.
It can dump the model to a `var-name: value` dict saved in npz format.
It dumps the model to a `var-name: value` dict saved in npz format.

## Load a Model to a Session

Model loading (in either training or inference) is through the `session_init` interface.
Model loading (in both training and inference) is through the `session_init` interface.
Currently there are two ways a session can be restored:
[session_init=SaverRestore(...)](../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore)
which restores a TF checkpoint,
or [session_init=DictRestore(...)](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore) which restores a dict.
[get_model_loader](../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader)
is a small helper to decide which one to use from a file name.
To load multiple models, use [ChainInit](../modules/tfutils.html#tensorpack.tfutils.sessinit.ChainInit).

Many models in tensorpack model zoo are provided in the form of numpy dictionary (`.npz`),
because it is easier to load and manipulate without requiring TensorFlow.
To load such files to a session, use `DictRestore(dict(np.load(filename)))`.
You can also use
[get_model_loader](../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader),
a small helper to create a `SaverRestore` or `DictRestore` based on the file name.

`DictRestore` is the most general loader because you can make arbitrary changes
you need (e.g., remove variables, rename variables) to the dict.
To load a TF checkpoint into a dict in order to make changes, use
[load_chkpt_vars](../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars).

Variable restoring is completely based on __name match__ between
variables in the current graph and variables in the `session_init` initializer.
@@ -40,4 +49,5 @@ Variables that appear in only one side will be printed as warning.
## Transfer Learning
Therefore, transfer learning is trivial.
If you want to load a pre-trained model, just use the same variable names.
If you want to re-train some layer, just rename it.
If you want to re-train some layer, just rename either the variables in the
graph or the variables in your loader.
@@ -12,7 +12,6 @@

from tensorpack import *

from atari import AtariPlayer
from atari_wrapper import FireResetEnv, FrameStack, LimitLength, MapState
from common import Evaluator, eval_model_multithread, play_n_episodes
from DQNModel import Model as DQNModel
@@ -52,6 +51,7 @@ def get_player(viz=False, train=False):
if USE_GYM:
env = gym.make(ENV_NAME)
else:
from atari import AtariPlayer
env = AtariPlayer(ENV_NAME, frame_skip=ACTION_REPEAT, viz=viz,
live_lost_as_eoe=train, max_num_frames=60000)
env = FireResetEnv(env)
@@ -31,7 +31,11 @@ def __init__(self, max_size, state_shape, history_len):
self._shape3d = (state_shape[0], state_shape[1], self._channel * (history_len + 1))
self.history_len = int(history_len)

self.state = np.zeros((self.max_size,) + state_shape, dtype='uint8')
state_shape = (self.max_size,) + state_shape
logger.info("Creating experience replay buffer of {:.1f} GB ... "
"use a smaller buffer if you don't have enough CPU memory.".format(
np.prod(state_shape) / 1024.0**3))
self.state = np.zeros(state_shape, dtype='uint8')
self.action = np.zeros((self.max_size,), dtype='int32')
self.reward = np.zeros((self.max_size,), dtype='float32')
self.isOver = np.zeros((self.max_size,), dtype='bool')
@@ -45,7 +45,7 @@ In this implementation, quantized operations are all performed through `tf.float

+ Look at the docstring in `*-dorefa.py` to see detailed usage and performance.

Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
Pretrained model for (1,4,32)-ResNet18 and several AlexNet are available at
[tensorpack model zoo](http://models.tensorpack.com/DoReFa-Net/).
They're provided in the format of numpy dictionary.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.

0 comments on commit 7b9d9b8

Please sign in to comment.
You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.