In [1]:
import os

import jax
import jax.numpy as jnp

from brax.envs import State as EnvState

from functools import partial
from typing import Any, Callable, Tuple

from qdax import environments
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs
from qdax.baselines.td3 import TD3Config, TD3, TD3TrainingState
from qdax.core.neuroevolution.buffers.buffer import (
    QDTransition,
    ReplayBuffer,
    Transition,
)
from qdax.core.neuroevolution.sac_td3_utils import warmstart_buffer, generate_unroll, do_iteration_fn
from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition
from qdax.core.neuroevolution.mdp_utils import TrainingState
from qdax.custom_types import Metrics
from qdax.custom_types import (
    Action,
    Descriptor,
    Mask,
    Metrics,
    Observation,
    Params,
    Reward,
    RNGKey,
)

# Multiagent shiet
from qdax.environments.multi_agent_wrappers import MultiAgentBraxWrapper

2025-07-26 02:11:57.168347: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.9.86). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
#@title QD Training Definitions Fields
#@markdown ---
env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
parameter_sharing=False
emitter_type="mix"
homogenisation_method="concat"
episode_length = 1000 #@param {type:"integer"}
num_timesteps = 7_864_320 #@param {type:"integer"}
seed = 1 #@param {type:"integer"}
policy_hidden_layer_sizes = (64, 64) #@param {type:"raw"}
policy_learning_rate = 3e-4
num_init_cvt_samples = 50000 #@param {type:"integer"}
num_centroids = 1024 #@param {type:"integer"}
min_bd = 0. #@param {type:"number"}
max_bd = 1.0 #@param {type:"number"}
warmstart_steps=8192*10
num_evals=20
log_period=1024
# proportion_mutation_ga = 0.5 #@param {type:"number"}

# TD3 params
env_batch_size = 128 #@param {type:"number"}
batch_size=512
expl_noise = 0.1
noise_clip = 0.5
grad_updates_per_step=0.3 #@param {type:"number"}
replay_buffer_size = 1000000 #@param {type:"number"}
critic_hidden_layer_sizes = (256, 256) #@param {type:"raw"}
critic_learning_rate = 3e-4
discount = 0.99 #@param {type:"number"}
reward_scaling = 1.0 #@param {type:"number"}
soft_tau_update = 0.005 #@param {type:"number"}
policy_delay = 2 #@param {type:"number"}
alpha_init=1.0
fix_alpha=False
#@markdown ---

In [3]:
from qdax.core.neuroevolution.networks.matd3_networks import make_matd3_networks
from qdax.core.neuroevolution.losses.matd3_loss import matd3_critic_loss_fn, matd3_policy_loss_fn
from qdax.baselines.matd3 import MATD3, MATD3Config, MATD3TrainingState

## Warmstart related functions

In [4]:
import functools

# Define the fonction random
def warmstart_play_step_fn(
    env_state: EnvState,
    random_key: RNGKey,
    env: MultiAgentBraxWrapper,
):
    """
    Play an environment step and return the updated state and the transition.
    """
    random_key, subkey = jax.random.split(random_key)

    action_sizes = env.get_action_sizes()

    keys = jax.random.split(subkey, len(action_sizes))

    actions = {
        agent_idx: jax.random.uniform(agent_key, (size,), minval=-1, maxval=1)
        for (agent_idx, size), agent_key in zip(action_sizes.items(), keys)
    }

    flatten_actions = jnp.concatenate([a for a in actions.values()])

    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=next_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        actions=flatten_actions,
        truncations=next_state.info["truncation"],
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
    )

    return next_state, random_key, transition

def generate_unroll_warmstart(
    random_key: RNGKey,
    env_state: EnvState,
    env: MultiAgentBraxWrapper,
    warmstart_play_step_fn: Callable[
        [EnvState, RNGKey, MultiAgentBraxWrapper],
        Tuple[
            EnvState,
            RNGKey,
            Transition,
        ],
    ],
    warmstart_steps: int,
) -> Tuple[EnvState, Transition]:
    """Pre-populates the buffer with transitions. Returns the warmstarted buffer
    and the new state of the environment.
    """

    def _scan_play_step_fn(
        carry: Tuple[EnvState, RNGKey], unused_arg: Any
    ) -> Tuple[Tuple[EnvState, RNGKey], Transition]:
        env_state, random_key, transitions = warmstart_play_step_fn(*carry, env)
        return (env_state, random_key), transitions

    (env_state, random_key), transitions = jax.lax.scan(
        _scan_play_step_fn,
        (env_state, random_key),
        (),
        length=warmstart_steps
    )

    return env_state, transitions

