Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] DDPG worse performance than stable baselines #1181

Closed
3 tasks done
smorad opened this issue May 22, 2023 · 4 comments
Closed
3 tasks done

[BUG] DDPG worse performance than stable baselines #1181

smorad opened this issue May 22, 2023 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@smorad
Copy link
Contributor

smorad commented May 22, 2023

Describe the bug

The stable-baselines 3 (SB3) version of DDPG seems to significantly outperform the torchrl DDPG implementation. I implemented this following @matteobettini's VMAS MADDPG script very closely, so it is possible that I made mistakes.

To Reproduce

SB3 version

import gymnasium as gym
import numpy as np

from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

env = gym.make("Pendulum-v1", render_mode="rgb_array")

# The noise objects for DDPG
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)

Here are the results from the last few SB3 epochs

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 200      |
|    ep_rew_mean     | -144     |
| time/              |          |
|    episodes        | 130      |
|    fps             | 338      |
|    time_elapsed    | 76       |
|    total_timesteps | 26000    |
| train/             |          |
|    actor_loss      | 14.1     |
|    critic_loss     | 0.968    |
|    learning_rate   | 0.001    |
|    n_updates       | 25800    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 200      |
|    ep_rew_mean     | -147     |
| time/              |          |
|    episodes        | 140      |
|    fps             | 334      |
|    time_elapsed    | 83       |
|    total_timesteps | 28000    |
| train/             |          |
|    actor_loss      | 13.6     |
|    critic_loss     | 1.21     |
|    learning_rate   | 0.001    |
|    n_updates       | 27800    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 200      |
|    ep_rew_mean     | -147     |
| time/              |          |
|    episodes        | 150      |
|    fps             | 332      |
|    time_elapsed    | 90       |
|    total_timesteps | 30000    |
| train/             |          |
|    actor_loss      | 12.4     |
|    critic_loss     | 1.01     |
|    learning_rate   | 0.001    |
|    n_updates       | 29800    |
---------------------------------

Here is my torchrl implementation:

from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.libs.gym import GymEnv
import torchrl
from torchrl.envs.transforms.transforms import Compose, RewardSum
from torchrl.objectives import DDPGLoss
from torchrl.data import ReplayBuffer, ListStorage
from torchrl.modules import (
    ConvNet,
    EGreedyWrapper,
    LSTMModule,
    MLP,
    QValueModule,
    TanhDelta,
)
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch import nn
import torch
from torchrl.envs.transforms import DoubleToFloat, TransformedEnv
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.modules.tensordict_module.exploration import AdditiveGaussianWrapper
from torchrl.objectives.utils import SoftUpdate, ValueEstimators
from torchrl.data.replay_buffers.samplers import PrioritizedSampler


# Config
seed = 0
gamma = 0.99
tau = 0.005
frames_per_batch = 200  # Frames sampled each sampling iteration
batch_size = 1  # in terms of rollout size
utd = 1  # update to data ratio
max_steps = 100
n_iters = 500  # Number of sampling/training iterations
warmup_frames = 100  # To prefill replay buffer
total_frames = frames_per_batch * n_iters
memory_size = frames_per_batch * 100
lr = 0.0001

torch.manual_seed(0)


def env():
    env = GymEnv("Pendulum-v1", from_pixels=False)
    env = TransformedEnv(
        env, Compose(DoubleToFloat(in_keys=["observation"]), RewardSum())
    )
    return env


proof_env = env()
obs_size = proof_env.observation_spec["observation"].shape[0]
act_size = proof_env.action_spec.shape[0]

# Nets


class CriticNet(nn.Module):
    def __init__(self, obs_size, act_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size + act_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, 1),
        )
        nn.init.normal_(self.net[-1].weight, 0, 0.001)
        self.net[-1].bias.data.zero_()

    def forward(self, observation, action):
        return self.net(torch.cat([observation, action], dim=-1))


class ActorNet(nn.Module):
    def __init__(self, obs_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, 1),
        )
        nn.init.normal_(self.net[-1].weight, 0, 0.001)
        self.net[-1].bias.data.zero_()

    def forward(self, observation):
        return self.net(observation)


