-
Notifications
You must be signed in to change notification settings - Fork 307
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TensorFlow embedding models (#1168)
* Add Embedding APIs and GaussianMLPEncoder. * Add sample_sym to DiagonalGaussian to enable get_action_sym in the future embeddings or policies. * Add Module to tensorflow, that captures shared methods in Policy, Regressor and Baseline. Co-authored-by: CatherineSue <CatherineSue@users.noreply.github.com> Co-authored-by: Keren Zhu <naeioi@hotmail.com>
- Loading branch information
1 parent
6024138
commit be30964
Showing
26 changed files
with
1,176 additions
and
414 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""Embedding encoders and decoders which use NumPy as a numerical backend.""" | ||
from garage.np.embeddings.encoder import Encoder, StochasticEncoder | ||
|
||
__all__ = ['Encoder', 'StochasticEncoder'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""Base class for context encoder.""" | ||
import abc | ||
|
||
|
||
class Encoder(abc.ABC): | ||
"""Base class of context encoders for training meta-RL algorithms.""" | ||
|
||
@abc.abstractmethod | ||
def forward(self, input_value): | ||
"""Encode an input value. | ||
Args: | ||
input_value (numpy.ndarray): Input values of (N, input_dim) shape. | ||
Returns: | ||
numpy.ndarray: Encoded embedding. | ||
""" | ||
|
||
@property | ||
@abc.abstractmethod | ||
def input_dim(self): | ||
"""int: Dimension of the encoder input.""" | ||
|
||
@property | ||
@abc.abstractmethod | ||
def output_dim(self): | ||
"""int: Dimension of the encoder output (embedding).""" | ||
|
||
def reset(self, do_resets=None): | ||
"""Reset the encoder. | ||
This is effective only to recurrent encoder. do_resets is effective | ||
only to vectoried encoder. | ||
For a vectorized encoder, do_resets is an array of boolean indicating | ||
which internal states to be reset. The length of do_resets should be | ||
equal to the length of inputs. | ||
Args: | ||
do_resets (numpy.ndarray): Bool array indicating which states | ||
to be reset. | ||
""" | ||
|
||
|
||
class StochasticEncoder(Encoder): | ||
"""An stochastic context encoders. | ||
An stochastic encoder maps an input to a distribution, but not a | ||
deterministic vector. | ||
""" | ||
|
||
@property | ||
@abc.abstractmethod | ||
def distribution(self): | ||
"""scipy.stats.rv_generic: Embedding distribution.""" | ||
|
||
def dist_info(self, input_value, state_infos): | ||
"""Distribution info. | ||
Get the information of embedding distribution given an input. | ||
Args: | ||
input_value (np.ndarray): input values | ||
state_infos (dict): a dictionary whose values contain | ||
information about the predicted embedding given an input. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,120 @@ | ||
"""Distributions Base.""" | ||
|
||
|
||
class Distribution: | ||
"""Base class for distribution.""" | ||
|
||
@property | ||
def dim(self): | ||
"""int: Dimension of this distribution.""" | ||
raise NotImplementedError | ||
|
||
def kl_sym(self, old_dist_info_vars, new_dist_info_vars): | ||
""" | ||
Compute the symbolic KL divergence of two distributions | ||
def kl_sym(self, old_dist_info_vars, new_dist_info_vars, name='kl_sym'): | ||
"""Compute the symbolic KL divergence of two distributions. | ||
Args: | ||
old_dist_info_vars (tf.Tensor): Symbolic parameters of | ||
the old distribution. | ||
new_dist_info_vars (tf.Tensor): Symbolic parameters of | ||
the new distribution. | ||
name (str): TensorFlow scope name. | ||
Returns: | ||
tf.Tensor: Symbolic KL divergence between the two distributions. | ||
""" | ||
raise NotImplementedError | ||
|
||
def kl(self, old_dist_info, new_dist_info): | ||
""" | ||
Compute the KL divergence of two distributions | ||
"""Compute the KL divergence of two distributions. | ||
Args: | ||
old_dist_info (dict): Parameters of the old distribution. | ||
new_dist_info (dict): Parameters of the new distribution. | ||
Returns: | ||
float: KL Divergence between two distributions. | ||
""" | ||
raise NotImplementedError | ||
|
||
def likelihood_ratio_sym(self, x_var, old_dist_info_vars, | ||
new_dist_info_vars): | ||
def likelihood_ratio_sym(self, | ||
x_var, | ||
old_dist_info_vars, | ||
new_dist_info_vars, | ||
name='ll_ratio_sym'): | ||
"""Symbolic likelihood ratio. | ||
Args: | ||
x_var (tf.Tensor): Input placeholder. | ||
old_dist_info_vars (dict): Old distribution tensors. | ||
new_dist_info_vars (dict): New distribution tensors. | ||
name (str): TensorFlow scope name. | ||
Returns: | ||
tf.Tensor: Symbolic likelihood ratio. | ||
""" | ||
raise NotImplementedError | ||
|
||
def entropy(self, dist_info): | ||
"""Entropy of a distribution. | ||
Args: | ||
dist_info (dict): Parameters of a distribution. | ||
Returns: | ||
float: Entropy of the distribution. | ||
""" | ||
raise NotImplementedError | ||
|
||
def entropy_sym(self, dist_info_vars, name='entropy_sym'): | ||
"""Symbolic entropy of a distribution. | ||
Args: | ||
dist_info_vars (dict): Symbolic parameters of a distribution. | ||
name (str): TensorFlow scope name. | ||
Returns: | ||
tf.Tensor: Symbolic entropy of the distribution. | ||
""" | ||
raise NotImplementedError | ||
|
||
def log_likelihood_sym(self, x_var, dist_info_vars): | ||
def log_likelihood_sym(self, x_var, dist_info_vars, name='ll_sym'): | ||
"""Symbolic log likelihood. | ||
Args: | ||
x_var (tf.Tensor): Input placeholder. | ||
dist_info_vars (dict): Parameters of a distribution. | ||
name (str): TensorFlow scope name. | ||
Returns: | ||
tf.Tensor: Symbolic log likelihood. | ||
""" | ||
raise NotImplementedError | ||
|
||
def log_likelihood(self, xs, dist_info): | ||
"""Log likelihood of a sample under a distribution. | ||
Args: | ||
xs (np.ndarray): Input value. | ||
dist_info (dict): Parameters of a distribution. | ||
Returns: | ||
float: Log likelihood of a sample under the distribution. | ||
""" | ||
raise NotImplementedError | ||
|
||
@property | ||
def dist_info_specs(self): | ||
"""list: Specification of the parameter of a distribution.""" | ||
raise NotImplementedError | ||
|
||
@property | ||
def dist_info_keys(self): | ||
"""list: Parameter names.""" | ||
return [k for k, _ in self.dist_info_specs] |
Oops, something went wrong.