@functools.partial(
        jax.jit,
        static_argnames=("env", "warmstart_play_step_fn", "warmstart_steps", "env_batch_size")
)
def warmstart_buffer(
    env: MultiAgentBraxWrapper,
    replay_buffer: ReplayBuffer,
    training_state: MATD3TrainingState,
    warmstart_play_step_fn: Callable[
        [EnvState, RNGKey, MultiAgentBraxWrapper],
        Tuple[
            EnvState,
            RNGKey,
            Transition,
        ],
    ],
    warmstart_steps: int,
    env_batch_size: int,
):
    

    generate_unroll = functools.partial(
        generate_unroll_warmstart,
        env = env,
        warmstart_play_step_fn=warmstart_play_step_fn,
        warmstart_steps=warmstart_steps//env_batch_size
    )

    generate_unroll_vmap = jax.vmap(
        generate_unroll,
        in_axes=(0, 0)
    )

    random_key, subkey = jax.random.split(training_state.random_key)
    keys = jax.random.split(subkey, env_batch_size)


    training_state = training_state.replace(
        random_key=random_key
    )

    reset_fn = jax.vmap(env.reset)

    env_states = reset_fn(keys)

    random_key, subkey = jax.random.split(training_state.random_key)

    training_state = training_state.replace(
        random_key=random_key
    )

    keys = jax.random.split(subkey, env_batch_size)
    env_states, transitions = generate_unroll_vmap(keys, env_states)

    # jax.debug.print("obs shape {obs}", obs=transitions.obs.shape)

    replay_buffer = replay_buffer.insert(transitions)
    
    return replay_buffer, training_state

## Prepare env and agent

In [5]:
base_env_name = env_name.split("_")[0]
env = environments.create(env_name, episode_length=episode_length)
env = MultiAgentBraxWrapper(
    env,
    env_name=base_env_name,
    parameter_sharing=False,
    emitter_type=emitter_type,
    homogenisation_method=homogenisation_method
)


random_key = jax.random.PRNGKey(seed)
num_agents = len(env.get_action_sizes())

# Make sure to pass the correct config parameters
matd3_config = MATD3Config(
    num_agents=len(env.get_action_sizes()),
    episode_length=episode_length,
    batch_size=batch_size,
    policy_delay=policy_delay,
    soft_tau_update=soft_tau_update,
    expl_noise=expl_noise,
    critic_hidden_layer_size=critic_hidden_layer_sizes,  # Fixed parameter name
    policy_hidden_layer_size=policy_hidden_layer_sizes,  # Fixed parameter name
    critic_learning_rate=critic_learning_rate,
    policy_learning_rate=policy_learning_rate,
    discount=discount,
    noise_clip=noise_clip,
    policy_noise=expl_noise,  # or define separate policy_noise
    reward_scaling=reward_scaling,
    max_grad_norm=10.0
)

matd3_agent = MATD3(config=matd3_config, action_sizes=env.get_action_sizes())

training_state = matd3_agent.init(random_key=random_key, 
                                  action_sizes_each_agent=env.get_action_sizes(),
                                  observation_size_raw=env.observation_size,
                                  observation_sizes_each_agent=env.get_obs_sizes())

# env_state = env.reset(random_key)

reset_fn = jax.vmap(env.reset)

random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, env_batch_size)

env_states = reset_fn(keys)

In [6]:
# Init replay buffer
dummmy_transition = QDTransition.init_dummy(observation_dim=env.observation_size, action_dim=env.action_size, descriptor_dim=env.behavior_descriptor_length)

replay_buffer = ReplayBuffer.init(buffer_size=replay_buffer_size,  transition=dummmy_transition)

In [7]:
replay_buffer, training_state = warmstart_buffer(
    env=env,
    replay_buffer=replay_buffer,
    training_state=training_state,
    warmstart_play_step_fn = warmstart_play_step_fn,
    warmstart_steps=warmstart_steps,
    env_batch_size=8,
)

