In [1]:
import os

import jax
import jax.numpy as jnp

from brax.envs import State as EnvState

from functools import partial
from typing import Any, Callable, Tuple, Optional, List

from qdax import environments

from qdax.core.neuroevolution.buffers.buffer import (
    ReplayBuffer,
    Transition,
)

from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition
from qdax.core.neuroevolution.networks.matd3_networks import make_matd3_networks
from qdax.core.neuroevolution.mdp_utils import TrainingState
from qdax.custom_types import Metrics
from qdax.custom_types import (
    Action,
    Descriptor,
    Mask,
    Metrics,
    Observation,
    Params,
    Reward,
    RNGKey,
)

# Multiagent shiet
from qdax.environments.multi_agent_wrappers import MultiAgentBraxWrapper

2025-11-13 21:02:43.233068: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.9.86). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
#@title QD Training Definitions Fields
#@markdown ---
num_iterations = 1000
log_interval = 10
env_name = 'hopper_uni'#@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
parameter_sharing=False
emitter_type="mix"
homogenisation_method="concat"
episode_length = 1000 #@param {type:"integer"}
num_timesteps = 7_864_320 #@param {type:"integer"}
seed = 1 #@param {type:"integer"}
policy_hidden_layer_sizes = (64, 64) #@param {type:"raw"}
policy_learning_rate = 3e-4
min_bd = 0. #@param {type:"number"}
max_bd = 1.0 #@param {type:"number"}
warmstart_steps=25_600
num_evals=20

# CEM
warmup_iters: int = 10 # number of iter update with only CEM
population_size: int = 10
num_best: Optional[int] = None
damp_init: float = 1e-3
damp_final: float = 1e-5
damp_tau : float = 0.95
rank_weight_shift: float = 1.0
mirror_sampling: bool = True
weighted_update: bool = True
num_learning_offspring: Optional[int] = population_size//2
# TD3 params
num_rl_updates_per_iter: int = 4000
env_batch_size = 128 #@param {type:"number"}
batch_size=256
expl_noise = 0.1
policy_noise = 0.2
noise_clip = 0.5
grad_updates_per_step=1.0 #@param {type:"number"}
replay_buffer_size = 1000000 #@param {type:"number"}
critic_hidden_layer_sizes = (256, 256) #@param {type:"raw"}
critic_learning_rate = 1e-3
discount = 0.99 #@param {type:"number"}
reward_scaling = 1.0 #@param {type:"number"}
soft_tau_update = 0.005 #@param {type:"number"}
policy_delay = 4 #@param {type:"number"}
max_grad_norm = 30.0
use_layer_norm=True
#@markdown ---

In [3]:
from qdax.baselines.cem_matd3 import CEMMATD3, CEMMATD3Config, CEMMATD3TrainingState

## Warmstart related functions

In [4]:
import functools

# Define the fonction random
def warmstart_play_step_fn(
    env_state: EnvState,
    random_key: RNGKey,
    env: MultiAgentBraxWrapper,
):
    """
    Play an environment step and return the updated state and the transition.
    """
    random_key, subkey = jax.random.split(random_key)

    action_sizes = env.get_action_sizes()

    keys = jax.random.split(subkey, len(action_sizes))

    actions = {
        agent_idx: jax.random.uniform(agent_key, (size,), minval=-1, maxval=1)
        for (agent_idx, size), agent_key in zip(action_sizes.items(), keys)
    }

    flatten_actions = jnp.concatenate([a for a in actions.values()])

    next_state = env.step(env_state, actions)

    transition = Transition(
        obs=next_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        actions=flatten_actions,
        truncations=next_state.info["truncation"],
    )

    return next_state, random_key, transition

def generate_unroll_warmstart(
    random_key: RNGKey,
    env_state: EnvState,
    env: MultiAgentBraxWrapper,
    warmstart_play_step_fn: Callable[
        [EnvState, RNGKey, MultiAgentBraxWrapper],
        Tuple[
            EnvState,
            RNGKey,
            Transition,
        ],
    ],
    warmstart_steps: int,
) -> Tuple[EnvState, Transition]:
    """Pre-populates the buffer with transitions. Returns the warmstarted buffer
    and the new state of the environment.
    """

    def _scan_play_step_fn(
        carry: Tuple[EnvState, RNGKey], unused_arg: Any
    ) -> Tuple[Tuple[EnvState, RNGKey], Transition]:
        env_state, random_key, transitions = warmstart_play_step_fn(*carry, env)
        return (env_state, random_key), transitions

    (env_state, random_key), transitions = jax.lax.scan(
        _scan_play_step_fn,
        (env_state, random_key),
        (),
        length=warmstart_steps
    )

    return env_state, transitions

