Skip to content

Commit

Permalink
modify examples/development to support SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
henry-zhang-bohan authored and hartikainen committed Feb 1, 2019
1 parent 6ac7db6 commit 4453876
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 55 deletions.
77 changes: 49 additions & 28 deletions examples/development/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,25 @@ def _build(self):
initial_exploration_policy = self.initial_exploration_policy = (
get_policy('UniformPolicy', env))

self.algorithm = get_algorithm_from_variant(
variant=variant,
env=env,
policy=policy,
initial_exploration_policy=initial_exploration_policy,
Qs=Qs,
pool=replay_pool,
sampler=sampler,
session=self._session,
)
if self._variant['algorithm_params']['type'] == 'SQL':
self.algorithm = get_algorithm_from_variant(
variant=self._variant,
env=self.env,
policy=policy,
Q=Qs[0],
pool=replay_pool,
sampler=sampler,
session=self._session)
else:
self.algorithm = get_algorithm_from_variant(
variant=self._variant,
env=self.env,
policy=policy,
initial_exploration_policy=initial_exploration_policy,
Qs=Qs,
pool=replay_pool,
sampler=sampler,
session=self._session)

initialize_tf_variables(self._session, only_uninitialized=True)

Expand Down Expand Up @@ -142,7 +151,7 @@ def _restore_replay_pool(self, current_checkpoint_dir):
experience_paths = [
self._replay_pool_pickle_path(checkpoint_dir)
for checkpoint_dir in sorted(glob.iglob(
os.path.join(experiment_root, 'checkpoint_*')))
os.path.join(experiment_root, 'checkpoint_*')))
]

for experience_path in experience_paths:
Expand Down Expand Up @@ -187,15 +196,25 @@ def _restore(self, checkpoint_dir):
initial_exploration_policy = self.initial_exploration_policy = (
get_policy('UniformPolicy', env))

self.algorithm = get_algorithm_from_variant(
variant=self._variant,
env=self.env,
policy=policy,
initial_exploration_policy=initial_exploration_policy,
Qs=Qs,
pool=replay_pool,
sampler=sampler,
session=self._session)
if self._variant['algorithm_params']['type'] == 'SQL':
self.algorithm = get_algorithm_from_variant(
variant=self._variant,
env=self.env,
policy=policy,
Q=Qs[0],
pool=replay_pool,
sampler=sampler,
session=self._session)
else:
self.algorithm = get_algorithm_from_variant(
variant=self._variant,
env=self.env,
policy=policy,
initial_exploration_policy=initial_exploration_policy,
Qs=Qs,
pool=replay_pool,
sampler=sampler,
session=self._session)
self.algorithm.__setstate__(pickleable['algorithm'].__getstate__())

tf_checkpoint = self._get_tf_checkpoint()
Expand All @@ -205,10 +224,12 @@ def _restore(self, checkpoint_dir):
status.assert_consumed().run_restore_ops(self._session)
initialize_tf_variables(self._session, only_uninitialized=True)

# TODO(hartikainen): target Qs should either be checkpointed
# or pickled.
for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets):
Q_target.set_weights(Q.get_weights())
# TODO(hartikainen): target Qs should either be checkpointed or pickled.
if self._variant['algorithm_params']['type'] == 'SQL':
self.algorithm._Q_target.set_weights(self.algorithm._Q.get_weights())
else:
for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets):
Q_target.set_weights(Q.get_weights())

self._built = True

Expand All @@ -219,12 +240,12 @@ def main():
universe, domain, task = parse_universe_domain_task(args)

if ('image' in task.lower()
or 'blind' in task.lower()
or 'image' in domain.lower()):
or 'blind' in task.lower()
or 'image' in domain.lower()):
variant_spec = get_variant_spec_image(
universe, domain, task, args.policy)
universe, domain, task, args.policy, args.algorithm)
else:
variant_spec = get_variant_spec(universe, domain, task, args.policy)
variant_spec = get_variant_spec(universe, domain, task, args.policy, args.algorithm)

variant_spec['mode'] = args.mode

Expand Down
62 changes: 41 additions & 21 deletions examples/development/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

NUM_COUPLING_LAYERS = 2


GAUSSIAN_POLICY_PARAMS_BASE = {
'type': 'GaussianPolicy',
'kwargs': {
Expand All @@ -27,7 +26,6 @@
'gaussian': POLICY_PARAMS_BASE['GaussianPolicy'],
})


POLICY_PARAMS_FOR_DOMAIN = {
'GaussianPolicy': GAUSSIAN_POLICY_PARAMS_FOR_DOMAIN,
}
Expand All @@ -48,20 +46,37 @@
'epoch_length': 1000,
'train_every_n_steps': 1,
'n_train_repeat': 1,
'n_initial_exploration_steps': int(1e3),
'reparameterize': REPARAMETERIZE,
'eval_render_mode': None,
'eval_n_episodes': 1,
'eval_deterministic': True,

