Skip to content

Commit

Permalink
Backport 1791 (#1847)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahtsan committed Aug 4, 2020
1 parent 411f657 commit 14738ce
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 9 deletions.
105 changes: 100 additions & 5 deletions src/garage/tf/policies/discrete_qf_derived_policy.py
Expand Up @@ -16,6 +16,7 @@ class DiscreteQfDerivedPolicy(Policy):
env_spec (garage.envs.env_spec.EnvSpec): Environment specification.
qf (garage.q_functions.QFunction): The q-function used.
name (str): Name of the policy.
"""

def __init__(self, env_spec, qf, name='DiscreteQfDerivedPolicy'):
Expand All @@ -34,7 +35,12 @@ def _initialize(self):

@property
def vectorized(self):
"""Vectorized or not."""
"""Vectorized or not.
Returns:
bool: True if vectorized.
"""
return True

def get_action(self, observation):
Expand All @@ -44,7 +50,7 @@ def get_action(self, observation):
observation (numpy.ndarray): Observation from environment.
Returns:
Single optimal action from this policy.
int: Single optimal action from this policy.
"""
q_vals = self._f_qval([observation])
Expand All @@ -59,21 +65,110 @@ def get_actions(self, observations):
observations (numpy.ndarray): Observations from environment.
Returns:
Optimal actions from this policy.
numpy.ndarray: Optimal actions from this policy.
"""
q_vals = self._f_qval(observations)
opt_actions = np.argmax(q_vals, axis=1)

return opt_actions

def get_trainable_vars(self):
"""Get trainable variables.
Returns:
List[tf.Variable]: A list of trainable variables in the current
variable scope.
"""
return self._qf.get_trainable_vars()

def get_global_vars(self):
"""Get global variables.
Returns:
List[tf.Variable]: A list of global variables in the current
variable scope.
"""
return self._qf.get_global_vars()

def get_regularizable_vars(self):
"""Get all network weight variables in the current scope.
Returns:
List[tf.Variable]: A list of network weight variables in the
current variable scope.
"""
return self._qf.get_regularizable_vars()

def get_params(self, trainable=True):
"""Get the trainable variables.
Args:
trainable (bool): Trainable or not.
Returns:
List[tf.Variable]: A list of trainable variables in the current
variable scope.
"""
return self._qf.get_params()

def get_param_shapes(self, **tags):
"""Get parameter shapes.
Args:
tags: Extra arguments.
Returns:
List[tuple]: A list of variable shapes.
"""
return self._qf.get_param_shapes()

def get_param_values(self, **tags):
"""Get param values.
Args:
tags: Extra arguments.
Returns:
np.ndarray: Values of the parameters evaluated in
the current session
"""
return self._qf.get_param_values()

def set_param_values(self, param_values, name=None, **tags):
"""Set param values.
Args:
param_values (np.ndarray): A numpy array of parameter values.
name (str): Name of the scope.
tags: Extra arguments.
"""
self._qf.set_param_values(param_values)

def __getstate__(self):
"""Object.__getstate__."""
"""Object.__getstate__.
Returns:
dict: the state to be pickled for the instance.
"""
new_dict = self.__dict__.copy()
del new_dict['_f_qval']
return new_dict

def __setstate__(self, state):
"""Object.__setstate__."""
"""Object.__setstate__.
Args:
state (dict): Unpickled state.
"""
self.__dict__.update(state)
self._initialize()
9 changes: 5 additions & 4 deletions tests/garage/tf/policies/test_qf_derived_policy.py
Expand Up @@ -10,12 +10,13 @@


class TestQfDerivedPolicy(TfGraphTestCase):

def setup_method(self):
super().setup_method()
self.env = TfEnv(DummyDiscreteEnv())
self.qf = SimpleQFunction(self.env.spec)
self.policy = DiscreteQfDerivedPolicy(
env_spec=self.env.spec, qf=self.qf)
self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec,
qf=self.qf)
self.sess.run(tf.compat.v1.global_variables_initializer())
self.env.reset()

Expand All @@ -28,8 +29,8 @@ def test_discrete_qf_derived_policy(self):
assert self.env.action_space.contains(action)

def test_is_pickleable(self):
with tf.compat.v1.variable_scope(
'SimpleQFunction/SimpleMLPModel', reuse=True):
with tf.compat.v1.variable_scope('SimpleQFunction/SimpleMLPModel',
reuse=True):
return_var = tf.compat.v1.get_variable('return_var')
# assign it to all one
return_var.load(tf.ones_like(return_var).eval())
Expand Down

0 comments on commit 14738ce

Please sign in to comment.