@functools.partial(
        jax.jit,
        static_argnames=("env", "warmstart_play_step_fn", "warmstart_steps", "env_batch_size")
)
def warmstart_buffer(
    env: MultiAgentBraxWrapper,
    replay_buffer: ReplayBuffer,
    training_state: CEMMATD3TrainingState,
    warmstart_play_step_fn: Callable[
        [EnvState, RNGKey, MultiAgentBraxWrapper],
        Tuple[
            EnvState,
            RNGKey,
            Transition,
        ],
    ],
    warmstart_steps: int,
    env_batch_size: int,
):
    

    generate_unroll = functools.partial(
        generate_unroll_warmstart,
        env = env,
        warmstart_play_step_fn=warmstart_play_step_fn,
        warmstart_steps=warmstart_steps//env_batch_size
    )

    generate_unroll_vmap = jax.vmap(
        generate_unroll,
        in_axes=(0, 0)
    )

    random_key, subkey = jax.random.split(training_state.random_key)
    keys = jax.random.split(subkey, env_batch_size)


    training_state = training_state.replace(
        random_key=random_key
    )

    reset_fn = jax.vmap(env.reset)

    env_states = reset_fn(keys)

    random_key, subkey = jax.random.split(training_state.random_key)

    training_state = training_state.replace(
        random_key=random_key
    )

    keys = jax.random.split(subkey, env_batch_size)
    env_states, transitions = generate_unroll_vmap(keys, env_states)

    # jax.debug.print("obs shape {obs}", obs=transitions.obs.shape)

    replay_buffer = replay_buffer.insert(transitions)
    
    return replay_buffer, training_state

## Prepare env and agent

In [5]:
base_env_name = env_name.split("_")[0]
env = environments.create(env_name, episode_length=episode_length)
env = MultiAgentBraxWrapper(
    env,
    env_name=base_env_name,
    parameter_sharing=False,
    emitter_type=emitter_type,
    homogenisation_method=homogenisation_method
)

policy_network, critic_network = make_matd3_networks(
    action_sizes=env.get_action_sizes(),
    critic_hidden_layer_sizes=critic_hidden_layer_sizes,
    policy_hidden_layer_sizes=policy_hidden_layer_sizes,
    use_layer_norm=use_layer_norm,
)

def play_step_fn(
    env_state: EnvState,
    policy_params: List[Params],
    random_key: RNGKey,
) -> Tuple[EnvState, RNGKey, Transition]:
    """Plays a step in the environment. Selects an action according to TD3 rule and
    performs the environment step.

    Args:
        env_state: the current environment state
        training_state: the SAC training state
        env: the environment
        deterministic: whether to select action in a deterministic way.
            Defaults to False.

    Returns:
        the new environment state
        the new TD3 training state
        the played transition
    """
    obs=env.obs(env_state)
    actions = {
        agent_idx: network.apply(params, agent_obs)
        for (agent_idx, network), params, agent_obs in zip(
            policy_network.items(), policy_params, obs.values()
        )
    }

    next_env_state = env.step(env_state, actions)

    flatten_action = jnp.concatenate([a for a in actions.values()])

    transition = Transition(
        obs=env_state.obs,
        next_obs=next_env_state.obs,
        rewards=next_env_state.reward,
        dones=next_env_state.done,
        truncations=next_env_state.info["truncation"],
        actions=flatten_action,
    )
    return next_env_state, random_key, transition



def generate_unroll(
        policy_params,
        env_state,
        random_key,
        play_step_fn,
        episode_length,
):
    def _scan_play_step_fn(
        carry: Tuple[EnvState, Params, RNGKey], unused_arg: Any
    ):
        env_state, policy_params, random_key = carry
        next_state, random_key, transition = play_step_fn(
            env_state, policy_params, random_key
        )
        return (next_state, policy_params, random_key), transition

    (env_state, policy_params, random_key), transitions = jax.lax.scan(
        _scan_play_step_fn,
        (env_state, policy_params, random_key),
        (),
        length=episode_length
    )
    return env_state, transitions


def scoring_function(
        policies_params,
        random_key,
        play_step_fn,
        play_reset_fn,
        episode_length,
):
    random_key, subkey = jax.random.split(random_key)
    keys = jax.random.split(
        subkey, jax.tree_util.tree_leaves(policies_params)[0].shape[0]
    )
    reset_fn = jax.vmap(play_reset_fn)
    init_states = reset_fn(keys)

    unroll_fn = functools.partial(
        generate_unroll,
        episode_length=episode_length,
        play_step_fn=play_step_fn,
        random_key=subkey,
    )

    # jax.debug.print("init_states {a}", a=init_states.obs.shape)

    # print("num_pol:",jax.tree_util.tree_leaves(policies_params)[0].shape[0] )
    # print("init state", init_states.obs.shape)

    env_states, data = jax.vmap(unroll_fn)(
        policies_params,
        init_states,
    )

    # create a mask to extract data properly
    is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1)
    mask = jnp.roll(is_done, 1, axis=1)
    mask = mask.at[:, 0].set(0)

    # scores
    fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1)

    return fitnesses, data, random_key