actor_mod = Mod(ActorNet(obs_size), in_keys=["observation"], out_keys=["param"])
actor = ProbabilisticActor(
    module=actor_mod,
    spec=proof_env.action_spec,
    in_keys=["param"],
    distribution_class=TanhDelta,
    distribution_kwargs={
        "min": proof_env.action_spec.space.minimum,
        "max": proof_env.action_spec.space.maximum,
    },
    return_log_prob=False,
)
policy = AdditiveGaussianWrapper(
    actor, annealing_num_steps=int(total_frames), sigma_end=0.05
)

critic = ValueOperator(CriticNet(obs_size, act_size), in_keys=["observation", "action"])

collector = torchrl.collectors.SyncDataCollector(
    env,
    policy,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
)
buffer = ReplayBuffer(storage=LazyTensorStorage(memory_size), batch_size=batch_size)

loss_module = DDPGLoss(
    actor_network=policy,
    value_network=critic,
    gamma=gamma,
)
target_net_updater = SoftUpdate(loss_module, eps=1 - tau)

critic_opt = torch.optim.Adam(actor.parameters(), lr)
actor_opt = torch.optim.Adam(critic.parameters(), lr)

for i, tensordict in enumerate(collector):
    buffer.add(tensordict)
    for j in range(utd):
        data = buffer.sample()
        loss = loss_module(data)
        critic_opt.zero_grad()
        actor_opt.zero_grad()
        loss["loss_actor"].backward()
        loss["loss_value"].backward()
        critic_opt.step()
        actor_opt.step()
        print(
            "Epoch {}: episode reward {:.2f} actor loss {:.2f}, critic loss {:.2f}".format(
                i,
                tensordict['next']["episode_reward"][tensordict['next']['done']].mean(),
                loss["loss_actor"].item(),
                loss["loss_value"].item(),
            )
        )
    target_net_updater.step()
    collector.update_policy_weights_()

And here are the torchrl results

Epoch 1: episode reward -1740.12 actor loss 0.01, critic loss 44.02
Epoch 2: episode reward -1327.09 actor loss 0.02, critic loss 44.01
Epoch 3: episode reward -977.49 actor loss 0.01, critic loss 76.29
Epoch 4: episode reward -1261.87 actor loss 0.03, critic loss 50.06
Epoch 5: episode reward -868.89 actor loss 0.04, critic loss 49.02
Epoch 6: episode reward -1044.71 actor loss 0.05, critic loss 49.02
Epoch 7: episode reward -1095.83 actor loss 0.02, critic loss 76.26
Epoch 8: episode reward -1666.79 actor loss 0.05, critic loss 50.04
Epoch 9: episode reward -1308.08 actor loss 0.06, critic loss 50.04
Epoch 10: episode reward -1070.14 actor loss 0.08, critic loss 48.99
...
Epoch 200: episode reward -1835.48 actor loss 28.66, critic loss 51.26
Epoch 201: episode reward -1423.16 actor loss 30.82, critic loss 51.08
Epoch 202: episode reward -1350.76 actor loss 31.53, critic loss 46.58
Epoch 203: episode reward -1597.77 actor loss 34.97, critic loss 53.95
Epoch 204: episode reward -1505.26 actor loss 30.50, critic loss 48.00
Epoch 205: episode reward -1197.51 actor loss 23.92, critic loss 46.47
Epoch 206: episode reward -1560.22 actor loss 18.10, critic loss 50.99
Epoch 207: episode reward -1182.17 actor loss 25.47, critic loss 47.51
Epoch 208: episode reward -1331.73 actor loss 18.64, critic loss 52.32
Epoch 209: episode reward -1738.40 actor loss 18.79, critic loss 51.17
Epoch 210: episode reward -1536.80 actor loss 31.51, critic loss 50.80
Epoch 211: episode reward -1270.51 actor loss 46.14, critic loss 76.46
Epoch 212: episode reward -1403.69 actor loss 36.24, critic loss 53.10
Epoch 213: episode reward -1447.37 actor loss 30.20, critic loss 50.71
Epoch 214: episode reward -1506.56 actor loss 32.46, critic loss 51.48
Epoch 215: episode reward -1609.89 actor loss 26.80, critic loss 47.40
Epoch 216: episode reward -1420.35 actor loss 36.40, critic loss 52.10
Epoch 217: episode reward -1450.20 actor loss 34.28, critic loss 49.61
Epoch 218: episode reward -1240.91 actor loss 21.95, critic loss 48.86
Epoch 219: episode reward -1450.16 actor loss 20.31, critic loss 47.86
Epoch 220: episode reward -1234.47 actor loss 47.02, critic loss 77.76
Epoch 221: episode reward -1735.47 actor loss 28.17, critic loss 51.75
Epoch 222: episode reward -1480.90 actor loss 21.89, critic loss 42.26
Epoch 223: episode reward -1544.21 actor loss 51.45, critic loss 88.18
Epoch 224: episode reward -1424.87 actor loss 63.30, critic loss 80.44
Epoch 225: episode reward -1367.20 actor loss 41.04, critic loss 57.50
Epoch 226: episode reward -1565.24 actor loss 32.88, critic loss 49.72
Epoch 227: episode reward -1173.81 actor loss 36.75, critic loss 57.84
Epoch 228: episode reward -1882.98 actor loss 38.38, critic loss 53.28
Epoch 229: episode reward -1710.50 actor loss 21.60, critic loss 48.92
Epoch 230: episode reward -1552.52 actor loss 38.31, critic loss 58.09
Epoch 231: episode reward -1762.59 actor loss 29.91, critic loss 50.14
Epoch 232: episode reward -1570.50 actor loss 31.51, critic loss 47.88
Epoch 233: episode reward -1663.63 actor loss 68.49, critic loss 87.92
Epoch 234: episode reward -1734.04 actor loss 24.08, critic loss 47.89
Epoch 235: episode reward -1483.62 actor loss 36.49, critic loss 52.02
Epoch 236: episode reward -1790.86 actor loss 39.68, critic loss 57.76
Epoch 237: episode reward -1414.47 actor loss 40.03, critic loss 58.14