## Policy | Env | replay_buffer vmap/scan interaction

In [8]:
play_step_fn = functools.partial(
    matd3_agent.play_qd_step_fn,
    env=env,
    deterministic=False
)

# Create the scan_update function
@functools.partial(jax.jit, static_argnames=("unflatten_obs_fn", "unflatten_actions_fn"))
def scan_update(
    carry: Tuple[MATD3TrainingState, ReplayBuffer],
    unused: Any,
    unflatten_obs_fn: Callable[[jnp.ndarray], dict[int, jnp.ndarray]],
    unflatten_actions_fn: Callable[[jnp.ndarray], dict[int, jnp.ndarray]],
) -> Tuple[Tuple[MATD3TrainingState, ReplayBuffer], Metrics]:
    """Single update step for the scan operation"""
    training_state, replay_buffer = carry
    
    # Perform one update step
    new_training_state, new_replay_buffer, metrics = matd3_agent.update(
        training_state, 
        replay_buffer,
        unflatten_obs_fn,
        unflatten_actions_fn,
    )
    
    return (new_training_state, new_replay_buffer), metrics

# Now create the clean single_step_and_update function
@functools.partial(jax.jit, static_argnames=("play_step_fn", "unflatten_obs_fn", "unflatten_actions_fn", "num_updates"))
def single_step_and_update(
    carry: [EnvState, MATD3TrainingState, ReplayBuffer],
    _,
    # env_states: EnvState,
    # replay_buffer: ReplayBuffer,
    # training_state: MATD3TrainingState,
    play_step_fn: Callable[
        [EnvState, MATD3TrainingState],
        Tuple[
            EnvState,
            MATD3TrainingState,
            QDTransition,
        ],
    ],
    unflatten_obs_fn: Callable[[jnp.ndarray], dict[int, jnp.ndarray]],
    unflatten_actions_fn: Callable[[jnp.ndarray], dict[int, jnp.ndarray]],
    num_updates: int
) -> Tuple[Tuple[EnvState, MATD3TrainingState, ReplayBuffer], Metrics]:
    """Performs one environment step followed by multiple gradient updates"""
    
    # Vectorized environment step
    play_step_fn_vmap = jax.vmap(
        play_step_fn, 
        in_axes=(0, None), 
        out_axes=(0, None, 0)
    )

    env_states,training_state, replay_buffer = carry

    env_states, training_state, transitions = play_step_fn_vmap(env_states, training_state)

    # Insert transitions into replay buffer
    replay_buffer = replay_buffer.insert(transitions)
    
    # Create partial function for scan_update with fixed unflatten_obs_fn
    scan_update_partial = functools.partial(scan_update, unflatten_obs_fn=unflatten_obs_fn, unflatten_actions_fn=unflatten_actions_fn)

    # Perform multiple gradient updates
    (training_state, replay_buffer), metrics = jax.lax.scan(
        scan_update_partial,
        (training_state, replay_buffer),
        (),
        length=num_updates
    )

    return (env_states, training_state, replay_buffer), metrics

## Functions to change flatten obs, actions to dictionary

In [9]:
def unflatten_obs_fn(global_obs: jnp.ndarray, env:MultiAgentBraxWrapper) -> dict[int, jnp.ndarray]:
    agent_obs = {}
    for agent_idx, obs_indices in env.agent_obs_mapping.items():
            agent_obs[agent_idx] = global_obs[obs_indices]
    return agent_obs

def unflatten_actions_fn(flatten_action: jnp.ndarray, env:MultiAgentBraxWrapper) -> dict[int, jax.Array]:
    """Because the actions in the form of Dict[int, jnp.array] is flatten by 
    flatten_actions = jnp.concatenate([a for a in actions.values()]) so we do this way
    """

    actions = {}
    start = 0
    for agent_idx, size in env.get_action_sizes().items():
        end = start + size
        actions[agent_idx] = flatten_action[start:end]
        start = end
    return actions

unflatten_obs_fn = functools.partial(
    unflatten_obs_fn,
    env=env
)

unflatten_actions_fn = functools.partial(
    unflatten_actions_fn,
    env=env
)