scoring_fn = functools.partial(
    scoring_function,
    play_step_fn=play_step_fn,
    play_reset_fn=env.reset,
    episode_length=episode_length,
)


print(env.get_action_sizes())
print(env.get_obs_sizes())

random_key = jax.random.PRNGKey(seed)
num_agents = len(env.get_action_sizes())

# Make sure to pass the correct config parameters
matd3_config = CEMMATD3Config(
    num_agents=len(env.get_action_sizes()),
    episode_length=episode_length,
    batch_size=batch_size,
    policy_delay=policy_delay,
    soft_tau_update=soft_tau_update,
    expl_noise=expl_noise,
    critic_hidden_layer_size=critic_hidden_layer_sizes,  
    policy_hidden_layer_size=policy_hidden_layer_sizes,  
    critic_learning_rate=critic_learning_rate,
    policy_learning_rate=policy_learning_rate,
    discount=discount,
    noise_clip=noise_clip,
    policy_noise=policy_noise,
    reward_scaling=reward_scaling,
    max_grad_norm=max_grad_norm
)

cem_matd3_agent = CEMMATD3(config=matd3_config, env=env, scoring_function=scoring_fn)

training_state = cem_matd3_agent.init(random_key=random_key)

reset_fn = jax.vmap(env.reset)

random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, env_batch_size)

env_states = reset_fn(keys)

{0: 1, 1: 1, 2: 1}
{0: 8, 1: 9, 2: 8}


In [6]:
# Init replay buffer
dummmy_transition = Transition.init_dummy(observation_dim=env.observation_size, action_dim=env.action_size)

replay_buffer = ReplayBuffer.init(buffer_size=replay_buffer_size,  transition=dummmy_transition)

In [7]:
replay_buffer, training_state = warmstart_buffer(
    env=env,
    replay_buffer=replay_buffer,
    training_state=training_state,
    warmstart_play_step_fn = warmstart_play_step_fn,
    warmstart_steps=warmstart_steps,
    env_batch_size=20,
)

## Training/logging loop

In [None]:
import wandb
import time
from datetime import datetime

# Initialize wandb with proper run naming
def init_wandb_logging():
    """Initialize wandb with descriptive run name"""
    run_name = f"CEMMATD3_{env_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    wandb.init(
        project="matd3-multiagent-rl",
        name=run_name,
        config={
            # Environment config
            "env_name": env_name,
            "episode_length": episode_length,
            "num_agents": num_agents,
            "parameter_sharing": parameter_sharing,
            "emitter_type": emitter_type,
            "homogenisation_method": homogenisation_method,
            
            # Training config
            "num_timesteps": num_timesteps,
            "env_batch_size": env_batch_size,
            "warmstart_steps": warmstart_steps,
            "grad_updates_per_step": grad_updates_per_step,
            "log_period": log_interval,
            "num_evals": num_evals,
            
            # MATD3 hyperparameters
            "batch_size": batch_size,
            "policy_learning_rate": policy_learning_rate,
            "critic_learning_rate": critic_learning_rate,
            "discount": discount,
            "soft_tau_update": soft_tau_update,
            "policy_delay": policy_delay,
            "expl_noise": expl_noise,
            "noise_clip": noise_clip,
            "reward_scaling": reward_scaling,
            "replay_buffer_size": replay_buffer_size,
            
            # Network architecture
            "policy_hidden_layer_sizes": policy_hidden_layer_sizes,
            "critic_hidden_layer_sizes": critic_hidden_layer_sizes,
            
            # Other
            "seed": seed,
        },
        tags=["cemmatd3", "multiagent", env_name.split("_")[0]]
    )

