Skip to content

Commit

Permalink
Merge pull request #187 from zuoxingdong/step_info_trajectory
Browse files Browse the repository at this point in the history
add SAC
  • Loading branch information
zuoxingdong committed May 12, 2019
2 parents 18a7441 + dbb3d78 commit ed2df3c
Show file tree
Hide file tree
Showing 458 changed files with 36 additions and 3,271 deletions.
18 changes: 18 additions & 0 deletions baselines/sac/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Soft Actor-Critic (SAC)

This is an implementation of [SAC](https://arxiv.org/abs/1812.05905) algorithm with automatic temperature tuning.

# Usage

Run the following command to start parallelized training:

```bash
python experiment.py
```

One could modify [experiment.py](./experiment.py) to quickly set up different configurations.

# Results

## MLP policy
<img src='logs/default/result.png' width='100%'>
17 changes: 5 additions & 12 deletions baselines/sac/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lagom.networks import ortho_init


# TODO: import from PyTorch when PR merged: https://github.com/pytorch/pytorch/pull/19785
class TanhTransform(Transform):
r"""
Transform via the mapping :math:`y = \tanh(x)`.
Expand Down Expand Up @@ -177,40 +178,32 @@ def learn(self, D, **kwargs):
for i in range(episode_length):
observations, actions, rewards, next_observations, masks = replay.sample(self.config['replay.batch_size'])

# Update Critic
Qs1, Qs2 = self.critic(observations, actions)
#Qs1, Qs2 = map(lambda x: x.squeeze(-1), [Qs1, Qs2])
with torch.no_grad():
out_actor = self.choose_action(next_observations, mode='train')
next_actions = out_actor['action']
next_actions_logprob = out_actor['action_logprob'].unsqueeze(-1)
next_Qs1, next_Qs2 = self.critic_target(next_observations, next_actions)
next_Qs = torch.min(next_Qs1, next_Qs2) - self.alpha.detach()*next_actions_logprob
Q_targets = rewards.unsqueeze(-1) + self.config['agent.gamma']*masks.unsqueeze(-1)*next_Qs

Q_targets = rewards + self.config['agent.gamma']*masks*next_Qs
critic_loss = F.mse_loss(Qs1, Q_targets.detach()) + F.mse_loss(Qs2, Q_targets.detach())
self.optimizer_zero_grad()
critic_loss.backward()
critic_grad_norm = nn.utils.clip_grad_norm_(self.critic.parameters(), self.config['agent.max_grad_norm'])
self.critic_optimizer.step()

# Update Actor
out_actor = self.choose_action(observations, mode='train')
policy_actions = out_actor['action']
policy_actions_logprob = out_actor['action_logprob']

policy_actions_logprob = out_actor['action_logprob'].unsqueeze(-1)
actor_Qs1, actor_Qs2 = self.critic(observations, policy_actions)
actor_Qs = torch.min(actor_Qs1, actor_Qs2).squeeze(-1)
actor_Qs = torch.min(actor_Qs1, actor_Qs2)
actor_loss = torch.mean(self.alpha.detach()*policy_actions_logprob - actor_Qs)

self.optimizer_zero_grad()
actor_loss.backward()
actor_grad_norm = nn.utils.clip_grad_norm_(self.actor.parameters(), self.config['agent.max_grad_norm'])
self.actor_optimizer.step()

# Update alpha

alpha_loss = torch.mean(self.log_alpha*(-policy_actions_logprob - self.target_entropy).detach())

self.optimizer_zero_grad()
alpha_loss.backward()
self.log_alpha_optimizer.step()
Expand Down
2 changes: 1 addition & 1 deletion baselines/sac/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ def run(config, seed, device, logdir):
log_dir='logs/default',
max_workers=os.cpu_count(),
chunksize=1,
use_gpu=True,
use_gpu=True, # GPU much faster, note that performance differs between CPU/GPU
gpu_ids=None)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
21 changes: 0 additions & 21 deletions baselines/sac/logs/___default/0/config.yml

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
21 changes: 0 additions & 21 deletions baselines/sac/logs/___default/1/config.yml

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
21 changes: 0 additions & 21 deletions baselines/sac/logs/___default/2/config.yml

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
21 changes: 0 additions & 21 deletions baselines/sac/logs/___default/3/config.yml

This file was deleted.

Binary file removed baselines/sac/logs/___default/configs.pkl
Binary file not shown.
Binary file removed baselines/sac/logs/___default/result.png
Binary file not shown.
Empty file.

0 comments on commit ed2df3c

Please sign in to comment.