Skip to content

Commit

Permalink
[RLlib] Deprecate all Model(v1) usage. (#8146)
Browse files Browse the repository at this point in the history
Deprecate all Model(v1) usage.
  • Loading branch information
sven1977 committed Apr 29, 2020
1 parent eb91619 commit bf25aee
Show file tree
Hide file tree
Showing 28 changed files with 554 additions and 679 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ py_test(
name = "examples/nested_action_spaces_ppo",
main = "examples/nested_action_spaces.py",
tags = ["examples", "examples_N"],
size = "small",
size = "medium",
srcs = ["examples/nested_action_spaces.py"],
args = ["--stop=-500", "--run=PPO"]
)
Expand Down
15 changes: 9 additions & 6 deletions rllib/agents/ars/ars_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ def __init__(self, obs_space, action_space, config):
dist_class, dist_dim = ModelCatalog.get_action_dist(
self.action_space, config["model"], dist_type="deterministic")

model = ModelCatalog.get_model({
SampleBatch.CUR_OBS: self.inputs
}, self.observation_space, self.action_space, dist_dim,
config["model"])
dist = dist_class(model.outputs, model)
self.model = ModelCatalog.get_model_v2(
obs_space=self.preprocessor.observation_space,
action_space=self.action_space,
num_outputs=dist_dim,
model_config=config["model"])
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs})
dist = dist_class(dist_inputs, self.model)

self.sampler = dist.sample()

self.variables = ray.experimental.tf_utils.TensorFlowVariables(
model.outputs, self.sess)
dist_inputs, self.sess)

self.num_params = sum(
np.prod(variable.shape.as_list())
Expand Down

0 comments on commit bf25aee

Please sign in to comment.