In [10]:
step_and_update = functools.partial(
    single_step_and_update,
    play_step_fn=play_step_fn,
    unflatten_obs_fn=unflatten_obs_fn,
    unflatten_actions_fn=unflatten_actions_fn,
    num_updates=int(grad_updates_per_step * env_batch_size)
)

## Functions related to evaluating

In [11]:
play_eval_step_fn = functools.partial(
    matd3_agent.play_qd_step_fn,
    env=env,
    deterministic=True
)

play_eval_step_fn = jax.vmap(
    play_eval_step_fn,
    in_axes=(0, None),
    out_axes=(0, None, 0)
)


## Training/logging loop

In [None]:
import wandb
import time
from datetime import datetime

# Initialize wandb with proper run naming
def init_wandb_logging():
    """Initialize wandb with descriptive run name"""
    run_name = f"MATD3_{env_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    wandb.init(
        project="matd3-multiagent-rl",
        name=run_name,
        config={
            # Environment config
            "env_name": env_name,
            "episode_length": episode_length,
            "num_agents": num_agents,
            "parameter_sharing": parameter_sharing,
            "emitter_type": emitter_type,
            "homogenisation_method": homogenisation_method,
            
            # Training config
            "num_timesteps": num_timesteps,
            "env_batch_size": env_batch_size,
            "warmstart_steps": warmstart_steps,
            "grad_updates_per_step": grad_updates_per_step,
            "log_period": log_period,
            "num_evals": num_evals,
            
            # MATD3 hyperparameters
            "batch_size": batch_size,
            "policy_learning_rate": policy_learning_rate,
            "critic_learning_rate": critic_learning_rate,
            "discount": discount,
            "soft_tau_update": soft_tau_update,
            "policy_delay": policy_delay,
            "expl_noise": expl_noise,
            "noise_clip": noise_clip,
            "reward_scaling": reward_scaling,
            "replay_buffer_size": replay_buffer_size,
            
            # Network architecture
            "policy_hidden_layer_sizes": policy_hidden_layer_sizes,
            "critic_hidden_layer_sizes": critic_hidden_layer_sizes,
            
            # Other
            "seed": seed,
        },
        tags=["matd3", "multiagent", env_name.split("_")[0]]
    )

