


## Problem Setup
This example demonstrates a multi-agent deep deterministic policy gradient (MADDPG) approach to a competitive environment where chasers aim to catch evaders. Each group of agents (chasers and evaders) has its own policy and value networks, trained either independently or in a mixed cooperative-competitive setting. It serves as our control when we try and integrate prospect theory into the policy gradient, seeing if we can get different results than the reward graphs below. Code is based from [https://pytorch.org/rl/0.6/tutorials/multiagent_competitive_ddpg.html]

## Clear Problem Statement
Train two chaser agents to minimize the evader’s cumulative reward while simultaneously training the evader agent to maximize its own cumulative reward. The environment runs for a fixed number of steps, and training can be halted for certain agents at a chosen iteration.

## Mathematical Formulation
- **Agent Policies**: $\pi_i(\mathbf{o_i}; \theta_i)$ map observations $\mathbf{o_i}$ to continuous actions.
- **Value Function**: $Q_i(\mathbf{o}, \mathbf{a}; \phi_i)$ estimates future return given all agents’ actions $\mathbf{a}$ and observations $\mathbf{o}$.
- **Loss Functions**: DDPG losses incorporate actor and critic objectives, ensuring that each agent maximizes expected returns while considering centralized training and decentralized execution.
- **Updates**: Soft updates are performed on target networks with \(\tau\) for both the policy and value functions.

## Data Requirements
- Episodes of agent interactions, collected with exploration strategies (e.g., Gaussian noise).
- Replay buffers per group for sampled training batches containing states, actions, rewards, and next states.

## Success Metrics
- Mean episode reward for each group (chasers and evaders), typically measured and plotted over training iterations.
- Convergence or stabilization of the reward signal, indicating improved policy performance.


In [None]:
!pip3 install torchrl==0.6.0
!pip3 install vmas
!pip3 install pettingzoo[mpe]==1.24.3
!pip3 install tqdm

Collecting torchrl==0.6.0
  Downloading torchrl-0.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (39 kB)
Collecting tensordict>=0.6.0 (from torchrl==0.6.0)
  Downloading tensordict-0.7.2-cp311-cp311-manylinux1_x86_64.whl.metadata (9.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.5.0->torchrl==0.6.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.5.0->torchrl==0.6.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.5.0->torchrl==0.6.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.5.0->torchrl==0.6.0)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.

### **Approach: Importing Required Libraries**
This section imports essential modules for implementing MADDPG:  
- **PyTorch** for deep learning operations.  
- **`torchrl` modules** for multi-agent reinforcement learning, including environments, policies, collectors, and replay buffers.  
- **`tensordict`** for structured tensor operations.  
- **Matplotlib** for visualization.  
- **`tqdm`** for progress tracking.

In [None]:
import copy
import tempfile

import torch

from matplotlib import pyplot as plt
from tensordict import TensorDictBase

from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import multiprocessing

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, RandomSampler, ReplayBuffer

from torchrl.envs import (
    check_env_specs,
    ExplorationType,
    PettingZooEnv,
    RewardSum,
    set_exploration_type,
    TransformedEnv,
    VmasEnv,
)

from torchrl.modules import (
    AdditiveGaussianModule,
    MultiAgentMLP,
    ProbabilisticActor,
    TanhDelta,
)

from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators

from torchrl.record import CSVLogger, PixelRenderTransform, VideoRecorder

from tqdm import tqdm

try:
    is_sphinx = __sphinx_build__
except NameError:
    is_sphinx = False

In [None]:
import copy
import tempfile

import torch

from matplotlib import pyplot as plt
from tensordict import TensorDictBase, is_tensor_collection

from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import multiprocessing

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, RandomSampler, ReplayBuffer

from torchrl.envs import (
    check_env_specs,
    ExplorationType,
    PettingZooEnv,
    RewardSum,
    set_exploration_type,
    TransformedEnv,
    VmasEnv,
)

from torchrl.modules import (
    AdditiveGaussianModule,
    MultiAgentMLP,
    ProbabilisticActor,
    TanhDelta,
)

from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators

from torchrl.record import CSVLogger, PixelRenderTransform, VideoRecorder

from tqdm import tqdm

try:
    is_sphinx = __sphinx_build__
except NameError:
    is_sphinx = False


try:
    from torch.compiler import is_compiling
except ImportError:
    from torch._dynamo import is_compiling

### **Approach: Environment Setup & Hyperparameters**
- **Seed & Device**: Sets the random seed for reproducibility and selects the appropriate device (GPU if available, otherwise CPU).  
- **Sampling**: Defines frames collected per batch (`1,000`), total iterations (`50`), and total frames (`50,000`).  
- **Training Control**: Stops evader training at `iteration_when_stop_training_evaders = 25`.  
- **Replay Buffer**: Stores up to `1M` frames for experience replay.  
- **Training Parameters**:  
  - **Optimization**: `100` updates per iteration, batch size of `128`.  
  - **Learning Rate**: `3e-4`, gradient clipping at `1.0`.  
- **DDPG-Specific**: Uses discount factor (`γ = 0.99`) and soft update parameter (`τ = 0.005`).

In [None]:
# Seed
seed = 0
torch.manual_seed(seed)

# Devices
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

# Sampling
frames_per_batch = 1_000  # Number of team frames collected per sampling iteration
n_iters = 200  # Number of sampling and training iterations
total_frames = frames_per_batch * n_iters


# Replay buffer
memory_size = 1_000_000  # The replay buffer of each group can store this many frames

# Training
n_optimiser_steps = 100  # Number of optimization steps per training iteration
train_batch_size = 128  # Number of frames trained in each optimiser step
lr = 3e-4  # Learning rate
max_grad_norm = 1.0  # Maximum norm for the gradients

# DDPG
gamma = 0.99  # Discount factor
polyak_tau = 0.005  # Tau for the soft-update of the target network

### **Approach: Environment Configuration**
- **Max Steps**: Each episode runs for `100` steps.  
- **Agents & Obstacles**: `2` chasers, `1` evader, and `2` obstacles.  
- **VMAS for Performance**:  
  - If `use_vmas = True`, uses `VmasEnv` for efficient vectorized multi-agent simulation.  
  - Otherwise, defaults to `PettingZooEnv` (parallel mode) for `simple_tag_v3`.  
- **Vectorization**: `num_vmas_envs = frames_per_batch / max_steps` ensures efficient frame collection.

In [None]:
max_steps = 100  # Environment steps before done

n_agents = 2
n_landmarks = 1

use_vmas = True  # Set this to True for a great performance speedup

if not use_vmas:
  base_env = PettingZooEnv(
      task="simple_spread_v3",
      parallel=True,
      seed=seed,
      continuous_actions=True,
      N = n_landmarks
  )
else:
    num_vmas_envs = (
        frames_per_batch // max_steps
    )
    base_env = VmasEnv(
        scenario="simple_spread",
        num_envs=num_vmas_envs,
        continuous_actions=True,
        max_steps=max_steps,
        local_ratio=0.5,
        device=device,
        seed=seed,
        n_agents = n_agents
    )

In [None]:
print(f"group_map: {base_env.group_map}")

group_map: {'agents': ['agent_0', 'agent_1']}


In [None]:
print("action_spec:", base_env.full_action_spec)
print("reward_spec:", base_env.full_reward_spec)
print("done_spec:", base_env.full_done_spec)
print("observation_spec:", base_env.observation_spec)

action_spec: Composite(
    agents: Composite(
        action: BoundedContinuous(
            shape=torch.Size([10, 2, 2]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([10, 2, 2]), device=cuda:0, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([10, 2, 2]), device=cuda:0, dtype=torch.float32, contiguous=True)),
            device=cuda:0,
            dtype=torch.float32,
            domain=continuous),
        device=cuda:0,
        shape=torch.Size([10, 2])),
    device=cuda:0,
    shape=torch.Size([10]))