# Alternative: Pass parameters explicitly to the function
def run_training_loop_with_logging_v2(training_state, replay_buffer, env_states):
    """Complete training loop with wandb logging and error handling - version with explicit parameters"""
    
    # Initialize wandb
    init_wandb_logging()
    
    try:
        # Calculate training parameters
        num_loops = int(num_iterations / log_interval)
        
        print(f"Training Configuration:")
        print(f"  Total timesteps: {num_timesteps:,}")
        print(f"  Env batch size: {env_batch_size}")
        print(f"  Total iterations: {num_iterations:,}")
        print(f"  Log period: {log_interval}")
        print(f"  Number of training loops: {num_loops}")
        print(f"  Warmstart steps: {warmstart_steps:,}")
        
        # Initialize random key for evaluation
        random_key_local = jax.random.PRNGKey(seed + 1000)  # Different seed for eval
        
        # Training metrics tracking
        start_time = time.time()
        
        for i in range(num_loops):
            loop_start_time = time.time()
            
            # Training step
            (training_state, replay_buffer), train_metrics = jax.lax.scan(
                cem_matd3_agent.scan_update,
                (training_state, replay_buffer),
                (),
                length=log_interval,
            )
            train_metrics = jax.tree_util.tree_map(
                lambda x: jnp.mean(x), train_metrics
            )
            # Evaluation
            training_state, eval_metrics = cem_matd3_agent.evaluate(training_state)

            metrics = train_metrics | eval_metrics
        
            
            # Calculate additional metrics
            current_timesteps = warmstart_steps + (i + 1) * population_size * log_interval * episode_length
            loop_time = time.time() - loop_start_time
            total_time = time.time() - start_time
            timesteps_per_second = (env_batch_size * log_interval) / loop_time
            
            
            # Log to wandb
            wandb.log({
                "training/timesteps": current_timesteps,
                "training/loop": i,
                "evaluation/mean_return": metrics["center_fitness_average"],
                "evaluation/return_std": metrics["center_fitness_std"],

                "performance/timesteps_per_second": timesteps_per_second,
                "performance/loop_time": loop_time,
                "performance/total_time": total_time,
                "training/replay_buffer_size": replay_buffer.current_size,
                "training/training_steps": training_state.steps,
                "rl_in_elites_percentage": metrics["rl_in_elites_percentage"]
            })
            
            # Console output
            print(f"Loop {i:4d}/{num_loops} | "
                  f"Steps: {current_timesteps:8,} | "
                  f"Fitness: {metrics['center_fitness_average']:7.2f} | "
                  f"Time: {loop_time:6.2f}s | "
                  f"TPS: {timesteps_per_second:8.1f}")
        
        
        print(f"\nTraining completed!")
        print(f"Total time: {total_time:.2f}s")
        print(f"Final training steps: {training_state.steps}")
        
        # Final logging
        wandb.log({
            "final/total_time": total_time,
        })
        
        return training_state, replay_buffer
        
    except Exception as e:
        print(f"Training failed with error: {e}")
        wandb.log({"error": str(e)})
        raise e
    
    finally:
        wandb.finish()

# Run the training loop - use the version that passes parameters explicitly
final_training_state, final_replay_buffer = run_training_loop_with_logging_v2(
    training_state, replay_buffer, env_states
)

  return LooseVersion(v) >= LooseVersion(check)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtherealtin[0m ([33mtherealtin-uit[0m). Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import display


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668288416667565, max=1.0â€¦

Training Configuration:
  Total timesteps: 7,864,320
  Env batch size: 128
  Total iterations: 1,000
  Log period: 10
  Number of training loops: 100
  Warmstart steps: 25,600


  selected_offsprings = jax.tree_map(lambda x: x[:self._num_learning_offspring], offsprings)
  new_offsprings = jax.tree_map(


Loop    0/100 | Steps:  125,600 | Fitness:  206.74 | Time:  27.36s | TPS:     46.8
Loop    1/100 | Steps:  225,600 | Fitness:  221.67 | Time:  24.84s | TPS:     51.5
Loop    2/100 | Steps:  325,600 | Fitness:  238.00 | Time:  26.50s | TPS:     48.3
Loop    3/100 | Steps:  425,600 | Fitness:  335.20 | Time:  26.16s | TPS:     48.9
Loop    4/100 | Steps:  525,600 | Fitness:  523.78 | Time:  27.09s | TPS:     47.2
Loop    5/100 | Steps:  625,600 | Fitness:  474.54 | Time:  25.77s | TPS:     49.7
Loop    6/100 | Steps:  725,600 | Fitness:  536.76 | Time:  26.28s | TPS:     48.7
Loop    7/100 | Steps:  825,600 | Fitness:  351.44 | Time:  27.64s | TPS:     46.3
Loop    8/100 | Steps:  925,600 | Fitness:  355.10 | Time:  27.71s | TPS:     46.2
Loop    9/100 | Steps: 1,025,600 | Fitness:  381.70 | Time:  27.29s | TPS:     46.9
Loop   10/100 | Steps: 1,125,600 | Fitness:  397.64 | Time:  27.42s | TPS:     46.7
Loop   11/100 | Steps: 1,225,600 | Fitness:  409.73 | Time:  26.66s | TPS:     48.0
L