# Alternative: Pass parameters explicitly to the function
def run_training_loop_with_logging_v2(training_state, replay_buffer, env_states):
    """Complete training loop with wandb logging and error handling - version with explicit parameters"""
    
    # Initialize wandb
    init_wandb_logging()
    
    try:
        # Calculate training parameters
        num_iters = num_timesteps // env_batch_size
        num_loops = num_iters // log_period
        
        print(f"Training Configuration:")
        print(f"  Total timesteps: {num_timesteps:,}")
        print(f"  Env batch size: {env_batch_size}")
        print(f"  Total iterations: {num_iters:,}")
        print(f"  Log period: {log_period}")
        print(f"  Number of training loops: {num_loops}")
        print(f"  Warmstart steps: {warmstart_steps:,}")
        
        # Initialize random key for evaluation
        random_key_local = jax.random.PRNGKey(seed + 1000)  # Different seed for eval
        
        # Training metrics tracking
        start_time = time.time()
        best_fitness = float('-inf')
        
        for i in range(num_loops):
            loop_start_time = time.time()
            
            # Training step
            (env_states, training_state, replay_buffer), metrics = jax.lax.scan(
                step_and_update,
                (env_states, training_state, replay_buffer),
                (),
                length=log_period
            )
            metrics = jax.tree_util.tree_map(
                lambda x: x.flatten(),
                metrics
            )

            # print(metrics['actor_losses'][0].shape, metrics['critic_loss'].shape)
            # Evaluation
            random_key_local, subkey = jax.random.split(random_key_local)
            keys = jax.random.split(subkey, num=num_evals)
            reset_states = reset_fn(keys)
            
            true_return, true_returns = matd3_agent.eval_policy_fn(
                training_state,
                reset_states,
                play_eval_step_fn,
            )
            
            actor_losses = metrics['actor_losses']
            
            critic_loss = jnp.mean(metrics['critic_loss'], axis=0)
            
            # Calculate additional metrics
            current_timesteps = warmstart_steps + (i + 1) * env_batch_size * log_period
            loop_time = time.time() - loop_start_time
            total_time = time.time() - start_time
            timesteps_per_second = (env_batch_size * log_period) / loop_time
            
            # Update best fitness
            if true_return > best_fitness:
                best_fitness = true_return
            
            # Log to wandb
            wandb.log({
                "training/timesteps": current_timesteps,
                "training/loop": i,
                "evaluation/mean_return": true_return,
                "evaluation/best_return": best_fitness,
                "evaluation/return_std": jnp.std(true_returns),
                "losses/critic_loss": critic_loss,
                "losses/actor_loss_mean": jnp.mean(jnp.array(actor_losses), axis=(0, 1)),

                "performance/timesteps_per_second": timesteps_per_second,
                "performance/loop_time": loop_time,
                "performance/total_time": total_time,
                "training/replay_buffer_size": replay_buffer.current_size,
                "training/training_steps": training_state.steps,
            })
            
            # Log individual agent losses
            for agent_idx, loss in enumerate(actor_losses):
                wandb.log({f"losses/agent_{agent_idx}_loss": jnp.mean(loss)})
            
            # Console output
            print(f"Loop {i:4d}/{num_loops} | "
                  f"Steps: {current_timesteps:8,} | "
                  f"Fitness: {true_return:7.2f} | "
                #   f"Best: {best_fitness:7.2f} | "
                  f"Critic Loss: {critic_loss:8.4f} | "
                  f"Actor Losses: {jnp.mean(jnp.array(actor_losses), axis=(0, 1))} | "
                  f"Time: {loop_time:6.2f}s | "
                  f"TPS: {timesteps_per_second:8.1f}")
            
            # Save checkpoint periodically
            if i % (num_loops // 10) == 0 and i > 0:
                print(f"Checkpoint at loop {i} - Best fitness: {best_fitness:.2f}")
        
        print(f"\nTraining completed!")
        print(f"Total time: {total_time:.2f}s")
        print(f"Best fitness achieved: {best_fitness:.2f}")
        print(f"Final training steps: {training_state.steps}")
        
        # Final logging
        wandb.log({
            "final/best_return": best_fitness,
            "final/total_time": total_time,
            "final/final_return": true_return,
        })
        
        return training_state, replay_buffer
        
    except Exception as e:
        print(f"Training failed with error: {e}")
        wandb.log({"error": str(e)})
        raise e
    
    finally:
        wandb.finish()

# Run the training loop - use the version that passes parameters explicitly
final_training_state, final_replay_buffer = run_training_loop_with_logging_v2(
    training_state, replay_buffer, env_states
)

  return LooseVersion(v) >= LooseVersion(check)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtherealtin[0m ([33mtherealtin-uit[0m). Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import HTML, display  # type: ignore


Training Configuration:
  Total timesteps: 7,864,320
  Env batch size: 128
  Total iterations: 61,440
  Log period: 1024
  Number of training loops: 60
  Warmstart steps: 81,920


  return jax.tree_map(mask_episodes, transition)  # type: ignore


Loop    0/60 | Steps:  212,992 | Fitness:  250.28 | Critic Loss:  14.4302 | Actor Losses: -56.590782165527344 | Time:  38.20s | TPS:   3431.5
Loop    1/60 | Steps:  344,064 | Fitness:  255.47 | Critic Loss:  23.6425 | Actor Losses: -100.24248504638672 | Time:  27.43s | TPS:   4779.1
Loop    2/60 | Steps:  475,136 | Fitness:  256.92 | Critic Loss:  15.2652 | Actor Losses: -80.9264144897461 | Time:  22.08s | TPS:   5935.2
Loop    3/60 | Steps:  606,208 | Fitness:  321.89 | Critic Loss:  17.2899 | Actor Losses: -62.99038314819336 | Time:  22.06s | TPS:   5942.5
Loop    4/60 | Steps:  737,280 | Fitness:  378.99 | Critic Loss:  11.8071 | Actor Losses: -75.56999206542969 | Time:  22.06s | TPS:   5940.9
Loop    5/60 | Steps:  868,352 | Fitness:   15.50 | Critic Loss:   8.4849 | Actor Losses: -93.60055541992188 | Time:  22.03s | TPS:   5950.6
Loop    6/60 | Steps:  999,424 | Fitness:  401.29 | Critic Loss:  15.7653 | Actor Losses: -102.37384796142578 | Time:  22.02s | TPS:   5953.3
Checkpoint 