


## Problem Setup
This example demonstrates a multi-agent deep deterministic policy gradient (MADDPG) approach to a competitive environment where chasers aim to catch evaders. Each group of agents (chasers and evaders) has its own policy and value networks, trained either independently or in a mixed cooperative-competitive setting. It serves as our control when we try and integrate prospect theory into the policy gradient, seeing if we can get different results than the reward graphs below. Code is based from [https://pytorch.org/rl/0.6/tutorials/multiagent_competitive_ddpg.html]

## Clear Problem Statement
Train two chaser agents to minimize the evader’s cumulative reward while simultaneously training the evader agent to maximize its own cumulative reward. The environment runs for a fixed number of steps, and training can be halted for certain agents at a chosen iteration.

## Mathematical Formulation
- **Agent Policies**: $\pi_i(\mathbf{o_i}; \theta_i)$ map observations $\mathbf{o_i}$ to continuous actions.
- **Value Function**: $Q_i(\mathbf{o}, \mathbf{a}; \phi_i)$ estimates future return given all agents’ actions $\mathbf{a}$ and observations $\mathbf{o}$.
- **Loss Functions**: DDPG losses incorporate actor and critic objectives, ensuring that each agent maximizes expected returns while considering centralized training and decentralized execution.
- **Updates**: Soft updates are performed on target networks with \(\tau\) for both the policy and value functions.

## Data Requirements
- Episodes of agent interactions, collected with exploration strategies (e.g., Gaussian noise).
- Replay buffers per group for sampled training batches containing states, actions, rewards, and next states.

## Success Metrics
- Mean episode reward for each group (chasers and evaders), typically measured and plotted over training iterations.
- Convergence or stabilization of the reward signal, indicating improved policy performance.


In [2]:
!pip install torchrl==0.6.0 pettingzoo gymnasium torch tqdm

Collecting torchrl==0.6.0
  Using cached torchrl-0.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (39 kB)
Collecting tensordict>=0.6.0 (from torchrl==0.6.0)
  Using cached tensordict-0.7.2-cp311-cp311-manylinux1_x86_64.whl.metadata (9.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1

In [3]:
!pip3 install vmas
!pip3 install pettingzoo[mpe]==1.24.3
!pip3 install tqdm

Collecting vmas
  Using cached vmas-1.5.0.tar.gz (217 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyglet<=1.5.27 (from vmas)
  Using cached pyglet-1.5.27-py3-none-any.whl.metadata (7.6 kB)
Using cached pyglet-1.5.27-py3-none-any.whl (1.1 MB)
Building wheels for collected packages: vmas
  Building wheel for vmas (setup.py) ... [?25l[?25hdone
  Created wheel for vmas: filename=vmas-1.5.0-py3-none-any.whl size=257528 sha256=4946d248ffe6a59d708d6eda67342e8779248f0e8aff3f2b25bc5b00ba5e36ce
  Stored in directory: /root/.cache/pip/wheels/f1/ab/3b/f1eb0befe556b53a3f78b4780554c29dca0cb781fb664f6a81
Successfully built vmas
Installing collected packages: pyglet, vmas
Successfully installed pyglet-1.5.27 vmas-1.5.0


### **Approach: Importing Required Libraries**
This section imports essential modules for implementing MADDPG:  
- **PyTorch** for deep learning operations.  
- **`torchrl` modules** for multi-agent reinforcement learning, including environments, policies, collectors, and replay buffers.  
- **`tensordict`** for structured tensor operations.  
- **Matplotlib** for visualization.  
- **`tqdm`** for progress tracking.

In [43]:
import copy
import tempfile

import torch

from matplotlib import pyplot as plt
from tensordict import TensorDictBase

from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import multiprocessing

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, RandomSampler, ReplayBuffer

from torchrl.envs import (
    check_env_specs,
    ExplorationType,
    PettingZooEnv,
    RewardSum,
    set_exploration_type,
    TransformedEnv,
    VmasEnv,
)

from torchrl.modules import (
    AdditiveGaussianModule,
    MultiAgentMLP,
    ProbabilisticActor,
    TanhDelta,
)

from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators

from torchrl.record import CSVLogger, PixelRenderTransform, VideoRecorder

from tqdm import tqdm

try:
    is_sphinx = __sphinx_build__
except NameError:
    is_sphinx = False

In [44]:
import copy
import tempfile

import torch

from matplotlib import pyplot as plt
from tensordict import TensorDictBase, is_tensor_collection

from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import multiprocessing

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, RandomSampler, ReplayBuffer

from torchrl.envs import (
    check_env_specs,
    ExplorationType,
    PettingZooEnv,
    RewardSum,
    set_exploration_type,
    TransformedEnv,
    VmasEnv,
)

from torchrl.modules import (
    AdditiveGaussianModule,
    MultiAgentMLP,
    ProbabilisticActor,
    TanhDelta,
)

from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators

from torchrl.record import CSVLogger, PixelRenderTransform, VideoRecorder

from tqdm import tqdm

try:
    is_sphinx = __sphinx_build__
except NameError:
    is_sphinx = False


try:
    from torch.compiler import is_compiling
except ImportError:
    from torch._dynamo import is_compiling

In [45]:
import functools, numpy as np, torch
from torch import nn
from gymnasium.spaces import Box
from gymnasium.utils import seeding

# PettingZoo API
from pettingzoo import ParallelEnv
from pettingzoo.utils import parallel_to_aec, wrappers

# TorchRL wrappers & collector
from torchrl.envs.libs.pettingzoo import PettingZooWrapper
from torchrl.envs.utils            import set_exploration_type, ExplorationType
from torchrl.collectors            import SyncDataCollector

# DDPG ingredients
from torchrl.objectives.ddpg import DDPGLoss
from torchrl.modules           import Actor, ValueOperator

In [46]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [47]:
def env(render_mode=None):
    """AEC‑style view of our ParallelEnv."""
    aec = raw_env(render_mode=render_mode)
    if render_mode == "ansi":
        aec = wrappers.CaptureStdoutWrapper(aec)
    aec = wrappers.AssertOutOfBoundsWrapper(aec)
    aec = wrappers.OrderEnforcingWrapper(aec)
    return aec

def raw_env(render_mode=None):
    """Convert ParallelEnv → AEC via parallel_to_aec."""
    par = parallel_env(render_mode=render_mode)
    return parallel_to_aec(par)

class parallel_env(ParallelEnv):
    metadata = {"render_modes": ["human"], "name": "first_price_auction_v0"}

    def __init__(self, num_agents=3, max_bid=1.0, render_mode=None):
        super().__init__()
        self.max_bid     = max_bid
        self.render_mode = render_mode
        self.possible_agents = [f"agent_{i}" for i in range(num_agents)]
        self.agents           = []
        # obs = your private valuation ∈[0,1]
        self.observation_spaces = {
            a: Box(0.0, 1.0, (1,), dtype=np.float32)
            for a in self.possible_agents
        }
        # action = bid ∈[0,max_bid]
        self.action_spaces = {
            a: Box(0.0, float(max_bid), (1,), dtype=np.float32)
            for a in self.possible_agents
        }
        self.valuations = {}

    @functools.lru_cache(None)
    def observation_space(self, agent):
        return self.observation_spaces[agent]

    @functools.lru_cache(None)
    def action_space(self, agent):
        return self.action_spaces[agent]

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.np_random, _ = seeding.np_random(seed)
        # all agents active
        self.agents = self.possible_agents[:]
        # sample private valuations U[0,1]
        self.valuations = {a: float(self.np_random.random()) for a in self.agents}
        obs   = {a: np.array([self.valuations[a]], dtype=np.float32) for a in self.agents}
        infos = {a: {} for a in self.agents}
        return obs, infos

    def step(self, actions):
        if not actions:  # no agents → done
            self.agents = []
            return {}, {}, {}, {}, {}

        bids = {a: float(actions[a][0]) for a in actions}
        # highest bid wins (ties by agent order)
        winner = max(self.possible_agents, key=lambda a: bids[a])
        # payoffs
        rewards = {
            a: (self.valuations[a] - bids[a] if a == winner else 0.0)
            for a in self.possible_agents
        }
        terminated = {a: True  for a in self.possible_agents}
        truncated  = {a: False for a in self.possible_agents}
        next_obs    = {a: np.zeros((1,), dtype=np.float32) for a in self.possible_agents}
        infos       = {a: {} for a in self.possible_agents}

        # end episode
        self.agents = []
        return next_obs, rewards, terminated, truncated, infos

    def render(self):
        if self.render_mode == "human":
            print("Valuations:", self.valuations)

    def close(self):
        pass

### **Approach: Environment Setup & Hyperparameters**
- **Seed & Device**: Sets the random seed for reproducibility and selects the appropriate device (GPU if available, otherwise CPU).  
- **Sampling**: Defines frames collected per batch (`1,000`), total iterations (`50`), and total frames (`50,000`).  
- **Training Control**: Stops evader training at `iteration_when_stop_training_evaders = 25`.  
- **Replay Buffer**: Stores up to `1M` frames for experience replay.  
- **Training Parameters**:  
  - **Optimization**: `100` updates per iteration, batch size of `128`.  
  - **Learning Rate**: `3e-4`, gradient clipping at `1.0`.  
- **DDPG-Specific**: Uses discount factor (`γ = 0.99`) and soft update parameter (`τ = 0.005`).

In [48]:
# Seed
seed = 0
torch.manual_seed(seed)

# Devices
is_fork = multiprocessing.get_start_method() == "fork"
# Sampling
frames_per_batch = 1_000  # Number of team frames collected per sampling iteration
n_iters = 200  # Number of sampling and training iterations
total_frames = frames_per_batch * n_iters


# Replay buffer
memory_size = 1_000_000  # The replay buffer of each group can store this many frames

# Training
n_optimiser_steps = 100  # Number of optimization steps per training iteration
train_batch_size = 128  # Number of frames trained in each optimiser step
lr = 3e-4  # Learning rate
max_grad_norm = 1.0  # Maximum norm for the gradients

# DDPG
gamma = 0.99  # Discount factor
polyak_tau = 0.005  # Tau for the soft-update of the target network

### **Approach: Environment Configuration**
- **Max Steps**: Each episode runs for `100` steps.  
- **Agents & Obstacles**: `2` chasers, `1` evader, and `2` obstacles.  
- **VMAS for Performance**:  
  - If `use_vmas = True`, uses `VmasEnv` for efficient vectorized multi-agent simulation.  
  - Otherwise, defaults to `PettingZooEnv` (parallel mode) for `simple_tag_v3`.  
- **Vectorization**: `num_vmas_envs = frames_per_batch / max_steps` ensures efficient frame collection.

In [49]:
max_steps = 100  # Environment steps before done

n_agents = 2
n_landmarks = 1

use_vmas = False  # Set this to True for a great performance speedup

if not use_vmas:
    par = parallel_env(num_agents=3, max_bid=1.0, render_mode=None)

    # 2) wrap it directly with TorchRL's PettingZooWrapper
    base_env = PettingZooWrapper(
        env               = par,
        return_state      = False,
        group_map         = None,
        use_mask          = False,
        categorical_actions = False,
        seed              = 42,
        device            = device,
        done_on_any       = True,
    )
else:
    num_vmas_envs = (
        frames_per_batch // max_steps
    )
    base_env = VmasEnv(
        scenario="simple_spread",
        num_envs=num_vmas_envs,
        continuous_actions=True,
        max_steps=max_steps,
        local_ratio=0.5,
        device=device,
        seed=seed,
        n_agents = n_agents
    )

In [50]:
print(f"group_map: {base_env.group_map}")

group_map: {'agent': ['agent_0', 'agent_1', 'agent_2']}


In [51]:
print("action_spec:", base_env.full_action_spec)
print("reward_spec:", base_env.full_reward_spec)
print("done_spec:", base_env.full_done_spec)
print("observation_spec:", base_env.observation_spec)

action_spec: Composite(
    agent: Composite(
        action: BoundedContinuous(
            shape=torch.Size([3, 1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.float32, contiguous=True)),
            device=cuda:0,
            dtype=torch.float32,
            domain=continuous),
        device=cuda:0,
        shape=torch.Size([3])),
    device=cuda:0,
    shape=torch.Size([]))
reward_spec: Composite(
    agent: Composite(
        reward: UnboundedContinuous(
            shape=torch.Size([3, 1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.float32, contiguous=True)),
            device=cuda:0,
            dtype=torch.float32,
          

In [77]:
print("action_keys:", base_env.action_keys)
print("reward_keys:", base_env.reward_keys)
print("done_keys:", base_env.done_keys)

action_keys: [('agent', 'action')]
reward_keys: [('agent', 'reward')]
done_keys: ['done', 'terminated', 'truncated', ('agent', 'done'), ('agent', 'terminated'), ('agent', 'truncated')]


### **Approach: Environment Transformation**
- **Wraps `base_env` with `TransformedEnv`** to apply reward processing.  
- **`RewardSum` Aggregation**:  
  - Uses `reward_keys` from `base_env` to sum rewards over time.  
  - Resets rewards using `_reset` keys for each agent group.  
- **Purpose**: Ensures proper reward tracking across multi-agent interactions.

In [78]:
env = TransformedEnv(
    base_env,
    RewardSum(
        in_keys=base_env.reward_keys,
    ),
)

In [79]:
check_env_specs(env)

2025-04-19 00:30:29,504 [torchrl][INFO] check_env_specs succeeded!


In [80]:
n_rollout_steps = 5
rollout = env.rollout(n_rollout_steps)
print(f"rollout of {n_rollout_steps} steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)

rollout of 5 steps: TensorDict(
    fields={
        agent: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([1, 3, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                done: Tensor(shape=torch.Size([1, 3, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                episode_reward: Tensor(shape=torch.Size([1, 3, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                observation: Tensor(shape=torch.Size([1, 3, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                terminated: Tensor(shape=torch.Size([1, 3, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                truncated: Tensor(shape=torch.Size([1, 3, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
            batch_size=torch.Size([1, 3]),
            device=cuda:0,
            is_shared=True),
        done: Tensor(shape=torch.Size([1, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        next: TensorDic

### **Approach: Policy Network Setup**
- **Iterates over agent groups** to create independent policies.  
- **Defines `MultiAgentMLP`** for decentralized policies:
  - **Observations & Actions**: Uses `env.observation_spec` and `env.full_action_spec`.
  - **Decentralized Execution**: Each agent acts based on its local observation.
  - **Parameter Sharing**: Controlled by `share_parameters_policy` (set to `True` for efficiency).
  - **Architecture**: 2-layer MLP (`256` neurons per layer, `Tanh` activation).
- **Wraps in `TensorDictModule`**:
  - Reads observations from `TensorDict` and writes action parameters.  
  - Allows structured tensor operations for multi-agent training.

In [81]:
env.group_map.items()

dict_items([('agent', ['agent_0', 'agent_1', 'agent_2'])])

In [82]:
policy_modules = {}
for group, agents in env.group_map.items():
    share_parameters_policy = False  # Can change this based on the group

    policy_net = MultiAgentMLP(
        n_agent_inputs=env.observation_spec[group, "observation"].shape[
            -1
        ],  # n_obs_per_agent
        n_agent_outputs=env.full_action_spec[group, "action"].shape[
            -1
        ],  # n_actions_per_agents
        n_agents=len(agents),  # Number of agents in the group
        centralised=False,  # the policies are decentralised (i.e., each agent will act from its local observation)
        share_params=share_parameters_policy,
        device=device,
        depth=2,
        num_cells=256,
        activation_class=torch.nn.Tanh,
    )

    # Wrap the neural network in a :class:`~tensordict.nn.TensorDictModule`.
    # This is simply a module that will read the ``in_keys`` from a tensordict, feed them to the
    # neural networks, and write the
    # outputs in-place at the ``out_keys``.

    policy_module = TensorDictModule(
        policy_net,
        in_keys=[(group, "observation")],
        out_keys=[(group, "param")],
    )
    policy_module = policy_module.to(device)
    policy_modules[group] = policy_module

### **Approach: Probabilistic Policy Definition**
- **Wraps policy networks (`policy_modules`) in `ProbabilisticActor`** to handle stochastic action sampling.  
- **Uses `TanhDelta` Distribution**:
  - Ensures continuous action outputs stay within predefined bounds (`low`, `high`).  
  - Helps stabilize training by keeping actions constrained.  
- **Input & Output Keys**:
  - Reads action parameters from `policy_modules` (`(group, "param")`).  
  - Outputs final actions (`(group, "action")`).  
- **Log Probabilities Disabled (`return_log_prob=False`)**:  
  - Not needed for deterministic policy updates in DDPG.

In [83]:
policies = {}
for group, _agents in env.group_map.items():
    policy = ProbabilisticActor(
        module=policy_modules[group],
        spec=env.full_action_spec[group, "action"],
        in_keys=[(group, "param")],
        out_keys=[(group, "action")],
        distribution_class=TanhDelta,
        distribution_kwargs={
            "low": env.full_action_spec[group, "action"].space.low,
            "high": env.full_action_spec[group, "action"].space.high,
        },
        return_log_prob=False,
    )
    policy = policy.to(device)
    policies[group] = policy

### **Approach: Exploration Policy with Gaussian Noise**
- **Adds exploration noise to deterministic policies** using `AdditiveGaussianModule`.  
- **Purpose**: Encourages better exploration by injecting Gaussian noise into actions.  
- **Annealing Strategy**:
  - **Starts with `sigma_init = 0.9`** (high noise for exploration).  
  - **Decays to `sigma_end = 0.1`** over `total_frames / 2` steps, reducing noise gradually.  
- **Wrapped in `TensorDictSequential`**:
  - First applies the base policy (`policies[group]`).  
  - Then adds Gaussian noise to the output action (`(group, "action")`).  
- **Ensures Smooth Transition**: High exploration at the start, stabilizing towards exploitation.

In [84]:
exploration_policies = {}
for group, _agents in env.group_map.items():
    exploration_policy = TensorDictSequential(
        policies[group],
        AdditiveGaussianModule(
            spec=policies[group].spec,
            annealing_num_steps=total_frames
            // 2,  # Number of frames after which sigma is sigma_end
            action_key=(group, "action"),
            sigma_init=0.9,  # Initial value of the sigma
            sigma_end=0.1,  # Final value of the sigma
        ),
    )
    exploration_policy = exploration_policy.to(device)
    exploration_policies[group] = exploration_policy

### **Approach: Critic Network for Value Estimation**
- **Defines critic networks for each agent group** to estimate state-action values (\(Q\)-values).  
- **Centralized vs. Decentralized Critic**:
  - **`MADDPG = True`**: Uses a centralized critic (multi-agent).  
  - **`IDDPG = False`**: Uses an independent critic per agent.  
- **Feature Concatenation (`cat_module`)**:
  - Combines agent's observation and action into a single tensor (`(group, "obs_action")`).  
- **Critic Network (`critic_module`)**:
  - Takes concatenated state-action inputs and predicts a **single Q-value per agent**.  
  - Uses a **2-layer MLP (256 neurons per layer, `Tanh` activation)**.  
  - Supports parameter sharing (`share_parameters_critic = True`).  
- **Final Critic Pipeline (`TensorDictSequential`)**:
  - First applies **feature concatenation (`cat_module`)**.  
  - Then passes through **`MultiAgentMLP` for value estimation** (`(group, "state_action_value")`).

In [85]:
critics = {}
for group, agents in env.group_map.items():
    share_parameters_critic = True  # Can change for each group
    MADDPG = True  # IDDPG if False, can change for each group

    # This module applies the lambda function: reading the action and observation entries for the group
    # and concatenating them in a new ``(group, "obs_action")`` entry
    cat_module = TensorDictModule(
        lambda obs, action: torch.cat([obs, action], dim=-1),
        in_keys=[(group, "observation"), (group, "action")],
        out_keys=[(group, "obs_action")],
    )

    critic_module = TensorDictModule(
        module=MultiAgentMLP(
            n_agent_inputs=env.observation_spec[group, "observation"].shape[-1]
            + env.full_action_spec[group, "action"].shape[-1],
            n_agent_outputs=1,  # 1 value per agent
            n_agents=len(agents),
            centralised=MADDPG,
            share_params=share_parameters_critic,
            device=device,
            depth=2,
            num_cells=256,
            activation_class=torch.nn.Tanh,
        ),
        in_keys=[(group, "obs_action")],  # Read ``(group, "obs_action")``
        out_keys=[
            (group, "state_action_value")
        ],  # Write ``(group, "state_action_value")``
    )
    critic_module = critic_module.to(device)

    critics[group] = TensorDictSequential(
        cat_module, critic_module
    )  # Run them in sequence

In [86]:
reset_td = env.reset()
for group, _agents in env.group_map.items():
    print(
        f"Running value and policy for group '{group}':",
        critics[group](policies[group](reset_td)),
    )

Running value and policy for group 'agent': TensorDict(
    fields={
        agent: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                done: Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                episode_reward: Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                obs_action: Tensor(shape=torch.Size([3, 2]), device=cuda:0, dtype=torch.float32, is_shared=True),
                observation: Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                param: Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                state_action_value: Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cuda:0, dtype=torch.bo

### **Approach: Data Collection for Training**
- **Combines all group exploration policies** into a single sequential module (`TensorDictSequential`), ensuring actions include exploration noise.  
- **`SyncDataCollector` for Data Sampling**:
  - Collects experience from the environment using **exploration policies**.  
  - Runs on **`device` (GPU or CPU)** for efficiency.  
  - **Frames per batch**: `1,000`, ensuring large enough updates per iteration.  
  - **Total frames**: `50,000` (over `50` iterations).  
- **Purpose**: Efficiently gathers on-policy experiences for training with replay buffers.

In [87]:
# Put exploration policies from each group in a sequence
agents_exploration_policy = TensorDictSequential(*exploration_policies.values())

collector = SyncDataCollector(
    env,
    agents_exploration_policy,
    device=device,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
)

In [88]:
#Standard in off policy algos for efficient data collections
replay_buffers = {}
for group, _agents in env.group_map.items():
    replay_buffer = ReplayBuffer(
        storage=LazyMemmapStorage(memory_size, device="cpu"),
        sampler=RandomSampler(),
        batch_size=train_batch_size,
    )
    replay_buffer.append_transform(lambda batch: batch.to("cuda:0"))
    replay_buffers[group] = replay_buffer

In [89]:
# Define parameters for each agent in the cooperative group
agent_params = {
    "agent_0": {
        "alpha": 0.7,
        "lam": 0.8,
        "w_plus_prime_const": 0.8,
        "w_minus_prime_const": 0.2,
    },
    "agent_1": {
        "alpha": 0.65,
        "lam": 2.8,
        "w_plus_prime_const": 0.25,
        "w_minus_prime_const": 0.75,
    },
    "agent_2": {
        "alpha": 0.65,
        "lam": 2.8,
        "w_plus_prime_const": 0.25,
        "w_minus_prime_const": 0.75,
    },
}


In [90]:
import torch
import torch.nn as nn

class AdaptiveBehavioralParameters(nn.Module):
    def __init__(self, init_alpha, init_lam, init_w_plus_gamma, init_w_minus_gamma):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(init_alpha, dtype=torch.float32))
        self.lam = nn.Parameter(torch.tensor(init_lam, dtype=torch.float32))
        # Instead of fixed w constants, now learn the gamma parameters for weighting functions:
        self.w_plus_prime_gamma = nn.Parameter(torch.tensor(init_w_plus_gamma, dtype=torch.float32))
        self.w_minus_prime_gamma = nn.Parameter(torch.tensor(init_w_minus_gamma, dtype=torch.float32))

    def get_params(self):
        return {
            "alpha": self.alpha,
            "lam": self.lam,
            "w_plus_prime_gamma": self.w_plus_prime_gamma,
            "w_minus_prime_gamma": self.w_minus_prime_gamma
        }


In [91]:
adaptive_params = nn.ModuleDict({
    "agent_0": AdaptiveBehavioralParameters(init_alpha=1.2, init_lam=1.5, init_w_plus_gamma=0.5, init_w_minus_gamma=0.69),
    "agent_1": AdaptiveBehavioralParameters(init_alpha=1.2, init_lam=1.2, init_w_plus_gamma=0.5, init_w_minus_gamma=0.69),
    "agent_2": AdaptiveBehavioralParameters(init_alpha=1.2, init_lam=1.2, init_w_plus_gamma=0.5, init_w_minus_gamma=0.69)
})

In [92]:
w_plus_prime_const = 0.2
w_minus_prime_const = 0.8

def w_plus_prime_dynamic(p, params, epsilon=1e-6):
    """
    Compute a dynamic weighting for gains.
    p: a tensor of probabilities (values between 0 and 1)
    params: dictionary containing a learnable parameter 'w_plus_prime_gamma'
    """
    # Clamp p to avoid log(0)
    p = torch.clamp(p, min=epsilon, max=1.0)
    gamma = params.get("w_plus_prime_gamma", torch.tensor(0.61, dtype=p.dtype, device=p.device))
    # Prelec weighting derivative (an example formulation):
    return torch.exp(-(-torch.log(p)) ** gamma)

def w_minus_prime_dynamic(p, params, epsilon=1e-6):
    """
    Compute a dynamic weighting for losses.
    p: a tensor of probabilities (values between 0 and 1)
    params: dictionary containing a learnable parameter 'w_minus_prime_gamma'
    """
    p = torch.clamp(p, min=epsilon, max=1.0)
    gamma = params.get("w_minus_prime_gamma", torch.tensor(0.69, dtype=p.dtype, device=p.device))
    return torch.exp(-(-torch.log(p)) ** gamma)

def compute_phi_linear(R):
    """
    Compute linearized CPT sensitivity:
    φ(R) ≈ w'_+(p*) * u^+(R) for R>=0, and -w'_-(p*) * u^-(R) for R<0.
    """
    R = R.view(-1)
    v = torch.where(R >= 0, u_plus(R), -u_minus(R))
    phi = torch.where(R >= 0, w_plus_prime_const * v, -w_minus_prime_const * v)
    return phi.mean()

def compute_phi_linear_dynamic(R, adaptive_params, epsilon=1e-6):
    """
    Compute linearized CPT sensitivity using dynamic weighting:
      φ(R) ≈ w'_+(p*) * u^+(R)  for R >= 0,
           ≈ -w'_-(p*) * u^-(R) for R < 0.
    Here p* is obtained by normalizing R into (0,1) (using a sigmoid, for example).
    adaptive_params is a ModuleDict where each module’s get_params() returns a dict that
    includes learnable parameters for the weighting functions (e.g., w_plus_prime_gamma).
    """
    R = R.view(-1)
    # For example, normalize R to (0,1) via a sigmoid.
    p_star = torch.sigmoid(R)

    # Compute the basic utility:
    # (Assuming you have functions u_plus(R) and u_minus(R) already defined.)
    # If not, you can use your existing stable versions:
    v = torch.where(R >= 0, u_plus(R), -u_minus(R))

    phi_values = []
    for agent_id, param_module in adaptive_params.items():
        params = param_module.get_params()
        # For gains, use dynamic weighting from our new function:
        weight_gain = w_plus_prime_dynamic(p_star, params, epsilon)
        # For losses:
        weight_loss = w_minus_prime_dynamic(p_star, params, epsilon)
        phi = torch.where(R >= 0, weight_gain * v, -weight_loss * v)
        phi_values.append(phi)
    phi_stack = torch.stack(phi_values)
    return phi_stack.mean(dim=0)


def stable_u_plus_agent(x, params, epsilon=1e-6, min_val=1e-3):
    # Ensure that x + epsilon is not too small; then compute in log space.
    y = torch.clamp(x + epsilon, min=min_val)
    return torch.exp(params["alpha"] * torch.log(y))

def u_minus_agent(x, params, epsilon=1e-6, min_val=1e-3):
    # For the negative branch, ensure -x + epsilon is not too small.
    y = torch.clamp(-x + epsilon, min=min_val)
    return params["lam"] * torch.pow(y, params["alpha"])

def compute_phi_cross_dynamic(R, adaptive_params, epsilon=1e-6, min_val=1e-3):
    """
    Compute the cross-agent CPT sensitivity factor dynamically.
    For each reward in R, first normalize it to (0,1) via a sigmoid (p_star).
    Then compute the utility using u_plus for gains and u_minus_agent for losses.
    Finally, apply dynamic weighting using the learnable weighting functions and average across agents.

    Args:
        R (Tensor): A tensor of rewards.
        adaptive_params (ModuleDict): A dictionary (ModuleDict) of adaptive parameter modules.
        epsilon (float): A small constant to prevent log(0).
        min_val (float): A minimum value to clamp inputs.

    Returns:
        Tensor: The averaged sensitivity factor φ.
    """
    R = R.view(-1)
    # Normalize rewards to [0,1] for the weighting function:
    p_star = torch.sigmoid(R)

    phi_values = []
    for agent_id, param_module in adaptive_params.items():
        params = param_module.get_params()
        # For gains: clamp R+epsilon, then compute stable u_plus:
        y = torch.clamp(R + epsilon, min=min_val)
        u_plus_val = torch.exp(params["alpha"] * torch.log(y))
        # For losses, use u_minus_agent (as defined elsewhere)
        v = torch.where(R >= 0, u_plus_val, -u_minus_agent(R, params, epsilon, min_val))
        # Now compute dynamic weights from p_star using our new functions:
        weight_gain = w_plus_prime_dynamic(p_star, params, epsilon)
        weight_loss = w_minus_prime_dynamic(p_star, params, epsilon)
        phi = torch.where(R >= 0, weight_gain * v, -weight_loss * v)
        phi_values.append(phi)

    phi_stack = torch.stack(phi_values)
    return phi_stack.mean(dim=0)

def C_transform_cross_dynamic(x, adaptive_params, epsilon=1e-6, min_val=1e-3):
    """
    Compute the CPT-transformed target value by averaging each agent’s transformation,
    using dynamic weighting rather than fixed constants.
    """
    transformed_vals = []
    # Optionally, define a normalization for x to obtain a probability. For instance:
    p_star = torch.sigmoid(x)
    for agent_id, param_module in adaptive_params.items():
        params = param_module.get_params()
        y = torch.clamp(x + epsilon, min=min_val)
        # Compute u_plus for gains:
        u_plus_val = torch.exp(params["alpha"] * torch.log(y))
        # For losses, use u_minus_agent as before:
        transformed = torch.where(
            x >= 0,
            w_plus_prime_dynamic(p_star, params, epsilon) * u_plus_val,
            -w_minus_prime_dynamic(p_star, params, epsilon) * u_minus_agent(x, params, epsilon, min_val)
        )
        # Clamp the final output for safety:
        transformed = torch.clamp(transformed, min=-1e6, max=1e6)
        transformed_vals.append(transformed)
    return torch.stack(transformed_vals).mean(dim=0)


def C_transform(x):
    """
    Simple CPT transformation on one-step return:
    C(x) ≈ w'_+(p*) * u^+(x) if x >= 0, else -w'_-(p*) * u^-(x).
    """
    return torch.where(x >= 0, w_plus_prime_const * u_plus(x), -w_minus_prime_const * u_minus(x))

In [93]:
from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule

from tensordict.utils import NestedKey, unravel_key
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
    _cache_values,
    _GAMMA_LMBDA_DEPREC_ERROR,
    _reduce,
    default_value_kwargs,
    distance_loss,
    ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


class CPTDDPGLoss(LossModule):
    """The DDPG Loss class.

    Args:
        actor_network (TensorDictModule): a policy operator.
        value_network (TensorDictModule): a Q value operator.
        loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
        delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for
            data collection. Default is ``False``.
        delay_value (bool, optional): whether to separate the target value networks from the value networks used for
            data collection. Default is ``True``.
        separate_losses (bool, optional): if ``True``, shared parameters between
            policy and critic will only be trained on the policy loss.
            Defaults to ``False``, i.e., gradients are propagated to shared
            parameters for both policy and critic losses.
        reduction (str, optional): Specifies the reduction to apply to the output:
            ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
            ``"mean"``: the sum of the output will be divided by the number of
            elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.

    Examples:
        >>> import torch
        >>> from torch import nn
        >>> from torchrl.data import Bounded
        >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
        >>> from torchrl.objectives.ddpg import DDPGLoss
        >>> from tensordict import TensorDict
        >>> n_act, n_obs = 4, 3
        >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
        >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
        >>> class ValueClass(nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.linear = nn.Linear(n_obs + n_act, 1)
        ...     def forward(self, obs, act):
        ...         return self.linear(torch.cat([obs, act], -1))
        >>> module = ValueClass()
        >>> value = ValueOperator(
        ...     module=module,
        ...     in_keys=["observation", "action"])
        >>> loss = DDPGLoss(actor, value)
        >>> batch = [2, ]
        >>> data = TensorDict({
        ...        "observation": torch.randn(*batch, n_obs),
        ...        "action": spec.rand(batch),
        ...        ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
        ...        ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
        ...        ("next", "reward"): torch.randn(*batch, 1),
        ...        ("next", "observation"): torch.randn(*batch, n_obs),
        ...    }, batch)
        >>> loss(data)
        TensorDict(
            fields={
                loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)

    This class is compatible with non-tensordict based modules too and can be
    used without recurring to any tensordict-related primitive. In this case,
    the expected keyword arguments are:
    ``["next_reward", "next_done", "next_terminated"]`` + in_keys of the actor_network and value_network.
    The return value is a tuple of tensors in the following order:
    ``["loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max"]``

    Examples:
        >>> import torch
        >>> from torch import nn
        >>> from torchrl.data import Bounded
        >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
        >>> from torchrl.objectives.ddpg import DDPGLoss
        >>> _ = torch.manual_seed(42)
        >>> n_act, n_obs = 4, 3
        >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
        >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
        >>> class ValueClass(nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.linear = nn.Linear(n_obs + n_act, 1)
        ...     def forward(self, obs, act):
        ...         return self.linear(torch.cat([obs, act], -1))
        >>> module = ValueClass()
        >>> value = ValueOperator(
        ...     module=module,
        ...     in_keys=["observation", "action"])
        >>> loss = DDPGLoss(actor, value)
        >>> loss_actor, loss_value, pred_value, target_value, pred_value_max, target_value_max = loss(
        ...     observation=torch.randn(n_obs),
        ...     action=spec.rand(),
        ...     next_done=torch.zeros(1, dtype=torch.bool),
        ...     next_terminated=torch.zeros(1, dtype=torch.bool),
        ...     next_observation=torch.randn(n_obs),
        ...     next_reward=torch.randn(1))
        >>> loss_actor.backward()

    The output keys can also be filtered using the :meth:`DDPGLoss.select_out_keys`
    method.

    Examples:
        >>> loss.select_out_keys('loss_actor', 'loss_value')
        >>> loss_actor, loss_value = loss(
        ...     observation=torch.randn(n_obs),
        ...     action=spec.rand(),
        ...     next_done=torch.zeros(1, dtype=torch.bool),
        ...     next_terminated=torch.zeros(1, dtype=torch.bool),
        ...     next_observation=torch.randn(n_obs),
        ...     next_reward=torch.randn(1))
        >>> loss_actor.backward()

    """

    @dataclass
    class _AcceptedKeys:
        """Maintains default values for all configurable tensordict keys.

        This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
        default values.

        Attributes:
            state_action_value (NestedKey): The input tensordict key where the
                state action value is expected. Will be used for the underlying
                value estimator as value key. Defaults to ``"state_action_value"``.
            priority (NestedKey): The input tensordict key where the target
                priority is written to. Defaults to ``"td_error"``.
            reward (NestedKey): The input tensordict key where the reward is expected.
                Will be used for the underlying value estimator. Defaults to ``"reward"``.
            done (NestedKey): The key in the input TensorDict that indicates
                whether a trajectory is done. Will be used for the underlying value estimator.
                Defaults to ``"done"``.
            terminated (NestedKey): The key in the input TensorDict that indicates
                whether a trajectory is terminated. Will be used for the underlying value estimator.
                Defaults to ``"terminated"``.

        """

        state_action_value: NestedKey = "state_action_value"
        priority: NestedKey = "td_error"
        reward: NestedKey = "reward"
        done: NestedKey = "done"
        terminated: NestedKey = "terminated"

    tensor_keys: _AcceptedKeys
    default_keys = _AcceptedKeys
    default_value_estimator: ValueEstimators = ValueEstimators.TD0
    out_keys = [
        "loss_actor",
        "loss_value",
        "pred_value",
        "target_value",
        "pred_value_max",
        "target_value_max",
    ]

    actor_network: TensorDictModule
    value_network: actor_network
    actor_network_params: TensorDictParams
    value_network_params: TensorDictParams
    target_actor_network_params: TensorDictParams
    target_value_network_params: TensorDictParams

    def __init__(
        self,
        actor_network: TensorDictModule,
        value_network: TensorDictModule,
        *,
        loss_function: str = "l2",
        delay_actor: bool = False,
        delay_value: bool = True,
        gamma: float = None,
        separate_losses: bool = False,
        reduction: str = None,
    ) -> None:
        self._in_keys = None
        if reduction is None:
            reduction = "mean"
        super().__init__()
        self.delay_actor = delay_actor
        self.delay_value = delay_value

        actor_critic = ActorCriticWrapper(actor_network, value_network)
        params = TensorDict.from_module(actor_critic)
        params_meta = params.apply(
            self._make_meta_params, device=torch.device("meta"), filter_empty=False
        )
        with params_meta.to_module(actor_critic):
            self.__dict__["actor_critic"] = deepcopy(actor_critic)

        self.convert_to_functional(
            actor_network,
            "actor_network",
            create_target_params=self.delay_actor,
        )
        if separate_losses:
            # we want to make sure there are no duplicates in the params: the
            # params of critic must be refs to actor if they're shared
            policy_params = list(actor_network.parameters())
        else:
            policy_params = None
        self.convert_to_functional(
            value_network,
            "value_network",
            create_target_params=self.delay_value,
            compare_against=policy_params,
        )
        self.actor_critic.module[0] = self.actor_network
        self.actor_critic.module[1] = self.value_network

        self.actor_in_keys = actor_network.in_keys
        self.value_exclusive_keys = set(self.value_network.in_keys) - (
            set(self.actor_in_keys) | set(self.actor_network.out_keys)
        )

        self.loss_function = loss_function
        self.reduction = reduction
        if gamma is not None:
            raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)

    def _forward_value_estimator_keys(self, **kwargs) -> None:
        if self._value_estimator is not None:
            self._value_estimator.set_keys(
                value=self._tensor_keys.state_action_value,
                reward=self._tensor_keys.reward,
                done=self._tensor_keys.done,
                terminated=self._tensor_keys.terminated,
            )
        self._set_in_keys()

    def _set_in_keys(self):
        in_keys = {
            unravel_key(("next", self.tensor_keys.reward)),
            unravel_key(("next", self.tensor_keys.done)),
            unravel_key(("next", self.tensor_keys.terminated)),
            *self.actor_in_keys,
            *[unravel_key(("next", key)) for key in self.actor_in_keys],
            *self.value_network.in_keys,
            *[unravel_key(("next", key)) for key in self.value_network.in_keys],
        }
        self._in_keys = sorted(in_keys, key=str)

    @property
    def in_keys(self):
        if self._in_keys is None:
            self._set_in_keys()
        return self._in_keys

    @in_keys.setter
    def in_keys(self, values):
        self._in_keys = values


    def _clear_weakrefs(self, *tds):
        if is_compiling():
            # Waiting for weakrefs reconstruct to be supported by compile
            for td in tds:
                if isinstance(td, str):
                    td = getattr(self, td, None)
                if not is_tensor_collection(td):
                    continue
                td.clear_refs_for_compile_()

    @dispatch
    def forward(self, tensordict: TensorDictBase) -> TensorDict:
        """Computes the DDPG losses given a tensordict sampled from the replay buffer.

        This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
            a priority to items in the tensordict.

        Args:
            tensordict (TensorDictBase): a tensordict with keys ["done", "terminated", "reward"] and the in_keys of the actor
                and value networks.

        Returns:
            a tuple of 2 tensors containing the DDPG loss.

        """
        loss_value, metadata = self.loss_value(tensordict)
        loss_actor, metadata_actor = self.loss_actor(tensordict)
        metadata.update(metadata_actor)
        td_out = TensorDict(
            source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
            batch_size=[],
        )
        self._clear_weakrefs(
            tensordict,
            td_out,
            "value_network_params",
            "target_value_network_params",
            "target_actor_network_params",
            "actor_network_params",
        )
        return td_out

    def loss_actor(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, dict]:
        td_copy = tensordict.select(
            *self.actor_in_keys, *self.value_exclusive_keys, strict=False
        ).detach()

        with self.actor_network_params.to_module(self.actor_network):
            td_copy = self.actor_network(td_copy)

        with self._cached_detached_value_params.to_module(self.value_network):
            td_copy = self.value_network(td_copy)

        actions = td_copy.get((self.actor_network.in_keys[0][0], "action"))
        Q_values = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
        grad_Q = torch.autograd.grad(Q_values.sum(), actions, retain_graph=True)[0]

        # Retrieve the cooperative reward (assuming common group 'agents')
        returns = tensordict.get(("agent", "episode_reward")).view(-1)
        # Now use the dynamic version that computes φ using your learnable weighting functions:
        phi_factor = compute_phi_cross_dynamic(returns, adaptive_params)

        policy_gradient = actions * grad_Q
        loss_actor = -phi_factor.mean() * policy_gradient.mean()

        return _reduce(loss_actor, self.reduction), {}


    def loss_value(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, dict]:
        td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach()
        with self.value_network_params.to_module(self.value_network):
            self.value_network(td_copy)
        pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)

        target_value = self.value_estimator.value_estimate(
            tensordict, target_params=self._cached_target_params
        ).squeeze(-1)
        # Replace the old transformation with your new dynamic version:
        target_value_CPT = C_transform_cross_dynamic(target_value, adaptive_params)

        loss_value = distance_loss(pred_val, target_value_CPT, loss_function=self.loss_function)
        tensordict.set("target_value_CPT", target_value_CPT, inplace=True)

        td_error = (pred_val - target_value_CPT).pow(2).detach()
        tensordict.set(self.tensor_keys.priority, td_error, inplace=True)

        metadata = {
            "td_error": td_error,
            "pred_value": pred_val,
            "target_value_CPT": target_value_CPT,
        }
        return _reduce(loss_value, self.reduction), metadata


    def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
        if value_type is None:
            value_type = self.default_value_estimator
        self.value_type = value_type
        hp = dict(default_value_kwargs(value_type))
        if hasattr(self, "gamma"):
            hp["gamma"] = self.gamma
        hp.update(hyperparams)
        if value_type == ValueEstimators.TD1:
            self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)
        elif value_type == ValueEstimators.TD0:
            self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)
        elif value_type == ValueEstimators.GAE:
            raise NotImplementedError(
                f"Value type {value_type} it not implemented for loss {type(self)}."
            )
        elif value_type == ValueEstimators.TDLambda:
            self._value_estimator = TDLambdaEstimator(
                value_network=self.actor_critic, **hp
            )
        else:
            raise NotImplementedError(f"Unknown value type {value_type}")

        tensor_keys = {
            "value": self.tensor_keys.state_action_value,
            "reward": self.tensor_keys.reward,
            "done": self.tensor_keys.done,
            "terminated": self.tensor_keys.terminated,
        }
        self._value_estimator.set_keys(**tensor_keys)

    @property
    @_cache_values
    def _cached_target_params(self):
        target_params = TensorDict(
            {
                "module": {
                    "0": self.target_actor_network_params,
                    "1": self.target_value_network_params,
                }
            },
            batch_size=self.target_actor_network_params.batch_size,
            device=self.target_actor_network_params.device,
        )
        return target_params

    @property
    @_cache_values
    def _cached_detached_value_params(self):
        return self.value_network_params.detach()

### **Approach: Loss Calculation & Optimization**
#### **Defining the Loss Function (`DDPGLoss`)**
- **Uses separate actor and critic losses**:
  - **`actor_network = policies[group]`**: Optimizes agent actions.
  - **`value_network = critics[group]`**: Estimates state-action values.
- **Target Network (`delay_value = True`)**:
  - Uses a **target critic** for more stable learning.
  - **Loss function**: Mean Squared Error (`"l2"`).
- **Key Assignments**:
  - **State-action value**: `(group, "state_action_value")`.
  - **Reward Signal**: `(group, "reward")`.
  - **Termination Handling**: `(group, "done")` and `(group, "terminated")`.
- **TD(0) Estimator**: Uses **Temporal Difference (TD) learning** with discount factor `γ = 0.99`.

#### **Target Network Updates**
- **Soft update mechanism (`SoftUpdate`)**:
  - **Gradually updates target networks** using `τ = 0.005`.
  - Prevents drastic changes, improving stability.

#### **Optimizers**
- **Separate Adam optimizers for actor and critic networks**:
  - **`loss_actor`**: Updates policy parameters.
  - **`loss_value`**: Updates value network parameters.
- **Learning rate (`lr = 3e-4`)** ensures smooth gradient updates.

In [94]:
losses = {}
for group, _agents in env.group_map.items():
    loss_module = CPTDDPGLoss(
        actor_network=policies[group],  # Use the non-explorative policies
        value_network=critics[group],
        delay_value=True,  # Whether to use a target network for the value
        loss_function="l2",
    )
    loss_module.set_keys(
        state_action_value=(group, "state_action_value"),
        reward=(group, "reward"),
        done=(group, "done"),
        terminated=(group, "terminated"),
    )
    loss_module.make_value_estimator(ValueEstimators.TD0, gamma=gamma)

    losses[group] = loss_module

target_updaters = {
    group: SoftUpdate(loss, tau=polyak_tau) for group, loss in losses.items()
}

optimisers = {
    group: {
        "loss_actor": torch.optim.Adam(
            loss.actor_network_params.flatten_keys().values(), lr=lr
        ),
        "loss_value": torch.optim.Adam(
            loss.value_network_params.flatten_keys().values(), lr=lr
        ),
    }
    for group, loss in losses.items()
}

optimizer_behavioral = torch.optim.Adam(adaptive_params.parameters(), lr=1e-5)

In [95]:
def process_batch(batch: TensorDictBase) -> TensorDictBase:
    """
    If the `(group, "terminated")` and `(group, "done")` keys are not present, create them by expanding
    `"terminated"` and `"done"`.
    This is needed to present them with the same shape as the reward to the loss.
    """
    for group in env.group_map.keys():
        keys = list(batch.keys(True, True))
        group_shape = batch.get_item_shape(group)
        nested_done_key = ("next", group, "done")
        nested_terminated_key = ("next", group, "terminated")
        if nested_done_key not in keys:
            batch.set(
                nested_done_key,
                batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
            )
        if nested_terminated_key not in keys:
            batch.set(
                nested_terminated_key,
                batch.get(("next", "terminated"))
                .unsqueeze(-1)
                .expand((*group_shape, 1)),
            )
    return batch

### **Approach: Training Loop & Optimization**
#### **Progress Bar & Logging Setup**
- **Uses `tqdm`** to track training iterations with episode rewards.  
- **Initializes `episode_reward_mean_map`** to store reward trends per agent group.  
- **Creates `train_group_map`** as a copy of `env.group_map`, allowing dynamic updates.

#### **Main Training Loop**
- **Iterates through `collector`** to process training batches.
- **Preprocesses Data (`process_batch`)**:
  - Expands done/terminated keys for proper loss computation.
  - **Excludes data from other groups** to isolate training signals.
  - **Reshapes batch** to align with replay buffer dimensions.
- **Stores Data in Replay Buffer (`replay_buffers[group].extend(group_batch)`)**.

#### **Optimization Steps**
- **Samples batches (`n_optimiser_steps = 100`)** from replay buffer.
- **Computes & Backpropagates Loss**:
  - Extracts actor (`loss_actor`) and critic (`loss_value`) loss.
  - **Clips gradients (`max_grad_norm = 1.0`)** to prevent instability.
  - **Optimizes parameters with Adam**, resetting gradients after each step.
- **Soft Updates (`target_updaters[group].step()`)**:
  - Gradually syncs target networks using `τ = 0.005`.

#### **Adaptive Exploration**
- **Anneals exploration noise (`sigma`)** based on the number of frames processed.

#### **Training Halting Condition**
- **Stops training evaders after `iteration_when_stop_training_evaders = 25`**.

#### **Logging & Progress Tracking**
- **Computes mean episode reward** for each group.
- **Updates `tqdm` progress bar** with latest reward values.

In [96]:
def compute_adaptive_loss(replay_sample, value_estimator, adaptive_params, target_params, epsilon=1e-6, min_val=1e-3):
    # Compute the standard target Q-value:
    target_value = value_estimator.value_estimate(replay_sample, target_params=target_params).squeeze(-1)
    # Compute the CPT-transformed target using adaptive parameters:
    target_value_CPT = C_transform_cross_dynamic(target_value, adaptive_params, epsilon, min_val)
    base_adaptive_loss = torch.mean((target_value_CPT - target_value) ** 2)
    return base_adaptive_loss



In [99]:
pbar = tqdm(
    total=n_iters,
    desc=", ".join(
        [f"episode_reward_mean_{group} = 0" for group in env.group_map.keys()]
    ),
)
episode_reward_mean_map = {group: [] for group in env.group_map.keys()}
train_group_map = copy.deepcopy(env.group_map)


adaptive_update_frequency = 10  # update adaptive parameters every 10 iterations
freeze_adaptive_until = 20  # don't update adaptive parameters until after 20 iterations
scale_factor = 1e-3  # scaling for adaptive loss
reg_lambda = 1e-3  # regularization coefficient

# (Assume initial_params is defined right after adaptive_params creation)
# For example:
initial_params = {
    agent_id: {name: param.clone().detach() for name, param in module.get_params().items()}
    for agent_id, module in adaptive_params.items()
}

previous_base_loss = None

for iteration, batch in enumerate(collector):
    current_frames = batch.numel()
    batch = process_batch(batch)

    for group in train_group_map.keys():
        group_batch = batch.exclude(
            *[
                key
                for _group in env.group_map.keys()
                if _group != group
                for key in [_group, ("next", _group)]
            ]
        )
        group_batch = group_batch.reshape(-1)
        replay_buffers[group].extend(group_batch)

        for _ in range(n_optimiser_steps):
            subdata = replay_buffers[group].sample()
            loss_vals = losses[group](subdata)
            for loss_name in ["loss_actor", "loss_value"]:
                loss_value = loss_vals[loss_name]
                optimiser = optimisers[group][loss_name]
                loss_value.backward()
                torch.nn.utils.clip_grad_norm_(optimiser.param_groups[0]["params"], max_grad_norm)
                optimiser.step()
                optimiser.zero_grad()
            target_updaters[group].step()
            exploration_policies[group][-1].step(current_frames)

    # Decoupled update for adaptive parameters every adaptive_update_frequency iterations,
    # but only after freeze_adaptive_until iterations.
    if iteration > freeze_adaptive_until and iteration % adaptive_update_frequency == 0:
        # Sample a batch from the cooperative group "agents"
        subdata = replay_buffers["agents"].sample()

        # Compute the standard target Q-value:
        target_value = losses["agents"]._value_estimator.value_estimate(
            subdata, target_params=losses["agents"]._cached_target_params
        ).squeeze(-1)

        # Compute the CPT-transformed target:
        target_value_CPT = C_transform_cross_dynamic(target_value, adaptive_params)

        # Compute the base adaptive loss as mean squared error:
        base_loss = torch.mean((target_value_CPT - target_value) ** 2)

        # Compute a dynamic scaling factor based on the change in base loss
        if previous_base_loss is None:
            dynamic_factor = 1.0
        else:
            # Increase the factor if the loss has changed substantially
            dynamic_factor = 1.0 + torch.abs(base_loss - previous_base_loss)
            # If you prefer to have a plain Python float, you can call .item() here
            # dynamic_factor = 1.0 + torch.abs(base_loss - previous_base_loss).item()

        # Update the previous_base_loss for the next adaptive update (detach it to avoid gradient tracking)
        previous_base_loss = base_loss.detach()

        # Compute the L2 regularization loss for keeping adaptive parameters near their initial values:
        reg_loss = 0.0
        for agent_id, module in adaptive_params.items():
            params = module.get_params()
            for name, param in params.items():
                reg_loss += torch.mean((param - initial_params[agent_id][name]) ** 2)

        # Combine the base adaptive loss (scaled dynamically) with the regularization term
        adaptive_loss = dynamic_factor * scale_factor * base_loss + reg_lambda * reg_loss

        # Print out the loss and the dynamic factor (dynamic_factor is a float or a tensor—if it's a float, no .item() is needed)
        print(f"Iteration {iteration}: Adaptive loss = {adaptive_loss.item()}, dynamic factor = {dynamic_factor}")

        # Backward pass and update adaptive parameters:
        adaptive_loss.backward()
        torch.nn.utils.clip_grad_norm_(adaptive_params.parameters(), max_norm=1.0)
        optimizer_behavioral.step()
        optimizer_behavioral.zero_grad()

    # Optionally, monitor adaptive parameters here
    for agent_id, module in adaptive_params.items():
        params = module.get_params()
        for name, param in params.items():
            if param.grad is not None:
                print(f"Iter {iteration} - {agent_id} {name}: {param.item():.6f}, grad mean: {param.grad.abs().mean().item():.6f}")

    # Logging of episode rewards, etc.
    for group in env.group_map.keys():
        episode_reward_mean = (
            batch.get(("next", group, "episode_reward"))[
                batch.get(("next", group, "done"))
            ]
            .mean()
            .item()
        )
        episode_reward_mean_map[group].append(episode_reward_mean)

    pbar.set_description(
        ", ".join(
            [
                f"episode_reward_mean_{group} = {episode_reward_mean_map[group][-1]}"
                for group in env.group_map.keys()
            ]
        ),
        refresh=False,
    )
    pbar.update()




episode_reward_mean_agent = 0.13389629125595093:   4%|▎         | 7/200 [03:34<1:38:38, 30.67s/it]

episode_reward_mean_agent = 0.13589432835578918:   0%|          | 1/200 [00:20<1:08:48, 20.75s/it][A

Iter 0 - agent_0 alpha: 1.200000, grad mean: 2.930393
Iter 0 - agent_0 lam: 1.500000, grad mean: 1.660273
Iter 0 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 3.674483
Iter 0 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.322690
Iter 0 - agent_1 alpha: 1.200000, grad mean: 2.743518
Iter 0 - agent_1 lam: 1.200000, grad mean: 1.660273
Iter 0 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 3.674483
Iter 0 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.258152
Iter 0 - agent_2 alpha: 1.200000, grad mean: 2.743518
Iter 0 - agent_2 lam: 1.200000, grad mean: 1.660273
Iter 0 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 3.674483
Iter 0 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.258152



episode_reward_mean_agent = 0.13417001068592072:   1%|          | 2/200 [00:41<1:08:27, 20.74s/it][A

Iter 1 - agent_0 alpha: 1.200000, grad mean: 3.266326
Iter 1 - agent_0 lam: 1.500000, grad mean: 1.717242
Iter 1 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 4.220649
Iter 1 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.332104
Iter 1 - agent_1 alpha: 1.200000, grad mean: 3.071735
Iter 1 - agent_1 lam: 1.200000, grad mean: 1.717242
Iter 1 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 4.220649
Iter 1 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.265683
Iter 1 - agent_2 alpha: 1.200000, grad mean: 3.071735
Iter 1 - agent_2 lam: 1.200000, grad mean: 1.717242
Iter 1 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 4.220649
Iter 1 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.265683



episode_reward_mean_agent = 0.13828521966934204:   2%|▏         | 3/200 [01:02<1:08:19, 20.81s/it][A

Iter 2 - agent_0 alpha: 1.200000, grad mean: 3.602347
Iter 2 - agent_0 lam: 1.500000, grad mean: 1.774439
Iter 2 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 4.770781
Iter 2 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.342680
Iter 2 - agent_1 alpha: 1.200000, grad mean: 3.400511
Iter 2 - agent_1 lam: 1.200000, grad mean: 1.774439
Iter 2 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 4.770781
Iter 2 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.274144
Iter 2 - agent_2 alpha: 1.200000, grad mean: 3.400511
Iter 2 - agent_2 lam: 1.200000, grad mean: 1.774439
Iter 2 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 4.770781
Iter 2 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.274144



episode_reward_mean_agent = 0.13877764344215393:   2%|▏         | 4/200 [01:23<1:08:12, 20.88s/it][A

Iter 3 - agent_0 alpha: 1.200000, grad mean: 3.935788
Iter 3 - agent_0 lam: 1.500000, grad mean: 1.822309
Iter 3 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 5.336387
Iter 3 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.350645
Iter 3 - agent_1 alpha: 1.200000, grad mean: 3.727357
Iter 3 - agent_1 lam: 1.200000, grad mean: 1.822309
Iter 3 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 5.336387
Iter 3 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.280516
Iter 3 - agent_2 alpha: 1.200000, grad mean: 3.727357
Iter 3 - agent_2 lam: 1.200000, grad mean: 1.822309
Iter 3 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 5.336387
Iter 3 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.280516



episode_reward_mean_agent = 0.13505150377750397:   2%|▎         | 5/200 [01:44<1:07:42, 20.83s/it][A

Iter 4 - agent_0 alpha: 1.200000, grad mean: 4.271237
Iter 4 - agent_0 lam: 1.500000, grad mean: 1.868312
Iter 4 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 5.907833
Iter 4 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.358169
Iter 4 - agent_1 alpha: 1.200000, grad mean: 4.056433
Iter 4 - agent_1 lam: 1.200000, grad mean: 1.868312
Iter 4 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 5.907833
Iter 4 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.286535
Iter 4 - agent_2 alpha: 1.200000, grad mean: 4.056433
Iter 4 - agent_2 lam: 1.200000, grad mean: 1.868312
Iter 4 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 5.907833
Iter 4 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.286535



episode_reward_mean_agent = 0.13710802793502808:   3%|▎         | 6/200 [02:04<1:07:20, 20.83s/it][A

Iter 5 - agent_0 alpha: 1.200000, grad mean: 4.601694
Iter 5 - agent_0 lam: 1.500000, grad mean: 1.911206
Iter 5 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 6.465546
Iter 5 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.365762
Iter 5 - agent_1 alpha: 1.200000, grad mean: 4.381077
Iter 5 - agent_1 lam: 1.200000, grad mean: 1.911206
Iter 5 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 6.465546
Iter 5 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.292609
Iter 5 - agent_2 alpha: 1.200000, grad mean: 4.381077
Iter 5 - agent_2 lam: 1.200000, grad mean: 1.911206
Iter 5 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 6.465546
Iter 5 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.292609



episode_reward_mean_agent = 0.1374434381723404:   4%|▎         | 7/200 [02:25<1:06:54, 20.80s/it] [A

Iter 6 - agent_0 alpha: 1.200000, grad mean: 4.934233
Iter 6 - agent_0 lam: 1.500000, grad mean: 1.948621
Iter 6 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 7.028308
Iter 6 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.371223
Iter 6 - agent_1 alpha: 1.200000, grad mean: 4.707945
Iter 6 - agent_1 lam: 1.200000, grad mean: 1.948621
Iter 6 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 7.028308
Iter 6 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.296978
Iter 6 - agent_2 alpha: 1.200000, grad mean: 4.707945
Iter 6 - agent_2 lam: 1.200000, grad mean: 1.948621
Iter 6 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 7.028308
Iter 6 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.296978



episode_reward_mean_agent = 0.1365082561969757:   4%|▍         | 8/200 [02:46<1:06:53, 20.90s/it][A

Iter 7 - agent_0 alpha: 1.200000, grad mean: 5.268345
Iter 7 - agent_0 lam: 1.500000, grad mean: 1.986237
Iter 7 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 7.598986
Iter 7 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.377423
Iter 7 - agent_1 alpha: 1.200000, grad mean: 5.036709
Iter 7 - agent_1 lam: 1.200000, grad mean: 1.986237
Iter 7 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 7.598986
Iter 7 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.301938
Iter 7 - agent_2 alpha: 1.200000, grad mean: 5.036709
Iter 7 - agent_2 lam: 1.200000, grad mean: 1.986237
Iter 7 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 7.598986
Iter 7 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.301938



episode_reward_mean_agent = 0.14169543981552124:   4%|▍         | 9/200 [03:07<1:06:14, 20.81s/it][A

Iter 8 - agent_0 alpha: 1.200000, grad mean: 5.601916
Iter 8 - agent_0 lam: 1.500000, grad mean: 2.020250
Iter 8 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 8.171803
Iter 8 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.383192
Iter 8 - agent_1 alpha: 1.200000, grad mean: 5.365370
Iter 8 - agent_1 lam: 1.200000, grad mean: 2.020250
Iter 8 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 8.171803
Iter 8 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.306553
Iter 8 - agent_2 alpha: 1.200000, grad mean: 5.365370
Iter 8 - agent_2 lam: 1.200000, grad mean: 2.020250
Iter 8 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 8.171803
Iter 8 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.306553



episode_reward_mean_agent = 0.140406534075737:   5%|▌         | 10/200 [03:28<1:05:50, 20.79s/it] [A

Iter 9 - agent_0 alpha: 1.200000, grad mean: 5.935163
Iter 9 - agent_0 lam: 1.500000, grad mean: 2.050816
Iter 9 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 8.747347
Iter 9 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.387420
Iter 9 - agent_1 alpha: 1.200000, grad mean: 5.693736
Iter 9 - agent_1 lam: 1.200000, grad mean: 2.050816
Iter 9 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 8.747347
Iter 9 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.309936
Iter 9 - agent_2 alpha: 1.200000, grad mean: 5.693736
Iter 9 - agent_2 lam: 1.200000, grad mean: 2.050816
Iter 9 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 8.747347
Iter 9 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.309936



episode_reward_mean_agent = 0.13855914771556854:   6%|▌         | 11/200 [03:48<1:05:28, 20.78s/it][A

Iter 10 - agent_0 alpha: 1.200000, grad mean: 6.267298
Iter 10 - agent_0 lam: 1.500000, grad mean: 2.082183
Iter 10 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 9.322682
Iter 10 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.392466
Iter 10 - agent_1 alpha: 1.200000, grad mean: 6.021203
Iter 10 - agent_1 lam: 1.200000, grad mean: 2.082183
Iter 10 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 9.322682
Iter 10 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.313972
Iter 10 - agent_2 alpha: 1.200000, grad mean: 6.021203
Iter 10 - agent_2 lam: 1.200000, grad mean: 2.082183
Iter 10 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 9.322682
Iter 10 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.313972



episode_reward_mean_agent = 0.14111290872097015:   6%|▌         | 12/200 [04:09<1:04:58, 20.74s/it][A

Iter 11 - agent_0 alpha: 1.200000, grad mean: 6.597880
Iter 11 - agent_0 lam: 1.500000, grad mean: 2.108637
Iter 11 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 9.910788
Iter 11 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.396330
Iter 11 - agent_1 alpha: 1.200000, grad mean: 6.347651
Iter 11 - agent_1 lam: 1.200000, grad mean: 2.108637
Iter 11 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 9.910788
Iter 11 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.317064
Iter 11 - agent_2 alpha: 1.200000, grad mean: 6.347651
Iter 11 - agent_2 lam: 1.200000, grad mean: 2.108637
Iter 11 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 9.910788
Iter 11 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.317064



episode_reward_mean_agent = 0.13459251821041107:   6%|▋         | 13/200 [04:30<1:04:59, 20.85s/it][A

Iter 12 - agent_0 alpha: 1.200000, grad mean: 6.932056
Iter 12 - agent_0 lam: 1.500000, grad mean: 2.140040
Iter 12 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 10.495029
Iter 12 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.401710
Iter 12 - agent_1 alpha: 1.200000, grad mean: 6.677318
Iter 12 - agent_1 lam: 1.200000, grad mean: 2.140040
Iter 12 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 10.495029
Iter 12 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.321368
Iter 12 - agent_2 alpha: 1.200000, grad mean: 6.677318
Iter 12 - agent_2 lam: 1.200000, grad mean: 2.140040
Iter 12 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 10.495029
Iter 12 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.321368



episode_reward_mean_agent = 0.13805460929870605:   7%|▋         | 14/200 [04:51<1:04:34, 20.83s/it][A

Iter 13 - agent_0 alpha: 1.200000, grad mean: 7.268262
Iter 13 - agent_0 lam: 1.500000, grad mean: 2.166467
Iter 13 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 11.086571
Iter 13 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.406237
Iter 13 - agent_1 alpha: 1.200000, grad mean: 7.009601
Iter 13 - agent_1 lam: 1.200000, grad mean: 2.166467
Iter 13 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 11.086571
Iter 13 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.324989
Iter 13 - agent_2 alpha: 1.200000, grad mean: 7.009601
Iter 13 - agent_2 lam: 1.200000, grad mean: 2.166467
Iter 13 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 11.086571
Iter 13 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.324989



episode_reward_mean_agent = 0.1386231631040573:   8%|▊         | 15/200 [05:12<1:04:08, 20.80s/it] [A

Iter 14 - agent_0 alpha: 1.200000, grad mean: 7.603292
Iter 14 - agent_0 lam: 1.500000, grad mean: 2.193408
Iter 14 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 11.675467
Iter 14 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.410027
Iter 14 - agent_1 alpha: 1.200000, grad mean: 7.340319
Iter 14 - agent_1 lam: 1.200000, grad mean: 2.193408
Iter 14 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 11.675467
Iter 14 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.328021
Iter 14 - agent_2 alpha: 1.200000, grad mean: 7.340319
Iter 14 - agent_2 lam: 1.200000, grad mean: 2.193408
Iter 14 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 11.675467
Iter 14 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.328021



episode_reward_mean_agent = 0.14030569791793823:   8%|▊         | 16/200 [05:32<1:03:43, 20.78s/it][A

Iter 15 - agent_0 alpha: 1.200000, grad mean: 7.935717
Iter 15 - agent_0 lam: 1.500000, grad mean: 2.219351
Iter 15 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 12.243419
Iter 15 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.413633
Iter 15 - agent_1 alpha: 1.200000, grad mean: 7.668577
Iter 15 - agent_1 lam: 1.200000, grad mean: 2.219351
Iter 15 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 12.243419
Iter 15 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.330906
Iter 15 - agent_2 alpha: 1.200000, grad mean: 7.668577
Iter 15 - agent_2 lam: 1.200000, grad mean: 2.219351
Iter 15 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 12.243419
Iter 15 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.330906



episode_reward_mean_agent = 0.1353548765182495:   8%|▊         | 17/200 [05:53<1:03:27, 20.81s/it] [A

Iter 16 - agent_0 alpha: 1.200000, grad mean: 8.268738
Iter 16 - agent_0 lam: 1.500000, grad mean: 2.245110
Iter 16 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 12.815495
Iter 16 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.417577
Iter 16 - agent_1 alpha: 1.200000, grad mean: 7.997627
Iter 16 - agent_1 lam: 1.200000, grad mean: 2.245110
Iter 16 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 12.815495
Iter 16 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.334061
Iter 16 - agent_2 alpha: 1.200000, grad mean: 7.997627
Iter 16 - agent_2 lam: 1.200000, grad mean: 2.245110
Iter 16 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 12.815495
Iter 16 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.334061



episode_reward_mean_agent = 0.1383940577507019:   9%|▉         | 18/200 [06:14<1:02:54, 20.74s/it][A

Iter 17 - agent_0 alpha: 1.200000, grad mean: 8.599236
Iter 17 - agent_0 lam: 1.500000, grad mean: 2.266936
Iter 17 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 13.388602
Iter 17 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.420327
Iter 17 - agent_1 alpha: 1.200000, grad mean: 8.324405
Iter 17 - agent_1 lam: 1.200000, grad mean: 2.266936
Iter 17 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 13.388602
Iter 17 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.336261
Iter 17 - agent_2 alpha: 1.200000, grad mean: 8.324405
Iter 17 - agent_2 lam: 1.200000, grad mean: 2.266936
Iter 17 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 13.388602
Iter 17 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.336261



episode_reward_mean_agent = 0.13772131502628326:  10%|▉         | 19/200 [06:35<1:02:29, 20.71s/it][A

Iter 18 - agent_0 alpha: 1.200000, grad mean: 8.933179
Iter 18 - agent_0 lam: 1.500000, grad mean: 2.290487
Iter 18 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 13.966878
Iter 18 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.423589
Iter 18 - agent_1 alpha: 1.200000, grad mean: 8.654492
Iter 18 - agent_1 lam: 1.200000, grad mean: 2.290487
Iter 18 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 13.966878
Iter 18 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.338870
Iter 18 - agent_2 alpha: 1.200000, grad mean: 8.654492
Iter 18 - agent_2 lam: 1.200000, grad mean: 2.290487
Iter 18 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 13.966878
Iter 18 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.338870



episode_reward_mean_agent = 0.1302420198917389:  10%|█         | 20/200 [06:55<1:02:04, 20.69s/it] [A

Iter 19 - agent_0 alpha: 1.200000, grad mean: 9.262519
Iter 19 - agent_0 lam: 1.500000, grad mean: 2.314017
Iter 19 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 14.545436
Iter 19 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.427063
Iter 19 - agent_1 alpha: 1.200000, grad mean: 8.980071
Iter 19 - agent_1 lam: 1.200000, grad mean: 2.314017
Iter 19 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 14.545436
Iter 19 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.341649
Iter 19 - agent_2 alpha: 1.200000, grad mean: 8.980071
Iter 19 - agent_2 lam: 1.200000, grad mean: 2.314017
Iter 19 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 14.545436
Iter 19 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.341649



episode_reward_mean_agent = 0.13681809604167938:  10%|█         | 21/200 [07:16<1:01:53, 20.74s/it][A

Iter 20 - agent_0 alpha: 1.200000, grad mean: 9.592777
Iter 20 - agent_0 lam: 1.500000, grad mean: 2.335630
Iter 20 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 15.123823
Iter 20 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.430175
Iter 20 - agent_1 alpha: 1.200000, grad mean: 9.306856
Iter 20 - agent_1 lam: 1.200000, grad mean: 2.335630
Iter 20 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 15.123823
Iter 20 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.344140
Iter 20 - agent_2 alpha: 1.200000, grad mean: 9.306856
Iter 20 - agent_2 lam: 1.200000, grad mean: 2.335630
Iter 20 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 15.123823
Iter 20 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.344140



episode_reward_mean_agent = 0.13795658946037292:  11%|█         | 22/200 [07:37<1:01:30, 20.74s/it][A

Iter 21 - agent_0 alpha: 1.200000, grad mean: 9.920970
Iter 21 - agent_0 lam: 1.500000, grad mean: 2.354654
Iter 21 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 15.690796
Iter 21 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.432622
Iter 21 - agent_1 alpha: 1.200000, grad mean: 9.631747
Iter 21 - agent_1 lam: 1.200000, grad mean: 2.354654
Iter 21 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 15.690796
Iter 21 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.346097
Iter 21 - agent_2 alpha: 1.200000, grad mean: 9.631747
Iter 21 - agent_2 lam: 1.200000, grad mean: 2.354654
Iter 21 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 15.690796
Iter 21 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.346097



episode_reward_mean_agent = 0.13500437140464783:  12%|█▏        | 23/200 [07:58<1:01:18, 20.78s/it][A

Iter 22 - agent_0 alpha: 1.200000, grad mean: 10.252801
Iter 22 - agent_0 lam: 1.500000, grad mean: 2.375675
Iter 22 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 16.266592
Iter 22 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.435073
Iter 22 - agent_1 alpha: 1.200000, grad mean: 9.959868
Iter 22 - agent_1 lam: 1.200000, grad mean: 2.375675
Iter 22 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 16.266592
Iter 22 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.348057
Iter 22 - agent_2 alpha: 1.200000, grad mean: 9.959868
Iter 22 - agent_2 lam: 1.200000, grad mean: 2.375675
Iter 22 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 16.266592
Iter 22 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.348057



episode_reward_mean_agent = 0.14003776013851166:  12%|█▏        | 24/200 [08:18<1:00:50, 20.74s/it][A

Iter 23 - agent_0 alpha: 1.200000, grad mean: 10.584785
Iter 23 - agent_0 lam: 1.500000, grad mean: 2.394324
Iter 23 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 16.840519
Iter 23 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.437351
Iter 23 - agent_1 alpha: 1.200000, grad mean: 10.288575
Iter 23 - agent_1 lam: 1.200000, grad mean: 2.394324
Iter 23 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 16.840519
Iter 23 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.349880
Iter 23 - agent_2 alpha: 1.200000, grad mean: 10.288575
Iter 23 - agent_2 lam: 1.200000, grad mean: 2.394324
Iter 23 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 16.840519
Iter 23 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.349880



episode_reward_mean_agent = 0.1383814662694931:  12%|█▎        | 25/200 [08:39<1:00:32, 20.75s/it] [A

Iter 24 - agent_0 alpha: 1.200000, grad mean: 10.916213
Iter 24 - agent_0 lam: 1.500000, grad mean: 2.412929
Iter 24 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 17.431030
Iter 24 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.439931
Iter 24 - agent_1 alpha: 1.200000, grad mean: 10.616799
Iter 24 - agent_1 lam: 1.200000, grad mean: 2.412929
Iter 24 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 17.431030
Iter 24 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.351944
Iter 24 - agent_2 alpha: 1.200000, grad mean: 10.616799
Iter 24 - agent_2 lam: 1.200000, grad mean: 2.412929
Iter 24 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 17.431030
Iter 24 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.351944



episode_reward_mean_agent = 0.1394975334405899:  13%|█▎        | 26/200 [09:00<1:00:17, 20.79s/it][A

Iter 25 - agent_0 alpha: 1.200000, grad mean: 11.244759
Iter 25 - agent_0 lam: 1.500000, grad mean: 2.430208
Iter 25 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 18.004625
Iter 25 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.442020
Iter 25 - agent_1 alpha: 1.200000, grad mean: 10.942194
Iter 25 - agent_1 lam: 1.200000, grad mean: 2.430208
Iter 25 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 18.004625
Iter 25 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.353615
Iter 25 - agent_2 alpha: 1.200000, grad mean: 10.942194
Iter 25 - agent_2 lam: 1.200000, grad mean: 2.430208
Iter 25 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 18.004625
Iter 25 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.353615



episode_reward_mean_agent = 0.13898040354251862:  14%|█▎        | 27/200 [09:21<59:49, 20.75s/it] [A

Iter 26 - agent_0 alpha: 1.200000, grad mean: 11.574458
Iter 26 - agent_0 lam: 1.500000, grad mean: 2.447019
Iter 26 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 18.587656
Iter 26 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.444171
Iter 26 - agent_1 alpha: 1.200000, grad mean: 11.268889
Iter 26 - agent_1 lam: 1.200000, grad mean: 2.447019
Iter 26 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 18.587656
Iter 26 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.355336
Iter 26 - agent_2 alpha: 1.200000, grad mean: 11.268889
Iter 26 - agent_2 lam: 1.200000, grad mean: 2.447019
Iter 26 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 18.587656
Iter 26 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.355336



episode_reward_mean_agent = 0.13494780659675598:  14%|█▍        | 28/200 [09:41<59:25, 20.73s/it][A

Iter 27 - agent_0 alpha: 1.200000, grad mean: 11.903017
Iter 27 - agent_0 lam: 1.500000, grad mean: 2.462561
Iter 27 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 19.164442
Iter 27 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.445632
Iter 27 - agent_1 alpha: 1.200000, grad mean: 11.594463
Iter 27 - agent_1 lam: 1.200000, grad mean: 2.462561
Iter 27 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 19.164442
Iter 27 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.356505
Iter 27 - agent_2 alpha: 1.200000, grad mean: 11.594463
Iter 27 - agent_2 lam: 1.200000, grad mean: 2.462561
Iter 27 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 19.164442
Iter 27 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.356505



episode_reward_mean_agent = 0.13922740519046783:  14%|█▍        | 29/200 [10:02<59:01, 20.71s/it][A

Iter 28 - agent_0 alpha: 1.200000, grad mean: 12.228605
Iter 28 - agent_0 lam: 1.500000, grad mean: 2.477068
Iter 28 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 19.742495
Iter 28 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.447329
Iter 28 - agent_1 alpha: 1.200000, grad mean: 11.917430
Iter 28 - agent_1 lam: 1.200000, grad mean: 2.477068
Iter 28 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 19.742495
Iter 28 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.357862
Iter 28 - agent_2 alpha: 1.200000, grad mean: 11.917430
Iter 28 - agent_2 lam: 1.200000, grad mean: 2.477068
Iter 28 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 19.742495
Iter 28 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.357862



episode_reward_mean_agent = 0.13774849474430084:  15%|█▌        | 30/200 [10:23<58:53, 20.78s/it][A

Iter 29 - agent_0 alpha: 1.200000, grad mean: 12.555275
Iter 29 - agent_0 lam: 1.500000, grad mean: 2.492884
Iter 29 - agent_0 w_plus_prime_gamma: 0.500000, grad mean: 20.324966
Iter 29 - agent_0 w_minus_prime_gamma: 0.690000, grad mean: 0.449215
Iter 29 - agent_1 alpha: 1.200000, grad mean: 12.241251
Iter 29 - agent_1 lam: 1.200000, grad mean: 2.492884
Iter 29 - agent_1 w_plus_prime_gamma: 0.500000, grad mean: 20.324966
Iter 29 - agent_1 w_minus_prime_gamma: 0.690000, grad mean: 0.359371
Iter 29 - agent_2 alpha: 1.200000, grad mean: 12.241251
Iter 29 - agent_2 lam: 1.200000, grad mean: 2.492884
Iter 29 - agent_2 w_plus_prime_gamma: 0.500000, grad mean: 20.324966
Iter 29 - agent_2 w_minus_prime_gamma: 0.690000, grad mean: 0.359371


KeyError: 'agents'

This is our "test" to make sure our agents are trainng, we see after the agent stops training the adversaries rewards are increasing and then while it is trainng their rewards both go to 0

In [None]:
fig, axs = plt.subplots(1, 1)
for i, group in enumerate(env.group_map.keys()):
    axs.plot(episode_reward_mean_map[group], label=f"Episode reward mean {group}")
    axs.set_ylabel("Reward")
    axs.legend()
axs.set_xlabel("Training iterations")
plt.show()
print(env.group_map.keys())

In [None]:
if use_vmas and not is_sphinx:
    # Replace tmpdir with any desired path where the video should be saved
    with tempfile.TemporaryDirectory() as tmpdir:
        video_logger = CSVLogger("vmas_logs", tmpdir, video_format="mp4")
        print("Creating rendering env")
        env_with_render = TransformedEnv(env.base_env, env.transform.clone())
        env_with_render = env_with_render.append_transform(
            PixelRenderTransform(
                out_keys=["pixels"],
                # the np.ndarray has a negative stride and needs to be copied before being cast to a tensor
                preproc=lambda x: x.copy(),
                as_non_tensor=True,
                # asking for array rather than on-screen rendering
                mode="rgb_array",
            )
        )
        env_with_render = env_with_render.append_transform(
            VideoRecorder(logger=video_logger, tag="vmas_rendered")
        )
        with set_exploration_type(ExplorationType.DETERMINISTIC):
            print("Rendering rollout...")
            env_with_render.rollout(100, policy=agents_exploration_policy)
        print("Saving the video...")
        env_with_render.transform.dump()
        print("Saved! Saved directory tree:")
        video_logger.print_log_dir()

Creating rendering env
Rendering rollout...


ImportError: Error occurred while running `from pyglet.gl import *`, HINT: make sure you have OpenGL installed. On Ubuntu, you can run 'apt-get install python3-opengl'. If you're running on a server, you may need a virtual frame buffer; something like this should work:'xvfb-run -s "-screen 0 1400x900x24" python <your_script.py>'

In [None]:
import os

# Define a permanent directory path (e.g., "local_videos")
local_dir = "local_videos"
os.makedirs(local_dir, exist_ok=True)

# Use the permanent directory instead of a temporary one
video_logger = CSVLogger("vmas_logs", local_dir, video_format="mp4")
print("Creating rendering env")
env_with_render = TransformedEnv(env.base_env, env.transform.clone())
env_with_render = env_with_render.append_transform(
    PixelRenderTransform(
        out_keys=["pixels"],
        # the np.ndarray has a negative stride and needs to be copied before being cast to a tensor
        preproc=lambda x: x.copy(),
        as_non_tensor=True,
        # asking for array rather than on-screen rendering
        mode="rgb_array",
    )
)
env_with_render = env_with_render.append_transform(
    VideoRecorder(logger=video_logger, tag="vmas_rendered")
)
with set_exploration_type(ExplorationType.DETERMINISTIC):
    print("Rendering rollout...")
    env_with_render.rollout(100, policy=agents_exploration_policy)
print("Saving the video...")
env_with_render.transform.dump()
print("Saved! Saved directory tree:")
video_logger.print_log_dir()