Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace continous mlp baseline w model primitive #898

Merged
merged 2 commits into from
Sep 26, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/garage/tf/baselines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
"""Baseline estimators for TensorFlow-based algorithms."""
from garage.tf.baselines.continuous_mlp_baseline import ContinuousMLPBaseline
from garage.tf.baselines.continuous_mlp_baseline_with_model import (
ContinuousMLPBaselineWithModel)
from garage.tf.baselines.gaussian_cnn_baseline_with_model import (
GaussianCNNBaselineWithModel)
from garage.tf.baselines.gaussian_conv_baseline import GaussianConvBaseline
from garage.tf.baselines.gaussian_mlp_baseline import GaussianMLPBaseline

__all__ = [
'ContinuousMLPBaseline',
'ContinuousMLPBaselineWithModel',
'GaussianConvBaseline',
'GaussianCNNBaselineWithModel',
'GaussianMLPBaseline',
Expand Down
53 changes: 28 additions & 25 deletions src/garage/tf/baselines/continuous_mlp_baseline.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""This module implements continuous mlp baseline."""
"""A value function (baseline) based on a MLP model."""
import numpy as np

from garage.core import Serializable
from garage.misc.overrides import overrides
from garage.np.baselines import Baseline
from garage.tf.core import Parameterized
from garage.tf.regressors import ContinuousMLPRegressor


class ContinuousMLPBaseline(Baseline, Parameterized, Serializable):
"""A value function using a mlp network."""
class ContinuousMLPBaseline(Baseline):
"""A value function using a MLP network."""

def __init__(
self,
Expand All @@ -20,37 +18,31 @@ def __init__(
name='ContinuousMLPBaseline',
):
"""
Constructor.
Continuous MLP Baseline.

:param env_spec: environment specification.
:param subsample_factor:
:param num_seq_inputs: number of sequence inputs.
:param regressor_args: regressor arguments.
It fits the input data by performing linear regression
to the outputs.

Args:
env_spec (garage.envs.env_spec.EnvSpec): Environment specification.
subsample_factor (float): The factor to subsample the data. By
default it is 1.0, which means using all the data.
num_seq_inputs (float): Number of sequence per input. By default
it is 1.0, which means only one single sequence.
regressor_args (dict): Arguments for regressor.
"""
Parameterized.__init__(self)
Serializable.quick_init(self, locals())
super(ContinuousMLPBaseline, self).__init__(env_spec)
super().__init__(env_spec)
if regressor_args is None:
regressor_args = dict()

self._regressor = ContinuousMLPRegressor(
input_shape=(
env_spec.observation_space.flat_dim * num_seq_inputs, ),
input_shape=(env_spec.observation_space.flat_dim *
num_seq_inputs, ),
output_dim=1,
name=name,
**regressor_args)
self.name = name

@overrides
def get_param_values(self, **tags):
"""Get parameter values."""
return self._regressor.get_param_values(**tags)

@overrides
def set_param_values(self, val, **tags):
"""Set parameter values to val."""
self._regressor.set_param_values(val, **tags)

@overrides
def fit(self, paths):
"""Fit regressor based on paths."""
Expand All @@ -63,6 +55,17 @@ def predict(self, path):
"""Predict value based on paths."""
return self._regressor.predict(path['observations']).flatten()

@overrides
def get_param_values(self, **tags):
"""Get parameter values."""
return self._regressor.get_param_values(**tags)

@overrides
def set_param_values(self, flattened_params, **tags):
"""Set parameter values to val."""
self._regressor.set_param_values(flattened_params, **tags)

@overrides
def get_params_internal(self, **tags):
"""Get internal parameters."""
return self._regressor.get_params_internal(**tags)
71 changes: 0 additions & 71 deletions src/garage/tf/baselines/continuous_mlp_baseline_with_model.py

This file was deleted.

10 changes: 4 additions & 6 deletions src/garage/tf/regressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
CategoricalMLPRegressorWithModel)
from garage.tf.regressors.continuous_mlp_regressor import (
ContinuousMLPRegressor)
from garage.tf.regressors.continuous_mlp_regressor_with_model import (
ContinuousMLPRegressorWithModel)
from garage.tf.regressors.gaussian_cnn_regressor_model import (
GaussianCNNRegressorModel)
from garage.tf.regressors.gaussian_cnn_regressor_with_model import (
Expand All @@ -23,8 +21,8 @@
__all__ = [
'BernoulliMLPRegressor', 'BernoulliMLPRegressorWithModel',
'CategoricalMLPRegressor', 'CategoricalMLPRegressorWithModel',
'ContinuousMLPRegressor', 'ContinuousMLPRegressorWithModel',
'GaussianCNNRegressorModel', 'GaussianCNNRegressorWithModel',
'GaussianMLPRegressor', 'GaussianConvRegressor', 'Regressor2',
'StochasticRegressor', 'StochasticRegressor2'
'ContinuousMLPRegressor', 'GaussianCNNRegressorModel',
'GaussianCNNRegressorWithModel', 'GaussianMLPRegressor',
'GaussianConvRegressor', 'Regressor2', 'StochasticRegressor',
'StochasticRegressor2'
]
Loading