Skip to content

Commit

Permalink
trying to support discrete action space
Browse files Browse the repository at this point in the history
  • Loading branch information
swy99 committed Aug 25, 2023
1 parent 3292e4a commit e27b256
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
6 changes: 6 additions & 0 deletions agent/mb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ def __init__(self, config, act_spec, tfstep):
self._use_amp = (config.precision == 16)
self.device = config.device

discrete = 'int' in act_spec.dtype.name
if self.cfg.actor.dist == 'auto':
self.cfg.actor.dist = 'onehot' if discrete else 'trunc_normal'
if self.cfg.actor_grad == 'auto':
self.cfg.actor_grad = 'reinforce' if discrete else 'dynamics'

inp_size = config.rssm.deter
if config.rssm.discrete:
inp_size += config.rssm.stoch * config.rssm.discrete
Expand Down
7 changes: 7 additions & 0 deletions agent/skill_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def __init__(self, config, act_spec, tfstep, skill_dim, solved_meta=None, imagin
self._use_amp = (config.precision == 16)
self.device = config.device

discrete = 'int' in act_spec.dtype.name

if self.cfg.actor.dist == 'auto':
self.cfg.actor.dist = 'onehot' if discrete else 'trunc_normal'
if self.cfg.actor_grad == 'auto':
self.cfg.actor_grad = 'reinforce' if discrete else 'dynamics'

self.imagine_obs = imagine_obs
self.solved_meta = solved_meta
self.skill_dim = skill_dim
Expand Down
4 changes: 2 additions & 2 deletions configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ loss_scales: {kl: 1.0, reward: 1.0, discount: 1.0, proprio: 1.0}
model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6}
replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: False}

actor: {layers: 4, units: 400, norm: none, dist: trunc_normal, min_std: 0.1 }
actor: {layers: 4, units: 400, norm: none, dist: auto, min_std: 0.1 }
critic: {layers: 4, units: 400, norm: none, dist: mse}
actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6}
critic_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6}
discount: 0.99
discount_lambda: 0.95
actor_grad: dynamics
actor_grad: auto
slow_target: True
slow_target_update: 100
slow_target_fraction: 1
Expand Down

0 comments on commit e27b256

Please sign in to comment.