Skip to content

Commit

Permalink
Hotfix policy shape
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Mar 8, 2019
1 parent 1b7ace3 commit 7a97d31
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion digideep/environment/make_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def space2config(S):
# S.__class__.__name__: "Discrete" / "Box"
if isinstance(S, spaces.Discrete):
typ = S.__class__.__name__
dim = S.n
dim = np.int32(S.n)
lim = (np.nan, np.nan) # Discrete Spaces do not have high/low
config = {"typ":typ, "dim":dim, "lim":lim}
elif isinstance(S, spaces.Box):
Expand Down
4 changes: 2 additions & 2 deletions digideep/policy/stochastic/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(self, device, obs_space, act_space, modelname, modelargs):

# TODO: For discrete actions, `act_space["dim"][0]` works. It works for constinuous actions as well.
# Even for discrete actions `np.isscalar(act_space["dim"])` returns False.
# num_outputs = act_space["dim"] if np.isscalar(act_space["dim"]) else act_space["dim"][0]
num_outputs = act_space["dim"].item() if len(act_space["dim"].shape)==0 else act_space["dim"][0]
num_outputs = act_space["dim"] if np.isscalar(act_space["dim"]) else act_space["dim"][0]
# num_outputs = act_space["dim"].item() if len(act_space["dim"].shape)==0 else act_space["dim"][0]
if act_space["typ"] == "Discrete":
print("Discrete is recognized and num_outputs=", num_outputs)
self.model["dist"] = Categorical(num_inputs=modelargs["output_size"], num_outputs=num_outputs)
Expand Down

0 comments on commit 7a97d31

Please sign in to comment.