reward_spec: Composite(
    agents: Composite(
        reward: UnboundedContinuous(
            shape=torch.Size([10, 2, 1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([10, 2, 1]), device=cuda:0, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([10, 2, 1]), device=cuda:0, dtype=torch.float32, contiguous=True)),
            device=cuda:0,
           

In [None]:
print("action_keys:", base_env.action_keys)
print("reward_keys:", base_env.reward_keys)
print("done_keys:", base_env.done_keys)

action_keys: [('agents', 'action')]
reward_keys: [('agents', 'reward')]
done_keys: ['done', 'terminated']


### **Approach: Environment Transformation**
- **Wraps `base_env` with `TransformedEnv`** to apply reward processing.  
- **`RewardSum` Aggregation**:  
  - Uses `reward_keys` from `base_env` to sum rewards over time.  
  - Resets rewards using `_reset` keys for each agent group.  
- **Purpose**: Ensures proper reward tracking across multi-agent interactions.

In [None]:
env = TransformedEnv(
    base_env,
    RewardSum(
        in_keys=base_env.reward_keys,
        reset_keys=["_reset"] * len(base_env.group_map.keys()),
    ),
)

In [None]:
check_env_specs(env)

2025-03-25 03:33:31,213 [torchrl][INFO] check_env_specs succeeded!


In [None]:
n_rollout_steps = 5
rollout = env.rollout(n_rollout_steps)
print(f"rollout of {n_rollout_steps} steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)

rollout of 5 steps: TensorDict(
    fields={
        agents: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([10, 5, 2, 2]), device=cuda:0, dtype=torch.float32, is_shared=True),
                episode_reward: Tensor(shape=torch.Size([10, 5, 2, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                observation: Tensor(shape=torch.Size([10, 5, 2, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
            batch_size=torch.Size([10, 5, 2]),
            device=cuda:0,
            is_shared=True),
        done: Tensor(shape=torch.Size([10, 5, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        next: TensorDict(
            fields={
                agents: TensorDict(
                    fields={
                        episode_reward: Tensor(shape=torch.Size([10, 5, 2, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                        observation: Tensor(shape=torch.Size([10, 5, 2, 10]), device=cuda:0

### **Approach: Policy Network Setup**
- **Iterates over agent groups** to create independent policies.  
- **Defines `MultiAgentMLP`** for decentralized policies:
  - **Observations & Actions**: Uses `env.observation_spec` and `env.full_action_spec`.
  - **Decentralized Execution**: Each agent acts based on its local observation.
  - **Parameter Sharing**: Controlled by `share_parameters_policy` (set to `True` for efficiency).
  - **Architecture**: 2-layer MLP (`256` neurons per layer, `Tanh` activation).
- **Wraps in `TensorDictModule`**:
  - Reads observations from `TensorDict` and writes action parameters.  
  - Allows structured tensor operations for multi-agent training.

In [None]:
policy_modules = {}
for group, agents in env.group_map.items():
    share_parameters_policy = False  # Can change this based on the group

    policy_net = MultiAgentMLP(
        n_agent_inputs=env.observation_spec[group, "observation"].shape[
            -1
        ],  # n_obs_per_agent
        n_agent_outputs=env.full_action_spec[group, "action"].shape[
            -1
        ],  # n_actions_per_agents
        n_agents=len(agents),  # Number of agents in the group
        centralised=False,  # the policies are decentralised (i.e., each agent will act from its local observation)
        share_params=share_parameters_policy,
        device=device,
        depth=2,
        num_cells=256,
        activation_class=torch.nn.Tanh,
    )

    # Wrap the neural network in a :class:`~tensordict.nn.TensorDictModule`.
    # This is simply a module that will read the ``in_keys`` from a tensordict, feed them to the
    # neural networks, and write the
    # outputs in-place at the ``out_keys``.

    policy_module = TensorDictModule(
        policy_net,
        in_keys=[(group, "observation")],
        out_keys=[(group, "param")],
    )  # We just name the input and output that the network will read and write to the input tensordict
    policy_modules[group] = policy_module

### **Approach: Probabilistic Policy Definition**
- **Wraps policy networks (`policy_modules`) in `ProbabilisticActor`** to handle stochastic action sampling.  
- **Uses `TanhDelta` Distribution**:
  - Ensures continuous action outputs stay within predefined bounds (`low`, `high`).  
  - Helps stabilize training by keeping actions constrained.  
- **Input & Output Keys**:
  - Reads action parameters from `policy_modules` (`(group, "param")`).  
  - Outputs final actions (`(group, "action")`).  
- **Log Probabilities Disabled (`return_log_prob=False`)**:  
  - Not needed for deterministic policy updates in DDPG.

In [None]:
policies = {}
for group, _agents in env.group_map.items():
    policy = ProbabilisticActor(
        module=policy_modules[group],
        spec=env.full_action_spec[group, "action"],
        in_keys=[(group, "param")],
        out_keys=[(group, "action")],
        distribution_class=TanhDelta,
        distribution_kwargs={
            "low": env.full_action_spec[group, "action"].space.low,
            "high": env.full_action_spec[group, "action"].space.high,
        },
        return_log_prob=False,
    )
    policies[group] = policy

### **Approach: Exploration Policy with Gaussian Noise**
- **Adds exploration noise to deterministic policies** using `AdditiveGaussianModule`.  
- **Purpose**: Encourages better exploration by injecting Gaussian noise into actions.  
- **Annealing Strategy**:
  - **Starts with `sigma_init = 0.9`** (high noise for exploration).  
  - **Decays to `sigma_end = 0.1`** over `total_frames / 2` steps, reducing noise gradually.  
- **Wrapped in `TensorDictSequential`**:
  - First applies the base policy (`policies[group]`).  
  - Then adds Gaussian noise to the output action (`(group, "action")`).  
- **Ensures Smooth Transition**: High exploration at the start, stabilizing towards exploitation.

In [None]:
exploration_policies = {}
for group, _agents in env.group_map.items():
    exploration_policy = TensorDictSequential(
        policies[group],
        AdditiveGaussianModule(
            spec=policies[group].spec,
            annealing_num_steps=total_frames
            // 2,  # Number of frames after which sigma is sigma_end
            action_key=(group, "action"),
            sigma_init=0.9,  # Initial value of the sigma
            sigma_end=0.1,  # Final value of the sigma
        ),
    )
    exploration_policies[group] = exploration_policy

### **Approach: Critic Network for Value Estimation**
- **Defines critic networks for each agent group** to estimate state-action values (\(Q\)-values).  
- **Centralized vs. Decentralized Critic**:
  - **`MADDPG = True`**: Uses a centralized critic (multi-agent).  
  - **`IDDPG = False`**: Uses an independent critic per agent.  
- **Feature Concatenation (`cat_module`)**:
  - Combines agent's observation and action into a single tensor (`(group, "obs_action")`).  
- **Critic Network (`critic_module`)**:
  - Takes concatenated state-action inputs and predicts a **single Q-value per agent**.  
  - Uses a **2-layer MLP (256 neurons per layer, `Tanh` activation)**.  
  - Supports parameter sharing (`share_parameters_critic = True`).  
- **Final Critic Pipeline (`TensorDictSequential`)**:
  - First applies **feature concatenation (`cat_module`)**.  
  - Then passes through **`MultiAgentMLP` for value estimation** (`(group, "state_action_value")`).

In [None]:
critics = {}
for group, agents in env.group_map.items():
    share_parameters_critic = True  # Can change for each group
    MADDPG = True  # IDDPG if False, can change for each group

    # This module applies the lambda function: reading the action and observation entries for the group
    # and concatenating them in a new ``(group, "obs_action")`` entry
    cat_module = TensorDictModule(
        lambda obs, action: torch.cat([obs, action], dim=-1),
        in_keys=[(group, "observation"), (group, "action")],
        out_keys=[(group, "obs_action")],
    )

    critic_module = TensorDictModule(
        module=MultiAgentMLP(
            n_agent_inputs=env.observation_spec[group, "observation"].shape[-1]
            + env.full_action_spec[group, "action"].shape[-1],
            n_agent_outputs=1,  # 1 value per agent
            n_agents=len(agents),
            centralised=MADDPG,
            share_params=share_parameters_critic,
            device=device,
            depth=2,
            num_cells=256,
            activation_class=torch.nn.Tanh,
        ),
        in_keys=[(group, "obs_action")],  # Read ``(group, "obs_action")``
        out_keys=[
            (group, "state_action_value")
        ],  # Write ``(group, "state_action_value")``
    )

    critics[group] = TensorDictSequential(
        cat_module, critic_module
    )  # Run them in sequence

In [None]:
reset_td = env.reset()
for group, _agents in env.group_map.items():
    print(
        f"Running value and policy for group '{group}':",
        critics[group](policies[group](reset_td)),
    )

Running value and policy for group 'agents': TensorDict(
    fields={
        agents: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([10, 2, 2]), device=cuda:0, dtype=torch.float32, is_shared=True),
                episode_reward: Tensor(shape=torch.Size([10, 2, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                obs_action: Tensor(shape=torch.Size([10, 2, 12]), device=cuda:0, dtype=torch.float32, is_shared=True),
                observation: Tensor(shape=torch.Size([10, 2, 10]), device=cuda:0, dtype=torch.float32, is_shared=True),
                param: Tensor(shape=torch.Size([10, 2, 2]), device=cuda:0, dtype=torch.float32, is_shared=True),
                state_action_value: Tensor(shape=torch.Size([10, 2, 1]), device=cuda:0, dtype=torch.float32, is_shared=True)},
            batch_size=torch.Size([10, 2]),
            device=cuda:0,
            is_shared=True),
        done: Tensor(shape=torch.Size([10, 1]), device=cuda:0, dty

### **Approach: Data Collection for Training**
- **Combines all group exploration policies** into a single sequential module (`TensorDictSequential`), ensuring actions include exploration noise.  
- **`SyncDataCollector` for Data Sampling**:
  - Collects experience from the environment using **exploration policies**.  
  - Runs on **`device` (GPU or CPU)** for efficiency.  
  - **Frames per batch**: `1,000`, ensuring large enough updates per iteration.  
  - **Total frames**: `50,000` (over `50` iterations).  
- **Purpose**: Efficiently gathers on-policy experiences for training with replay buffers.

In [None]:
# Put exploration policies from each group in a sequence
agents_exploration_policy = TensorDictSequential(*exploration_policies.values())

collector = SyncDataCollector(
    env,
    agents_exploration_policy,
    device=device,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
)

In [None]:
#Standard in off policy algos for efficient data collections
replay_buffers = {}
for group, _agents in env.group_map.items():
    replay_buffer = ReplayBuffer(
        storage=LazyMemmapStorage(memory_size, device="cpu"),
        sampler=RandomSampler(),
        batch_size=train_batch_size,
    )
    replay_buffer.append_transform(lambda batch: batch.to("cuda:0"))
    replay_buffers[group] = replay_buffer

In [None]:
# Define parameters for each agent in the cooperative group
agent_params = {
    "agent_0": {
        "alpha": 0.7,
        "lam": 0.8,
        "w_plus_prime_const": 0.8,
        "w_minus_prime_const": 0.2,
    },
    "agent_1": {
        "alpha": 0.65,
        "lam": 2.8,
        "w_plus_prime_const": 0.25,
        "w_minus_prime_const": 0.75,
    },
}


In [None]:
w_plus_prime_const = 0.2
w_minus_prime_const = 0.8

def w_plus_prime(p):
    eta = 0.61
    return eta * torch.pow(p, eta - 1)

def w_minus_prime(p):
    eta = 0.69
    return eta * torch.pow(p, eta - 1)

def compute_phi_linear(R):
    """
    Compute linearized CPT sensitivity:
    φ(R) ≈ w'_+(p*) * u^+(R) for R>=0, and -w'_-(p*) * u^-(R) for R<0.
    """
    R = R.view(-1)
    v = torch.where(R >= 0, u_plus(R), -u_minus(R))
    phi = torch.where(R >= 0, w_plus_prime_const * v, -w_minus_prime_const * v)
    return phi.mean()

epsilon = 1e-6  # small constant to avoid numerical issues


def u_plus_agent(x, params):
    return torch.pow(x, params["alpha"])

def u_minus_agent(x, params):
    return params["lam"] * torch.pow(-x, params["alpha"])

def compute_phi_cross(R, agent_params):
    R = R.view(-1)  # flatten rewards
    phi_values = []
    for agent_id, params in agent_params.items():
        v = torch.where(R >= 0, u_plus_agent(R, params), -u_minus_agent(R, params))
        phi = torch.where(R >= 0, params["w_plus_prime_const"] * v, -params["w_minus_prime_const"] * v)
        phi_values.append(phi)
    phi_stack = torch.stack(phi_values)  # shape: [num_agents, batch_size]
    return phi_stack.mean(dim=0)  # average across agents

def C_transform_cross(x, agent_params):
    transformed_vals = []
    for agent_id, params in agent_params.items():
        transformed = torch.where(
            x >= 0,
            params["w_plus_prime_const"] * u_plus_agent(x, params),
            -params["w_minus_prime_const"] * u_minus_agent(x, params),
        )
        transformed_vals.append(transformed)
    return torch.stack(transformed_vals).mean(dim=0)


def C_transform(x):
    """
    Simple CPT transformation on one-step return:
    C(x) ≈ w'_+(p*) * u^+(x) if x >= 0, else -w'_-(p*) * u^-(x).
    """
    return torch.where(x >= 0, w_plus_prime_const * u_plus(x), -w_minus_prime_const * u_minus(x))

In [None]:
from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule

from tensordict.utils import NestedKey, unravel_key
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
    _cache_values,
    _GAMMA_LMBDA_DEPREC_ERROR,
    _reduce,
    default_value_kwargs,
    distance_loss,
    ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


class CPTDDPGLoss(LossModule):
    """The DDPG Loss class.

    Args:
        actor_network (TensorDictModule): a policy operator.
        value_network (TensorDictModule): a Q value operator.
        loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
        delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for
            data collection. Default is ``False``.
        delay_value (bool, optional): whether to separate the target value networks from the value networks used for
            data collection. Default is ``True``.
        separate_losses (bool, optional): if ``True``, shared parameters between
            policy and critic will only be trained on the policy loss.
            Defaults to ``False``, i.e., gradients are propagated to shared
            parameters for both policy and critic losses.
        reduction (str, optional): Specifies the reduction to apply to the output:
            ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
            ``"mean"``: the sum of the output will be divided by the number of
            elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.

    Examples:
        >>> import torch
        >>> from torch import nn
        >>> from torchrl.data import Bounded
        >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
        >>> from torchrl.objectives.ddpg import DDPGLoss
        >>> from tensordict import TensorDict
        >>> n_act, n_obs = 4, 3
        >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
        >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
        >>> class ValueClass(nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.linear = nn.Linear(n_obs + n_act, 1)
        ...     def forward(self, obs, act):
        ...         return self.linear(torch.cat([obs, act], -1))
        >>> module = ValueClass()
        >>> value = ValueOperator(
        ...     module=module,
        ...     in_keys=["observation", "action"])
        >>> loss = DDPGLoss(actor, value)
        >>> batch = [2, ]
        >>> data = TensorDict({
        ...        "observation": torch.randn(*batch, n_obs),
        ...        "action": spec.rand(batch),
        ...        ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
        ...        ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
        ...        ("next", "reward"): torch.randn(*batch, 1),
        ...        ("next", "observation"): torch.randn(*batch, n_obs),
        ...    }, batch)
        >>> loss(data)
        TensorDict(
            fields={
                loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)

    This class is compatible with non-tensordict based modules too and can be
    used without recurring to any tensordict-related primitive. In this case,
    the expected keyword arguments are:
    ``["next_reward", "next_done", "next_terminated"]`` + in_keys of the actor_network and value_network.
    The return value is a tuple of tensors in the following order:
    ``["loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max"]``

    Examples:
        >>> import torch
        >>> from torch import nn
        >>> from torchrl.data import Bounded
        >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
        >>> from torchrl.objectives.ddpg import DDPGLoss
        >>> _ = torch.manual_seed(42)
        >>> n_act, n_obs = 4, 3
        >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
        >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
        >>> class ValueClass(nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.linear = nn.Linear(n_obs + n_act, 1)
        ...     def forward(self, obs, act):
        ...         return self.linear(torch.cat([obs, act], -1))
        >>> module = ValueClass()
        >>> value = ValueOperator(
        ...     module=module,
        ...     in_keys=["observation", "action"])
        >>> loss = DDPGLoss(actor, value)
        >>> loss_actor, loss_value, pred_value, target_value, pred_value_max, target_value_max = loss(
        ...     observation=torch.randn(n_obs),
        ...     action=spec.rand(),
        ...     next_done=torch.zeros(1, dtype=torch.bool),
        ...     next_terminated=torch.zeros(1, dtype=torch.bool),
        ...     next_observation=torch.randn(n_obs),
        ...     next_reward=torch.randn(1))
        >>> loss_actor.backward()

    The output keys can also be filtered using the :meth:`DDPGLoss.select_out_keys`
    method.

    Examples:
        >>> loss.select_out_keys('loss_actor', 'loss_value')
        >>> loss_actor, loss_value = loss(
        ...     observation=torch.randn(n_obs),
        ...     action=spec.rand(),
        ...     next_done=torch.zeros(1, dtype=torch.bool),
        ...     next_terminated=torch.zeros(1, dtype=torch.bool),
        ...     next_observation=torch.randn(n_obs),
        ...     next_reward=torch.randn(1))
        >>> loss_actor.backward()

    """

    @dataclass
    class _AcceptedKeys:
        """Maintains default values for all configurable tensordict keys.

        This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
        default values.

        Attributes:
            state_action_value (NestedKey): The input tensordict key where the
                state action value is expected. Will be used for the underlying
                value estimator as value key. Defaults to ``"state_action_value"``.
            priority (NestedKey): The input tensordict key where the target
                priority is written to. Defaults to ``"td_error"``.
            reward (NestedKey): The input tensordict key where the reward is expected.
                Will be used for the underlying value estimator. Defaults to ``"reward"``.
            done (NestedKey): The key in the input TensorDict that indicates
                whether a trajectory is done. Will be used for the underlying value estimator.
                Defaults to ``"done"``.
            terminated (NestedKey): The key in the input TensorDict that indicates
                whether a trajectory is terminated. Will be used for the underlying value estimator.
                Defaults to ``"terminated"``.

        """

        state_action_value: NestedKey = "state_action_value"
        priority: NestedKey = "td_error"
        reward: NestedKey = "reward"
        done: NestedKey = "done"
        terminated: NestedKey = "terminated"

    tensor_keys: _AcceptedKeys
    default_keys = _AcceptedKeys
    default_value_estimator: ValueEstimators = ValueEstimators.TD0
    out_keys = [
        "loss_actor",
        "loss_value",
        "pred_value",
        "target_value",
        "pred_value_max",
        "target_value_max",
    ]

    actor_network: TensorDictModule
    value_network: actor_network
    actor_network_params: TensorDictParams
    value_network_params: TensorDictParams
    target_actor_network_params: TensorDictParams
    target_value_network_params: TensorDictParams

    def __init__(
        self,
        actor_network: TensorDictModule,
        value_network: TensorDictModule,
        *,
        loss_function: str = "l2",
        delay_actor: bool = False,
        delay_value: bool = True,
        gamma: float = None,
        separate_losses: bool = False,
        reduction: str = None,
    ) -> None:
        self._in_keys = None
        if reduction is None:
            reduction = "mean"
        super().__init__()
        self.delay_actor = delay_actor
        self.delay_value = delay_value

        actor_critic = ActorCriticWrapper(actor_network, value_network)
        params = TensorDict.from_module(actor_critic)
        params_meta = params.apply(
            self._make_meta_params, device=torch.device("meta"), filter_empty=False
        )
        with params_meta.to_module(actor_critic):
            self.__dict__["actor_critic"] = deepcopy(actor_critic)

        self.convert_to_functional(
            actor_network,
            "actor_network",
            create_target_params=self.delay_actor,
        )
        if separate_losses:
            # we want to make sure there are no duplicates in the params: the
            # params of critic must be refs to actor if they're shared
            policy_params = list(actor_network.parameters())
        else:
            policy_params = None
        self.convert_to_functional(
            value_network,
            "value_network",
            create_target_params=self.delay_value,
            compare_against=policy_params,
        )
        self.actor_critic.module[0] = self.actor_network
        self.actor_critic.module[1] = self.value_network

        self.actor_in_keys = actor_network.in_keys
        self.value_exclusive_keys = set(self.value_network.in_keys) - (
            set(self.actor_in_keys) | set(self.actor_network.out_keys)
        )

        self.loss_function = loss_function
        self.reduction = reduction
        if gamma is not None:
            raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)

    def _forward_value_estimator_keys(self, **kwargs) -> None:
        if self._value_estimator is not None:
            self._value_estimator.set_keys(
                value=self._tensor_keys.state_action_value,
                reward=self._tensor_keys.reward,
                done=self._tensor_keys.done,
                terminated=self._tensor_keys.terminated,
            )
        self._set_in_keys()

    def _set_in_keys(self):
        in_keys = {
            unravel_key(("next", self.tensor_keys.reward)),
            unravel_key(("next", self.tensor_keys.done)),
            unravel_key(("next", self.tensor_keys.terminated)),
            *self.actor_in_keys,
            *[unravel_key(("next", key)) for key in self.actor_in_keys],
            *self.value_network.in_keys,
            *[unravel_key(("next", key)) for key in self.value_network.in_keys],
        }
        self._in_keys = sorted(in_keys, key=str)

    @property
    def in_keys(self):
        if self._in_keys is None:
            self._set_in_keys()
        return self._in_keys

    @in_keys.setter
    def in_keys(self, values):
        self._in_keys = values


    def _clear_weakrefs(self, *tds):
        if is_compiling():
            # Waiting for weakrefs reconstruct to be supported by compile
            for td in tds:
                if isinstance(td, str):
                    td = getattr(self, td, None)
                if not is_tensor_collection(td):
                    continue
                td.clear_refs_for_compile_()

    @dispatch
    def forward(self, tensordict: TensorDictBase) -> TensorDict:
        """Computes the DDPG losses given a tensordict sampled from the replay buffer.

        This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
            a priority to items in the tensordict.

        Args:
            tensordict (TensorDictBase): a tensordict with keys ["done", "terminated", "reward"] and the in_keys of the actor
                and value networks.

        Returns:
            a tuple of 2 tensors containing the DDPG loss.

        """
        loss_value, metadata = self.loss_value(tensordict)
        loss_actor, metadata_actor = self.loss_actor(tensordict)
        metadata.update(metadata_actor)
        td_out = TensorDict(
            source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
            batch_size=[],
        )
        self._clear_weakrefs(
            tensordict,
            td_out,
            "value_network_params",
            "target_value_network_params",
            "target_actor_network_params",
            "actor_network_params",
        )
        return td_out

    def loss_actor(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, dict]:
        td_copy = tensordict.select(
            *self.actor_in_keys, *self.value_exclusive_keys, strict=False
        ).detach()

        with self.actor_network_params.to_module(self.actor_network):
            td_copy = self.actor_network(td_copy)

        with self._cached_detached_value_params.to_module(self.value_network):
            td_copy = self.value_network(td_copy)

        actions = td_copy.get((self.actor_network.in_keys[0][0], "action"))
        Q_values = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
        grad_Q = torch.autograd.grad(Q_values.sum(), actions, retain_graph=True)[0]

        # Use the cooperative reward from the 'agents' group
        returns = tensordict.get(("agents", "episode_reward")).view(-1)
        phi_factor = compute_phi_cross(returns, agent_params)

        policy_gradient = actions * grad_Q
        loss_actor = -phi_factor.mean() * policy_gradient.mean()

        return _reduce(loss_actor, self.reduction), {}

    def loss_value(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, dict]:
        td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach()
        with self.value_network_params.to_module(self.value_network):
            self.value_network(td_copy)
        pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)

        target_value = self.value_estimator.value_estimate(
            tensordict, target_params=self._cached_target_params
        ).squeeze(-1)
        target_value_CPT = C_transform_cross(target_value, agent_params)

        loss_value = distance_loss(pred_val, target_value_CPT, loss_function=self.loss_function)
        tensordict.set("target_value_CPT", target_value_CPT, inplace=True)

        td_error = (pred_val - target_value_CPT).pow(2).detach()
        tensordict.set(self.tensor_keys.priority, td_error, inplace=True)

        metadata = {
            "td_error": td_error,
            "pred_value": pred_val,
            "target_value_CPT": target_value_CPT,
        }
        return _reduce(loss_value, self.reduction), metadata

    def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
        if value_type is None:
            value_type = self.default_value_estimator
        self.value_type = value_type
        hp = dict(default_value_kwargs(value_type))
        if hasattr(self, "gamma"):
            hp["gamma"] = self.gamma
        hp.update(hyperparams)
        if value_type == ValueEstimators.TD1:
            self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)
        elif value_type == ValueEstimators.TD0:
            self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)
        elif value_type == ValueEstimators.GAE:
            raise NotImplementedError(
                f"Value type {value_type} it not implemented for loss {type(self)}."
            )
        elif value_type == ValueEstimators.TDLambda:
            self._value_estimator = TDLambdaEstimator(
                value_network=self.actor_critic, **hp
            )
        else:
            raise NotImplementedError(f"Unknown value type {value_type}")

        tensor_keys = {
            "value": self.tensor_keys.state_action_value,
            "reward": self.tensor_keys.reward,
            "done": self.tensor_keys.done,
            "terminated": self.tensor_keys.terminated,
        }
        self._value_estimator.set_keys(**tensor_keys)

    @property
    @_cache_values
    def _cached_target_params(self):
        target_params = TensorDict(
            {
                "module": {
                    "0": self.target_actor_network_params,
                    "1": self.target_value_network_params,
                }
            },
            batch_size=self.target_actor_network_params.batch_size,
            device=self.target_actor_network_params.device,
        )
        return target_params

    @property
    @_cache_values
    def _cached_detached_value_params(self):
        return self.value_network_params.detach()

### **Approach: Loss Calculation & Optimization**
#### **Defining the Loss Function (`DDPGLoss`)**
- **Uses separate actor and critic losses**:
  - **`actor_network = policies[group]`**: Optimizes agent actions.
  - **`value_network = critics[group]`**: Estimates state-action values.
- **Target Network (`delay_value = True`)**:
  - Uses a **target critic** for more stable learning.
  - **Loss function**: Mean Squared Error (`"l2"`).
- **Key Assignments**:
  - **State-action value**: `(group, "state_action_value")`.
  - **Reward Signal**: `(group, "reward")`.
  - **Termination Handling**: `(group, "done")` and `(group, "terminated")`.
- **TD(0) Estimator**: Uses **Temporal Difference (TD) learning** with discount factor `γ = 0.99`.

#### **Target Network Updates**
- **Soft update mechanism (`SoftUpdate`)**:
  - **Gradually updates target networks** using `τ = 0.005`.
  - Prevents drastic changes, improving stability.

#### **Optimizers**
- **Separate Adam optimizers for actor and critic networks**:
  - **`loss_actor`**: Updates policy parameters.
  - **`loss_value`**: Updates value network parameters.
- **Learning rate (`lr = 3e-4`)** ensures smooth gradient updates.

In [None]:
losses = {}
for group, _agents in env.group_map.items():
    loss_module = CPTDDPGLoss(
        actor_network=policies[group],  # Use the non-explorative policies
        value_network=critics[group],
        delay_value=True,  # Whether to use a target network for the value
        loss_function="l2",
    )
    loss_module.set_keys(
        state_action_value=(group, "state_action_value"),
        reward=(group, "reward"),
        done=(group, "done"),
        terminated=(group, "terminated"),
    )
    loss_module.make_value_estimator(ValueEstimators.TD0, gamma=gamma)

    losses[group] = loss_module

target_updaters = {
    group: SoftUpdate(loss, tau=polyak_tau) for group, loss in losses.items()
}

optimisers = {
    group: {
        "loss_actor": torch.optim.Adam(
            loss.actor_network_params.flatten_keys().values(), lr=lr
        ),
        "loss_value": torch.optim.Adam(
            loss.value_network_params.flatten_keys().values(), lr=lr
        ),
    }
    for group, loss in losses.items()
}

optimizer_behavioral = torch.optim.Adam(adaptive_params.parameters(), lr=1e-4)

In [None]:
def process_batch(batch: TensorDictBase) -> TensorDictBase:
    """
    If the `(group, "terminated")` and `(group, "done")` keys are not present, create them by expanding
    `"terminated"` and `"done"`.
    This is needed to present them with the same shape as the reward to the loss.
    """
    for group in env.group_map.keys():
        keys = list(batch.keys(True, True))
        group_shape = batch.get_item_shape(group)
        nested_done_key = ("next", group, "done")
        nested_terminated_key = ("next", group, "terminated")
        if nested_done_key not in keys:
            batch.set(
                nested_done_key,
                batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
            )
        if nested_terminated_key not in keys:
            batch.set(
                nested_terminated_key,
                batch.get(("next", "terminated"))
                .unsqueeze(-1)
                .expand((*group_shape, 1)),
            )
    return batch

### **Approach: Training Loop & Optimization**
#### **Progress Bar & Logging Setup**
- **Uses `tqdm`** to track training iterations with episode rewards.  
- **Initializes `episode_reward_mean_map`** to store reward trends per agent group.  
- **Creates `train_group_map`** as a copy of `env.group_map`, allowing dynamic updates.

#### **Main Training Loop**
- **Iterates through `collector`** to process training batches.
- **Preprocesses Data (`process_batch`)**:
  - Expands done/terminated keys for proper loss computation.
  - **Excludes data from other groups** to isolate training signals.
  - **Reshapes batch** to align with replay buffer dimensions.
- **Stores Data in Replay Buffer (`replay_buffers[group].extend(group_batch)`)**.

#### **Optimization Steps**
- **Samples batches (`n_optimiser_steps = 100`)** from replay buffer.
- **Computes & Backpropagates Loss**:
  - Extracts actor (`loss_actor`) and critic (`loss_value`) loss.
  - **Clips gradients (`max_grad_norm = 1.0`)** to prevent instability.
  - **Optimizes parameters with Adam**, resetting gradients after each step.
- **Soft Updates (`target_updaters[group].step()`)**:
  - Gradually syncs target networks using `τ = 0.005`.

#### **Adaptive Exploration**
- **Anneals exploration noise (`sigma`)** based on the number of frames processed.

#### **Training Halting Condition**
- **Stops training evaders after `iteration_when_stop_training_evaders = 25`**.

#### **Logging & Progress Tracking**
- **Computes mean episode reward** for each group.
- **Updates `tqdm` progress bar** with latest reward values.

In [None]:
def compute_adaptive_loss(replay_sample, value_estimator, adaptive_params, target_params):
    """
    Compute an auxiliary loss that encourages the CPT-transformed target Q-value to stay close to the standard target.

    Args:
      replay_sample: A sample batch from your replay buffer.
      value_estimator: The value estimator (e.g., from your loss module).
      adaptive_params: The adaptive behavioral parameters (the ModuleDict).
      target_params: The target network parameters (e.g., losses["agents"]._cached_target_params).

    Returns:
      A scalar loss value.
    """
    # Compute the standard target value using the provided target parameters
    target_value = value_estimator.value_estimate(replay_sample, target_params=target_params).squeeze(-1)

    # Compute the CPT-transformed target using the adaptive parameters
    target_value_CPT = C_transform_cross(target_value, adaptive_params)

    # Define the loss as the mean squared difference between the transformed and standard targets
    adaptive_loss = torch.mean((target_value_CPT - target_value) ** 2)
    return adaptive_loss



In [None]:
pbar = tqdm(
    total=n_iters,
    desc=", ".join(
        [f"episode_reward_mean_{group} = 0" for group in env.group_map.keys()]
    ),
)
episode_reward_mean_map = {group: [] for group in env.group_map.keys()}
train_group_map = copy.deepcopy(env.group_map)


adaptive_update_frequency = 10  # update adaptive parameters every 10 iterations
freeze_adaptive_until = 20  # don't update adaptive parameters until after 20 iterations
scale_factor = 1e-3  # scaling for adaptive loss
reg_lambda = 1e-3  # regularization coefficient

# (Assume initial_params is defined right after adaptive_params creation)
# For example:
initial_params = {
    agent_id: {name: param.clone().detach() for name, param in module.get_params().items()}
    for agent_id, module in adaptive_params.items()
}



for iteration, batch in enumerate(collector):
    current_frames = batch.numel()
    batch = process_batch(batch)

    for group in train_group_map.keys():
        group_batch = batch.exclude(
            *[
                key
                for _group in env.group_map.keys()
                if _group != group
                for key in [_group, ("next", _group)]
            ]
        )
        group_batch = group_batch.reshape(-1)
        replay_buffers[group].extend(group_batch)

        for _ in range(n_optimiser_steps):
            subdata = replay_buffers[group].sample()
            loss_vals = losses[group](subdata)
            for loss_name in ["loss_actor", "loss_value"]:
                loss_value = loss_vals[loss_name]
                optimiser = optimisers[group][loss_name]
                loss_value.backward()
                torch.nn.utils.clip_grad_norm_(optimiser.param_groups[0]["params"], max_grad_norm)
                optimiser.step()
                optimiser.zero_grad()
            target_updaters[group].step()
            exploration_policies[group][-1].step(current_frames)

    # Decoupled update for adaptive parameters every adaptive_update_frequency iterations,
    # but only after freeze_adaptive_until iterations.
    if iteration > freeze_adaptive_until and iteration % adaptive_update_frequency == 0:
        subdata = replay_buffers["agents"].sample()  # using group "agents"
        # Compute standard target
        target_value = losses["agents"]._value_estimator.value_estimate(
            subdata, target_params=losses["agents"]._cached_target_params
        ).squeeze(-1)
        # Compute CPT-transformed target using adaptive parameters
        target_value_CPT = C_transform_cross(target_value, adaptive_params)
        # Compute base adaptive loss
        base_adaptive_loss = torch.mean((target_value_CPT - target_value) ** 2)

        # Add regularization term for each adaptive parameter
        reg_loss = 0.0
        for agent_id, module in adaptive_params.items():
            params = module.get_params()
            for name, param in params.items():
                reg_loss += torch.mean((param - initial_params[agent_id][name]) ** 2)

        adaptive_loss = scale_factor * base_adaptive_loss + reg_lambda * reg_loss

        print(f"Iteration {iteration}: Adaptive loss = {adaptive_loss.item()}")
        adaptive_loss.backward()
        torch.nn.utils.clip_grad_norm_(adaptive_params.parameters(), max_norm=1.0)
        optimizer_behavioral.step()
        optimizer_behavioral.zero_grad()

    # Optionally, monitor adaptive parameters here
    for agent_id, module in adaptive_params.items():
        params = module.get_params()
        for name, param in params.items():
            if param.grad is not None:
                print(f"Iter {iteration} - {agent_id} {name}: {param.item():.6f}, grad mean: {param.grad.abs().mean().item():.6f}")

    # Logging of episode rewards, etc.
    for group in env.group_map.keys():
        episode_reward_mean = (
            batch.get(("next", group, "episode_reward"))[
                batch.get(("next", group, "done"))
            ]
            .mean()
            .item()
        )
        episode_reward_mean_map[group].append(episode_reward_mean)

    pbar.set_description(
        ", ".join(
            [
                f"episode_reward_mean_{group} = {episode_reward_mean_map[group][-1]}"
                for group in env.group_map.keys()
            ]
        ),
        refresh=False,
    )
    pbar.update()



episode_reward_mean_agents = -416.9090881347656:  16%|█▌        | 32/200 [07:29<39:18, 14.04s/it]
episode_reward_mean_agents = -416.01910400390625:   0%|          | 1/200 [00:05<17:20,  5.23s/it]

Iter 0 - agent_0 alpha: 0.700000, grad mean: nan
Iter 0 - agent_0 lam: 2.500000, grad mean: 104.850670
Iter 0 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 0 - agent_0 w_minus_prime_const: 0.800000, grad mean: 327.658020
Iter 0 - agent_1 alpha: 0.650000, grad mean: nan
Iter 0 - agent_1 lam: 2.800000, grad mean: 124.319458
Iter 0 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 0 - agent_1 w_minus_prime_const: 0.750000, grad mean: 464.126251


episode_reward_mean_agents = -535.8214721679688:   1%|          | 2/200 [00:09<16:05,  4.88s/it] 

Iter 1 - agent_0 alpha: 0.700000, grad mean: nan
Iter 1 - agent_0 lam: 2.500000, grad mean: 151.585297
Iter 1 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 1 - agent_0 w_minus_prime_const: 0.800000, grad mean: 473.704620
Iter 1 - agent_1 alpha: 0.650000, grad mean: nan
Iter 1 - agent_1 lam: 2.800000, grad mean: 36.957024
Iter 1 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 1 - agent_1 w_minus_prime_const: 0.750000, grad mean: 137.972549


episode_reward_mean_agents = -596.8135986328125:   2%|▏         | 3/200 [00:15<16:26,  5.01s/it]

Iter 2 - agent_0 alpha: 0.700000, grad mean: nan
Iter 2 - agent_0 lam: 2.500000, grad mean: 812.003113
Iter 2 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 2 - agent_0 w_minus_prime_const: 0.800000, grad mean: 2537.511475
Iter 2 - agent_1 alpha: 0.650000, grad mean: nan
Iter 2 - agent_1 lam: 2.800000, grad mean: 500.255920
Iter 2 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 2 - agent_1 w_minus_prime_const: 0.750000, grad mean: 1867.621094


episode_reward_mean_agents = -676.6171264648438:   2%|▏         | 4/200 [00:19<15:53,  4.87s/it]

Iter 3 - agent_0 alpha: 0.700000, grad mean: nan
Iter 3 - agent_0 lam: 2.500000, grad mean: 1481.556152
Iter 3 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 3 - agent_0 w_minus_prime_const: 0.800000, grad mean: 4629.863281
Iter 3 - agent_1 alpha: 0.650000, grad mean: nan
Iter 3 - agent_1 lam: 2.800000, grad mean: 972.214966
Iter 3 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 3 - agent_1 w_minus_prime_const: 0.750000, grad mean: 3629.601562


episode_reward_mean_agents = -596.9724731445312:   2%|▎         | 5/200 [00:24<15:38,  4.81s/it]

Iter 4 - agent_0 alpha: 0.700000, grad mean: nan
Iter 4 - agent_0 lam: 2.500000, grad mean: 2318.615967
Iter 4 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 4 - agent_0 w_minus_prime_const: 0.800000, grad mean: 7245.676270
Iter 4 - agent_1 alpha: 0.650000, grad mean: nan
Iter 4 - agent_1 lam: 2.800000, grad mean: 1563.941772
Iter 4 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 4 - agent_1 w_minus_prime_const: 0.750000, grad mean: 5838.715820


episode_reward_mean_agents = -615.2721557617188:   3%|▎         | 6/200 [00:29<15:46,  4.88s/it]

Iter 5 - agent_0 alpha: 0.700000, grad mean: nan
Iter 5 - agent_0 lam: 2.500000, grad mean: 3182.529053
Iter 5 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 5 - agent_0 w_minus_prime_const: 0.800000, grad mean: 9945.405273
Iter 5 - agent_1 alpha: 0.650000, grad mean: nan
Iter 5 - agent_1 lam: 2.800000, grad mean: 2173.311279
Iter 5 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 5 - agent_1 w_minus_prime_const: 0.750000, grad mean: 8113.694336


episode_reward_mean_agents = -530.4470825195312:   4%|▎         | 7/200 [00:34<15:25,  4.80s/it]

Iter 6 - agent_0 alpha: 0.700000, grad mean: nan
Iter 6 - agent_0 lam: 2.500000, grad mean: 4049.897461
Iter 6 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 6 - agent_0 w_minus_prime_const: 0.800000, grad mean: 12655.936523
Iter 6 - agent_1 alpha: 0.650000, grad mean: nan
Iter 6 - agent_1 lam: 2.800000, grad mean: 2786.605957
Iter 6 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 6 - agent_1 w_minus_prime_const: 0.750000, grad mean: 10403.331055


episode_reward_mean_agents = -575.3156127929688:   4%|▍         | 8/200 [00:39<15:45,  4.92s/it]

Iter 7 - agent_0 alpha: 0.700000, grad mean: nan
Iter 7 - agent_0 lam: 2.500000, grad mean: 4889.194824
Iter 7 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 7 - agent_0 w_minus_prime_const: 0.800000, grad mean: 15278.740234
Iter 7 - agent_1 alpha: 0.650000, grad mean: nan
Iter 7 - agent_1 lam: 2.800000, grad mean: 3379.769775
Iter 7 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 7 - agent_1 w_minus_prime_const: 0.750000, grad mean: 12617.814453


episode_reward_mean_agents = -524.67724609375:   4%|▍         | 9/200 [00:43<15:26,  4.85s/it]  

Iter 8 - agent_0 alpha: 0.700000, grad mean: nan
Iter 8 - agent_0 lam: 2.500000, grad mean: 5842.537598
Iter 8 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 8 - agent_0 w_minus_prime_const: 0.800000, grad mean: 18257.935547
Iter 8 - agent_1 alpha: 0.650000, grad mean: nan
Iter 8 - agent_1 lam: 2.800000, grad mean: 4053.875000
Iter 8 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 8 - agent_1 w_minus_prime_const: 0.750000, grad mean: 15134.481445


episode_reward_mean_agents = -528.2385864257812:   5%|▌         | 10/200 [00:48<15:28,  4.89s/it]

Iter 9 - agent_0 alpha: 0.700000, grad mean: nan
Iter 9 - agent_0 lam: 2.500000, grad mean: 6826.064453
Iter 9 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 9 - agent_0 w_minus_prime_const: 0.800000, grad mean: 21331.460938
Iter 9 - agent_1 alpha: 0.650000, grad mean: nan
Iter 9 - agent_1 lam: 2.800000, grad mean: 4750.669434
Iter 9 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 9 - agent_1 w_minus_prime_const: 0.750000, grad mean: 17735.845703


episode_reward_mean_agents = -504.0941467285156:   6%|▌         | 11/200 [00:53<15:26,  4.90s/it]

Iter 10 - agent_0 alpha: 0.700000, grad mean: nan
Iter 10 - agent_0 lam: 2.500000, grad mean: 7749.073242
Iter 10 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 10 - agent_0 w_minus_prime_const: 0.800000, grad mean: 24215.837891
Iter 10 - agent_1 alpha: 0.650000, grad mean: nan
Iter 10 - agent_1 lam: 2.800000, grad mean: 5404.242188
Iter 10 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 10 - agent_1 w_minus_prime_const: 0.750000, grad mean: 20175.849609


episode_reward_mean_agents = -294.3373718261719:   6%|▌         | 12/200 [00:58<15:02,  4.80s/it]

Iter 11 - agent_0 alpha: 0.700000, grad mean: nan
Iter 11 - agent_0 lam: 2.500000, grad mean: 8789.119141
Iter 11 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 11 - agent_0 w_minus_prime_const: 0.800000, grad mean: 27465.988281
Iter 11 - agent_1 alpha: 0.650000, grad mean: nan
Iter 11 - agent_1 lam: 2.800000, grad mean: 6143.787598
Iter 11 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 11 - agent_1 w_minus_prime_const: 0.750000, grad mean: 22936.832031


episode_reward_mean_agents = -323.53173828125:   6%|▋         | 13/200 [01:03<15:17,  4.91s/it]  

Iter 12 - agent_0 alpha: 0.700000, grad mean: nan
Iter 12 - agent_0 lam: 2.500000, grad mean: 9821.275391
Iter 12 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 12 - agent_0 w_minus_prime_const: 0.800000, grad mean: 30691.453125
Iter 12 - agent_1 alpha: 0.650000, grad mean: nan
Iter 12 - agent_1 lam: 2.800000, grad mean: 6877.970215
Iter 12 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 12 - agent_1 w_minus_prime_const: 0.750000, grad mean: 25677.773438


episode_reward_mean_agents = -348.01397705078125:   7%|▋         | 14/200 [01:08<15:02,  4.85s/it]

Iter 13 - agent_0 alpha: 0.700000, grad mean: nan
Iter 13 - agent_0 lam: 2.500000, grad mean: 10803.907227
Iter 13 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 13 - agent_0 w_minus_prime_const: 0.800000, grad mean: 33762.187500
Iter 13 - agent_1 alpha: 0.650000, grad mean: nan
Iter 13 - agent_1 lam: 2.800000, grad mean: 7576.093750
Iter 13 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 13 - agent_1 w_minus_prime_const: 0.750000, grad mean: 28284.123047


episode_reward_mean_agents = -303.9656677246094:   8%|▊         | 15/200 [01:13<14:56,  4.85s/it] 

Iter 14 - agent_0 alpha: 0.700000, grad mean: nan
Iter 14 - agent_0 lam: 2.500000, grad mean: 11812.491211
Iter 14 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 14 - agent_0 w_minus_prime_const: 0.800000, grad mean: 36914.015625
Iter 14 - agent_1 alpha: 0.650000, grad mean: nan
Iter 14 - agent_1 lam: 2.800000, grad mean: 8294.081055
Iter 14 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 14 - agent_1 w_minus_prime_const: 0.750000, grad mean: 30964.609375


episode_reward_mean_agents = -314.6786193847656:   8%|▊         | 16/200 [01:18<14:59,  4.89s/it]

Iter 15 - agent_0 alpha: 0.700000, grad mean: nan
Iter 15 - agent_0 lam: 2.500000, grad mean: 12831.309570
Iter 15 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 15 - agent_0 w_minus_prime_const: 0.800000, grad mean: 40097.824219
Iter 15 - agent_1 alpha: 0.650000, grad mean: nan
Iter 15 - agent_1 lam: 2.800000, grad mean: 9018.528320
Iter 15 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 15 - agent_1 w_minus_prime_const: 0.750000, grad mean: 33669.214844


episode_reward_mean_agents = -315.873291015625:   8%|▊         | 17/200 [01:22<14:37,  4.79s/it] 

Iter 16 - agent_0 alpha: 0.700000, grad mean: nan
Iter 16 - agent_0 lam: 2.500000, grad mean: 13808.826172
Iter 16 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 16 - agent_0 w_minus_prime_const: 0.800000, grad mean: 43152.531250
Iter 16 - agent_1 alpha: 0.650000, grad mean: nan
Iter 16 - agent_1 lam: 2.800000, grad mean: 9716.123047
Iter 16 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 16 - agent_1 w_minus_prime_const: 0.750000, grad mean: 36273.601562


episode_reward_mean_agents = -312.5508728027344:   9%|▉         | 18/200 [01:27<14:56,  4.93s/it]

Iter 17 - agent_0 alpha: 0.700000, grad mean: nan
Iter 17 - agent_0 lam: 2.500000, grad mean: 14781.190430
Iter 17 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 17 - agent_0 w_minus_prime_const: 0.800000, grad mean: 46191.140625
Iter 17 - agent_1 alpha: 0.650000, grad mean: nan
Iter 17 - agent_1 lam: 2.800000, grad mean: 10410.686523
Iter 17 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 17 - agent_1 w_minus_prime_const: 0.750000, grad mean: 38866.597656


episode_reward_mean_agents = -366.3407897949219:  10%|▉         | 19/200 [01:32<14:37,  4.85s/it]

Iter 18 - agent_0 alpha: 0.700000, grad mean: nan
Iter 18 - agent_0 lam: 2.500000, grad mean: 15779.561523
Iter 18 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 18 - agent_0 w_minus_prime_const: 0.800000, grad mean: 49311.062500
Iter 18 - agent_1 alpha: 0.650000, grad mean: nan
Iter 18 - agent_1 lam: 2.800000, grad mean: 11123.583984
Iter 18 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 18 - agent_1 w_minus_prime_const: 0.750000, grad mean: 41528.121094


episode_reward_mean_agents = -466.1153259277344:  10%|█         | 20/200 [01:37<14:31,  4.84s/it]

Iter 19 - agent_0 alpha: 0.700000, grad mean: nan
Iter 19 - agent_0 lam: 2.500000, grad mean: 16750.369141
Iter 19 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 19 - agent_0 w_minus_prime_const: 0.800000, grad mean: 52344.804688
Iter 19 - agent_1 alpha: 0.650000, grad mean: nan
Iter 19 - agent_1 lam: 2.800000, grad mean: 11817.089844
Iter 19 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 19 - agent_1 w_minus_prime_const: 0.750000, grad mean: 44117.250000


episode_reward_mean_agents = -409.905029296875:  10%|█         | 21/200 [01:42<14:32,  4.87s/it] 

Iter 20 - agent_0 alpha: 0.700000, grad mean: nan
Iter 20 - agent_0 lam: 2.500000, grad mean: 17686.857422
Iter 20 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 20 - agent_0 w_minus_prime_const: 0.800000, grad mean: 55271.320312
Iter 20 - agent_1 alpha: 0.650000, grad mean: nan
Iter 20 - agent_1 lam: 2.800000, grad mean: 12486.463867
Iter 20 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 20 - agent_1 w_minus_prime_const: 0.750000, grad mean: 46616.238281


episode_reward_mean_agents = -482.4176940917969:  11%|█         | 22/200 [01:46<14:10,  4.78s/it]

Iter 21 - agent_0 alpha: 0.700000, grad mean: nan
Iter 21 - agent_0 lam: 2.500000, grad mean: 18589.232422
Iter 21 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 21 - agent_0 w_minus_prime_const: 0.800000, grad mean: 58091.238281
Iter 21 - agent_1 alpha: 0.650000, grad mean: nan
Iter 21 - agent_1 lam: 2.800000, grad mean: 13131.492188
Iter 21 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 21 - agent_1 w_minus_prime_const: 0.750000, grad mean: 49024.347656


episode_reward_mean_agents = -488.3270568847656:  12%|█▏        | 23/200 [01:51<14:22,  4.87s/it]

Iter 22 - agent_0 alpha: 0.700000, grad mean: nan
Iter 22 - agent_0 lam: 2.500000, grad mean: 19482.066406
Iter 22 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 22 - agent_0 w_minus_prime_const: 0.800000, grad mean: 60881.367188
Iter 22 - agent_1 alpha: 0.650000, grad mean: nan
Iter 22 - agent_1 lam: 2.800000, grad mean: 13769.115234
Iter 22 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 22 - agent_1 w_minus_prime_const: 0.750000, grad mean: 51404.832031


episode_reward_mean_agents = -423.0145568847656:  12%|█▏        | 24/200 [01:56<14:07,  4.81s/it]

Iter 23 - agent_0 alpha: 0.700000, grad mean: nan
Iter 23 - agent_0 lam: 2.500000, grad mean: 20387.335938
Iter 23 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 23 - agent_0 w_minus_prime_const: 0.800000, grad mean: 63710.347656
Iter 23 - agent_1 alpha: 0.650000, grad mean: nan
Iter 23 - agent_1 lam: 2.800000, grad mean: 14415.757812
Iter 23 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 23 - agent_1 w_minus_prime_const: 0.750000, grad mean: 53818.929688


episode_reward_mean_agents = -358.1033935546875:  12%|█▎        | 25/200 [02:01<14:00,  4.81s/it]

Iter 24 - agent_0 alpha: 0.700000, grad mean: nan
Iter 24 - agent_0 lam: 2.500000, grad mean: 21247.089844
Iter 24 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 24 - agent_0 w_minus_prime_const: 0.800000, grad mean: 66397.046875
Iter 24 - agent_1 alpha: 0.650000, grad mean: nan
Iter 24 - agent_1 lam: 2.800000, grad mean: 15029.782227
Iter 24 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 24 - agent_1 w_minus_prime_const: 0.750000, grad mean: 56111.292969


episode_reward_mean_agents = -371.4198303222656:  13%|█▎        | 26/200 [02:06<14:07,  4.87s/it]

Iter 25 - agent_0 alpha: 0.700000, grad mean: nan
Iter 25 - agent_0 lam: 2.500000, grad mean: 22118.900391
Iter 25 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 25 - agent_0 w_minus_prime_const: 0.800000, grad mean: 69121.437500
Iter 25 - agent_1 alpha: 0.650000, grad mean: nan
Iter 25 - agent_1 lam: 2.800000, grad mean: 15654.313477
Iter 25 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 25 - agent_1 w_minus_prime_const: 0.750000, grad mean: 58442.871094


episode_reward_mean_agents = -351.96612548828125:  14%|█▎        | 27/200 [02:11<13:56,  4.84s/it]

Iter 26 - agent_0 alpha: 0.700000, grad mean: nan
Iter 26 - agent_0 lam: 2.500000, grad mean: 22970.916016
Iter 26 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 26 - agent_0 w_minus_prime_const: 0.800000, grad mean: 71783.945312
Iter 26 - agent_1 alpha: 0.650000, grad mean: nan
Iter 26 - agent_1 lam: 2.800000, grad mean: 16263.656250
Iter 26 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 26 - agent_1 w_minus_prime_const: 0.750000, grad mean: 60717.777344


episode_reward_mean_agents = -416.7789001464844:  14%|█▍        | 28/200 [02:16<14:28,  5.05s/it] 

Iter 27 - agent_0 alpha: 0.700000, grad mean: nan
Iter 27 - agent_0 lam: 2.500000, grad mean: 23783.185547
Iter 27 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 27 - agent_0 w_minus_prime_const: 0.800000, grad mean: 74322.281250
Iter 27 - agent_1 alpha: 0.650000, grad mean: nan
Iter 27 - agent_1 lam: 2.800000, grad mean: 16845.644531
Iter 27 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 27 - agent_1 w_minus_prime_const: 0.750000, grad mean: 62890.531250


episode_reward_mean_agents = -381.2997131347656:  14%|█▍        | 29/200 [02:21<13:58,  4.90s/it]

Iter 28 - agent_0 alpha: 0.700000, grad mean: nan
Iter 28 - agent_0 lam: 2.500000, grad mean: 24580.919922
Iter 28 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 28 - agent_0 w_minus_prime_const: 0.800000, grad mean: 76815.140625
Iter 28 - agent_1 alpha: 0.650000, grad mean: nan
Iter 28 - agent_1 lam: 2.800000, grad mean: 17416.121094
Iter 28 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 28 - agent_1 w_minus_prime_const: 0.750000, grad mean: 65020.242188


episode_reward_mean_agents = -415.9543151855469:  15%|█▌        | 30/200 [02:26<13:46,  4.86s/it]

Iter 29 - agent_0 alpha: 0.700000, grad mean: nan
Iter 29 - agent_0 lam: 2.500000, grad mean: 25352.599609
Iter 29 - agent_0 w_plus_prime_const: 0.200000, grad mean: nan
Iter 29 - agent_0 w_minus_prime_const: 0.800000, grad mean: 79226.640625
Iter 29 - agent_1 alpha: 0.650000, grad mean: nan
Iter 29 - agent_1 lam: 2.800000, grad mean: 17967.779297
Iter 29 - agent_1 w_plus_prime_const: 0.250000, grad mean: nan
Iter 29 - agent_1 w_minus_prime_const: 0.750000, grad mean: 67079.726562


episode_reward_mean_agents = -385.1997985839844:  16%|█▌        | 31/200 [02:31<13:47,  4.90s/it]

Iteration 30: Adaptive loss = 0.01808582805097103


episode_reward_mean_agents = -416.9090881347656:  16%|█▌        | 32/200 [02:35<13:32,  4.84s/it]

Iter 31 - agent_0 alpha: nan, grad mean: nan
Iter 31 - agent_0 lam: nan, grad mean: nan
Iter 31 - agent_0 w_plus_prime_const: nan, grad mean: nan
Iter 31 - agent_0 w_minus_prime_const: nan, grad mean: nan
Iter 31 - agent_1 alpha: nan, grad mean: nan
Iter 31 - agent_1 lam: nan, grad mean: nan
Iter 31 - agent_1 w_plus_prime_const: nan, grad mean: nan
Iter 31 - agent_1 w_minus_prime_const: nan, grad mean: nan


AssertionError: 

This is our "test" to make sure our agents are trainng, we see after the agent stops training the adversaries rewards are increasing and then while it is trainng their rewards both go to 0

In [None]:
fig, axs = plt.subplots(1, 1)
for i, group in enumerate(env.group_map.keys()):
    axs.plot(episode_reward_mean_map[group], label=f"Episode reward mean {group}")
    axs.set_ylabel("Reward")
    axs.legend()
axs.set_xlabel("Training iterations")
plt.show()
print(env.group_map.keys())

In [None]:
if use_vmas and not is_sphinx:
    # Replace tmpdir with any desired path where the video should be saved
    with tempfile.TemporaryDirectory() as tmpdir:
        video_logger = CSVLogger("vmas_logs", tmpdir, video_format="mp4")
        print("Creating rendering env")
        env_with_render = TransformedEnv(env.base_env, env.transform.clone())
        env_with_render = env_with_render.append_transform(
            PixelRenderTransform(
                out_keys=["pixels"],
                # the np.ndarray has a negative stride and needs to be copied before being cast to a tensor
                preproc=lambda x: x.copy(),
                as_non_tensor=True,
                # asking for array rather than on-screen rendering
                mode="rgb_array",
            )
        )
        env_with_render = env_with_render.append_transform(
            VideoRecorder(logger=video_logger, tag="vmas_rendered")
        )
        with set_exploration_type(ExplorationType.DETERMINISTIC):
            print("Rendering rollout...")
            env_with_render.rollout(100, policy=agents_exploration_policy)
        print("Saving the video...")
        env_with_render.transform.dump()
        print("Saved! Saved directory tree:")
        video_logger.print_log_dir()

In [None]:
import os

# Define a permanent directory path (e.g., "local_videos")
local_dir = "local_videos"
os.makedirs(local_dir, exist_ok=True)

# Use the permanent directory instead of a temporary one
video_logger = CSVLogger("vmas_logs", local_dir, video_format="mp4")
print("Creating rendering env")
env_with_render = TransformedEnv(env.base_env, env.transform.clone())
env_with_render = env_with_render.append_transform(
    PixelRenderTransform(
        out_keys=["pixels"],
        # the np.ndarray has a negative stride and needs to be copied before being cast to a tensor
        preproc=lambda x: x.copy(),
        as_non_tensor=True,
        # asking for array rather than on-screen rendering
        mode="rgb_array",
    )
)
env_with_render = env_with_render.append_transform(
    VideoRecorder(logger=video_logger, tag="vmas_rendered")
)
with set_exploration_type(ExplorationType.DETERMINISTIC):
    print("Rendering rollout...")
    env_with_render.rollout(100, policy=agents_exploration_policy)
print("Saving the video...")
env_with_render.transform.dump()
print("Saved! Saved directory tree:")
video_logger.print_log_dir()