Expected behavior

I'd expect the SB3 and torchrl implementations to be roughly equivalent in terms of reward and value loss. It appears that the torchrl version isn't really learning and the critic loss is much worse.

System info

Describe the characteristic of your environment:
From source (master branch, 22 May)

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@smorad smorad added the bug Something isn't working label May 22, 2023
@vmoens
Copy link
Contributor

vmoens commented May 23, 2023

Hi @smorad
Thanks for reporting this. It's a valid point that we should provide algos with SOTA performance without trouble so we'll work on making life easier for our users on this regard.

I made a working example out from your script which I'm hosting on a separate branch. You can see my edits here:
ac7eff7

This is the file:
https://github.com/pytorch/rl/blob/ddpg_example/ddpg_example.py

I testes it for 10K frames and it seems to be working fine.

The major issue is that in your script, you're not building target parameters. I see two problems for torchrl here:

  1. We must set the delay_value to True by default, this is addressed here
  2. The updater should raise an exception if there is no target param, this is addressed here

Other small issues:

  • I see from the code you ran in SB3 that there are as many updates as steps, but with torchrl you run 200x less. So I used a UTR of 200.
  • You add data to the buffer but don't extend, which means that you store entire trajectories. You can get things to work better by storing single transitions I think, especially for off-policy learning.
  • I see that the actor loss in SB3 is about 14 when the total reward of a traj is about 140. This is one order of magnitude less, but with a gamma of 0.99 and an average reward per step of 0.7, I would expect the actor loss to be of about 100 * 0.7 = 70. Not sure if they scale the reward somewhere? I could not find it. I added a reward scaling to have more decent losses but I guess this won't have as much impact as the points before.
  • I would use a TensorDictReplayBuffer. Regular RB work fine too, but you can more easily pack info in a TDRB.
  • I think you mixed the optimizers but this has no impact. Also the polyak update should be done at each optimizer update.

Let me know if you have questions or remarks! Happy to address anything that i missed.

@smorad
Copy link
Contributor Author

smorad commented May 23, 2023

Fantastic, thanks for the quick response!

@smorad
Copy link
Contributor Author

smorad commented May 23, 2023

The missing target network (oops) and buffer.extend fixes divergence and makes this learn significantly better.

@smorad smorad closed this as completed May 23, 2023
@vmoens
Copy link
Contributor

vmoens commented May 23, 2023

DDPG should default to have a target net and we should not be able to create SoftUpdate without one, so this one's defo on me :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants