From 14738ce23a1f8ebb46728220f0381381bafb61b0 Mon Sep 17 00:00:00 2001 From: Anson Wong Date: Tue, 4 Aug 2020 08:23:44 -0700 Subject: [PATCH] Backport 1791 (#1847) --- .../tf/policies/discrete_qf_derived_policy.py | 105 +++++++++++++++++- .../tf/policies/test_qf_derived_policy.py | 9 +- 2 files changed, 105 insertions(+), 9 deletions(-) diff --git a/src/garage/tf/policies/discrete_qf_derived_policy.py b/src/garage/tf/policies/discrete_qf_derived_policy.py index 0f5bfc3494..848f223db2 100644 --- a/src/garage/tf/policies/discrete_qf_derived_policy.py +++ b/src/garage/tf/policies/discrete_qf_derived_policy.py @@ -16,6 +16,7 @@ class DiscreteQfDerivedPolicy(Policy): env_spec (garage.envs.env_spec.EnvSpec): Environment specification. qf (garage.q_functions.QFunction): The q-function used. name (str): Name of the policy. + """ def __init__(self, env_spec, qf, name='DiscreteQfDerivedPolicy'): @@ -34,7 +35,12 @@ def _initialize(self): @property def vectorized(self): - """Vectorized or not.""" + """Vectorized or not. + + Returns: + bool: True if vectorized. + + """ return True def get_action(self, observation): @@ -44,7 +50,7 @@ def get_action(self, observation): observation (numpy.ndarray): Observation from environment. Returns: - Single optimal action from this policy. + int: Single optimal action from this policy. """ q_vals = self._f_qval([observation]) @@ -59,7 +65,7 @@ def get_actions(self, observations): observations (numpy.ndarray): Observations from environment. Returns: - Optimal actions from this policy. + numpy.ndarray: Optimal actions from this policy. """ q_vals = self._f_qval(observations) @@ -67,13 +73,102 @@ def get_actions(self, observations): return opt_actions + def get_trainable_vars(self): + """Get trainable variables. + + Returns: + List[tf.Variable]: A list of trainable variables in the current + variable scope. + + """ + return self._qf.get_trainable_vars() + + def get_global_vars(self): + """Get global variables. + + Returns: + List[tf.Variable]: A list of global variables in the current + variable scope. + + """ + return self._qf.get_global_vars() + + def get_regularizable_vars(self): + """Get all network weight variables in the current scope. + + Returns: + List[tf.Variable]: A list of network weight variables in the + current variable scope. + + """ + return self._qf.get_regularizable_vars() + + def get_params(self, trainable=True): + """Get the trainable variables. + + Args: + trainable (bool): Trainable or not. + + Returns: + List[tf.Variable]: A list of trainable variables in the current + variable scope. + + """ + return self._qf.get_params() + + def get_param_shapes(self, **tags): + """Get parameter shapes. + + Args: + tags: Extra arguments. + + Returns: + List[tuple]: A list of variable shapes. + + """ + return self._qf.get_param_shapes() + + def get_param_values(self, **tags): + """Get param values. + + Args: + tags: Extra arguments. + + Returns: + np.ndarray: Values of the parameters evaluated in + the current session + + """ + return self._qf.get_param_values() + + def set_param_values(self, param_values, name=None, **tags): + """Set param values. + + Args: + param_values (np.ndarray): A numpy array of parameter values. + name (str): Name of the scope. + tags: Extra arguments. + + """ + self._qf.set_param_values(param_values) + def __getstate__(self): - """Object.__getstate__.""" + """Object.__getstate__. + + Returns: + dict: the state to be pickled for the instance. + + """ new_dict = self.__dict__.copy() del new_dict['_f_qval'] return new_dict def __setstate__(self, state): - """Object.__setstate__.""" + """Object.__setstate__. + + Args: + state (dict): Unpickled state. + + """ self.__dict__.update(state) self._initialize() diff --git a/tests/garage/tf/policies/test_qf_derived_policy.py b/tests/garage/tf/policies/test_qf_derived_policy.py index 0a081b6438..20333c5de2 100644 --- a/tests/garage/tf/policies/test_qf_derived_policy.py +++ b/tests/garage/tf/policies/test_qf_derived_policy.py @@ -10,12 +10,13 @@ class TestQfDerivedPolicy(TfGraphTestCase): + def setup_method(self): super().setup_method() self.env = TfEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) - self.policy = DiscreteQfDerivedPolicy( - env_spec=self.env.spec, qf=self.qf) + self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, + qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() @@ -28,8 +29,8 @@ def test_discrete_qf_derived_policy(self): assert self.env.action_space.contains(action) def test_is_pickleable(self): - with tf.compat.v1.variable_scope( - 'SimpleQFunction/SimpleMLPModel', reuse=True): + with tf.compat.v1.variable_scope('SimpleQFunction/SimpleMLPModel', + reuse=True): return_var = tf.compat.v1.get_variable('return_var') # assign it to all one return_var.load(tf.ones_like(return_var).eval())