'lr': 3e-4,
'discount': 0.99,
'target_update_interval': 1,
'tau': 5e-3,
'target_entropy': 'auto',
'reward_scale': 1.0,
'store_extra_policy_info': False,
'action_prior': 'uniform',
}
}

ALGORITHM_PARAMS_ADDITIONAL = {
'SAC': {
'type': 'SAC',
'kwargs': {
'reparameterize': REPARAMETERIZE,
'lr': 3e-4,
'target_update_interval': 1,
'tau': 1e-4,
'target_entropy': 'auto',
'store_extra_policy_info': False,
'action_prior': 'uniform',
'n_initial_exploration_steps': int(1e3),
}
},
'SQL': {
'type': 'SQL',
'kwargs': {
'policy_lr': 3e-4,
'td_target_update_interval': 1,
'n_initial_exploration_steps': 0,
}
}
}

Expand Down Expand Up @@ -90,9 +105,9 @@
'n_epochs': NUM_EPOCHS_PER_DOMAIN.get(
domain, DEFAULT_NUM_EPOCHS),
'n_initial_exploration_steps': (
MAX_PATH_LENGTH_PER_DOMAIN.get(
domain, DEFAULT_MAX_PATH_LENGTH
) * 10),
MAX_PATH_LENGTH_PER_DOMAIN.get(
domain, DEFAULT_MAX_PATH_LENGTH
) * 10),
}
} for domain in NUM_EPOCHS_PER_DOMAIN
}
Expand Down Expand Up @@ -149,18 +164,26 @@
},
'Point2DEnv': {
'Default': {
'observation_keys': ('observation', ),
'observation_keys': ('observation',),
},
'Wall': {
'observation_keys': ('observation', ),
'observation_keys': ('observation',),
},
}
}

NUM_CHECKPOINTS = 10


def get_variant_spec(universe, domain, task, policy):
def get_variant_spec(universe, domain, task, policy, algorithm):
algorithm_params = deep_update(
ALGORITHM_PARAMS_BASE,
ALGORITHM_PARAMS_PER_DOMAIN.get(domain, {})
)
algorithm_params = deep_update(
algorithm_params,
ALGORITHM_PARAMS_ADDITIONAL.get(algorithm, {})
)
variant_spec = {
'domain': domain,
'task': task,
Expand All @@ -178,10 +201,7 @@ def get_variant_spec(universe, domain, task, policy):
'hidden_layer_sizes': (M, M),
}
},
'algorithm_params': deep_update(
ALGORITHM_PARAMS_BASE,
ALGORITHM_PARAMS_PER_DOMAIN.get(domain, {})
),
'algorithm_params': algorithm_params,
'replay_pool_params': {
'type': 'SimpleReplayPool',
'kwargs': {
Expand Down Expand Up @@ -211,9 +231,9 @@ def get_variant_spec(universe, domain, task, policy):
return variant_spec


def get_variant_spec_image(universe, domain, task, policy, *args, **kwargs):
def get_variant_spec_image(universe, domain, task, policy, algorithm, *args, **kwargs):
variant_spec = get_variant_spec(
universe, domain, task, policy, *args, **kwargs)
universe, domain, task, policy, algorithm, *args, **kwargs)

if 'image' in task.lower() or 'image' in domain.lower():
preprocessor_params = {
Expand Down
6 changes: 1 addition & 5 deletions examples/multi_goal/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
'discount': 0.99,
'reward_scale': 1.0,
'save_full_state': True,
'tau': 1e-4,
}
}

Expand All @@ -24,7 +25,6 @@
'reparameterize': True,
'lr': 3e-4,
'target_update_interval': 1,
'tau': 1e-4,
'target_entropy': -2.0,
'store_extra_policy_info': False,
'action_prior': 'uniform',
Expand All @@ -43,10 +43,6 @@

def get_variant_spec(universe, domain, task, policy, local_dir, algorithm):
layer_size = 64
print(deep_update(
ALGORITHM_PARAMS_BASE,
ALGORITHM_PARAMS_ADDITIONAL.get(algorithm, {})
))
variant_spec = {
'seed': 1,

Expand Down
1 change: 0 additions & 1 deletion softlearning/algorithms/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

def assert_shape(tensor, expected_shape):
tensor_shape = tensor.shape.as_list()
print(tensor_shape, expected_shape)
assert len(tensor_shape) == len(expected_shape)
assert all([a == b for a, b in zip(tensor_shape, expected_shape)])

Expand Down

0 comments on commit 4453876

Please sign in to comment.