Skip to content

Commit

Permalink
Fixing num_outputs for discrete actions
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Mar 7, 2019
1 parent 3652b57 commit 4d2ff68
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions digideep/environment/dmc2gym/spec2space.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dm_control.rl import specs
import numpy as np
import collections

# import warnings

def spec2space_single(spec):
"""
Expand All @@ -21,7 +21,7 @@ def spec2space_single(spec):

if (type(spec) is specs.BoundedArraySpec) and (spec.dtype == np.int):
# Discrete
warnings.warn("The DMC environment uses a discrete action space!")
# warnings.warn("The DMC environment uses a discrete action space!")
if spec.minimum == 0:
return spaces.Discrete(spec.maximum)
else:
Expand Down
1 change: 0 additions & 1 deletion digideep/environment/dmc2gym/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import numpy as np
import collections
import warnings
import sys
import copy

Expand Down
6 changes: 4 additions & 2 deletions digideep/policy/stochastic/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ def __init__(self, device, obs_space, act_space, modelname, modelargs):
else:
raise NotImplementedError


num_outputs = act_space["dim"] if np.isscalar(act_space["dim"]) else act_space["dim"][0]
# TODO: For discrete actions, `act_space["dim"][0]` works. It works for constinuous actions as well.
# Even for discrete actions `np.isscalar(act_space["dim"])` returns False.
# num_outputs = act_space["dim"] if np.isscalar(act_space["dim"]) else act_space["dim"][0]
num_outputs = act_space["dim"].item() if len(act_space["dim"].shape)==0 else act_space["dim"][0]
if act_space["typ"] == "Discrete":
print("Discrete is recognized and num_outputs=", num_outputs)
self.model["dist"] = Categorical(num_inputs=modelargs["output_size"], num_outputs=num_outputs)
Expand Down

0 comments on commit 4d2ff68

Please sign in to comment.