In [1]:
import jax
import jax.numpy as jnp
jax.config.update("jax_log_compiles", False)

In [2]:
from qdax.core.emitters.mutation_operators import proximal_mutation, _proximal_mutation

2025-07-31 01:16:55.392421: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.9.86). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
random_key = jax.random.PRNGKey(0)
obs = jax.random.normal(random_key, (256, 8))

In [4]:
def test_proximal_mutation():
    """Test function to verify proximal mutation works correctly."""
    
    import jax
    import jax.numpy as jnp
    from jax import random
    
    print("Testing Proximal Mutation...")
    
    # Setup test data
    key = random.PRNGKey(42)
    
    # Create dummy genotypes (batch of simple linear network params)
    batch_size = 3
    input_dim = 4
    output_dim = 2
    
    # Simple genotype: just weight matrix and bias
    def make_dummy_genotype(key, batch_size, input_dim, output_dim):
        key1, key2 = random.split(key)
        weights = random.normal(key1, (batch_size, input_dim, output_dim)) * 0.1
        biases = random.normal(key2, (batch_size, output_dim)) * 0.1
        return {'weights': weights, 'biases': biases}
    
    # Simple policy function for testing
    def dummy_policy_fn(params, obs):
        """Simple linear policy: obs @ weights + bias"""
        # obs shape: [batch_size, input_dim]
        # params['weights'] shape: [input_dim, output_dim]  
        # params['biases'] shape: [output_dim]
        return obs @ params['weights'] + params['biases']
    
    # Create test data
    key, subkey = random.split(key)
    genotypes = make_dummy_genotype(subkey, batch_size, input_dim, output_dim)
    
    key, subkey = random.split(key)
    obs_batch_size = 10
    obs = random.normal(subkey, (obs_batch_size, input_dim))
    
    print(f"Genotype shapes:")
    for k, v in genotypes.items():
        print(f"  {k}: {v.shape}")
    print(f"Observations shape: {obs.shape}")
    
    # Test the policy function first
    try:
        print("\n1. Testing policy function...")
        single_genotype = jax.tree_map(lambda x: x[0], genotypes)  # First genotype
        output = dummy_policy_fn(single_genotype, obs)
        print(f"   Policy output shape: {output.shape}")
        print("   ✓ Policy function works!")
    except Exception as e:
        print(f"   ✗ Policy function failed: {e}")
        return False
    
    # Test single genotype proximal mutation
    try:
        print("\n2. Testing single genotype proximal mutation...")
        key, subkey = random.split(key)
        mutated_single = _proximal_mutation(
            x=single_genotype,
            random_key=subkey,
            policy_fn=dummy_policy_fn,
            obs=obs,
            mutation_mag=0.01,
            mutation_noise=False,
        )
        print(f"   Original genotype shapes:")
        for k, v in single_genotype.items():
            print(f"     {k}: {v.shape}")
        print(f"   Mutated genotype shapes:")
        for k, v in mutated_single.items():
            print(f"     {k}: {v.shape}")
        
        # Check if mutation actually changed parameters
        weight_diff = jnp.mean(jnp.abs(mutated_single['weights'] - single_genotype['weights']))
        print(f"   Mean weight change: {weight_diff:.6f}")
        print("   ✓ Single genotype mutation works!")
    except Exception as e:
        print(f"   ✗ Single genotype mutation failed: {e}")
        return False
    
    # Test batch proximal mutation
    try:
        print("\n3. Testing batch proximal mutation...")
        key, subkey = random.split(key)
        mutated_batch, new_key = proximal_mutation(
            x=genotypes,
            random_key=subkey,
            policy_fn=dummy_policy_fn,
            obs=obs,
            mutation_mag=0.01,
            mutation_noise=False,
        )
        
        print(f"   Original batch shapes:")
        for k, v in genotypes.items():
            print(f"     {k}: {v.shape}")
        print(f"   Mutated batch shapes:")
        for k, v in mutated_batch.items():
            print(f"     {k}: {v.shape}")
        
        # Check if mutation changed parameters for each genotype
        for i in range(batch_size):
            orig_weights = genotypes['weights'][i]
            mut_weights = mutated_batch['weights'][i]
            weight_diff = jnp.mean(jnp.abs(mut_weights - orig_weights))
            print(f"   Genotype {i} mean weight change: {weight_diff:.6f}")
        
        print("   ✓ Batch proximal mutation works!")
        
    except Exception as e:
        print(f"   ✗ Batch proximal mutation failed: {e}")
        import traceback
        traceback.print_exc()
        return False
    
    # Test with mutation noise
    try:
        print("\n4. Testing with mutation noise...")
        key, subkey = random.split(key)
        mutated_noise, _ = proximal_mutation(
            x=genotypes,
            random_key=subkey,
            policy_fn=dummy_policy_fn,
            obs=obs,
            mutation_mag=0.01,
            mutation_noise=True,  # Enable noise
        )
        print("   ✓ Mutation with noise works!")
        
    except Exception as e:
        print(f"   ✗ Mutation with noise failed: {e}")
        return False
    
    print("\n🎉 All tests passed! Proximal mutation is working correctly.")
    return True

# Run the test
if __name__ == "__main__":
    test_proximal_mutation()

Testing Proximal Mutation...
Genotype shapes:
  weights: (3, 4, 2)
  biases: (3, 2)
Observations shape: (10, 4)

1. Testing policy function...
   Policy output shape: (10, 2)
   ✓ Policy function works!

2. Testing single genotype proximal mutation...
   ✗ Single genotype mutation failed: Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute
Genotype shapes:
  weights: (3, 4, 2)
  biases: (3, 2)
Observations shape: (10, 4)

1. Testing policy function...
   Policy output shape: (10, 2)
   ✓ Policy function works!

2. Testing single genotype proximal mutation...
   ✗ Single genotype mutation failed: Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute


  single_genotype = jax.tree_map(lambda x: x[0], genotypes)  # First genotype


In [5]:
from jax import random
def test_multiagent_integration():
    """Test proximal mutation in a multi-agent scenario."""
    
    print("\nTesting Multi-Agent Integration...")
    
    key = random.PRNGKey(123)
    
    # Multi-agent setup
    num_agents = 2
    batch_size = 2  # Number of genotypes per agent
    obs_dim = 6
    action_dim_agent1 = 3
    action_dim_agent2 = 2
    
    # Create agent-specific genotypes
    def make_agent_genotype(key, batch_size, obs_dim, action_dim):
        key1, key2 = random.split(key)
        weights = random.normal(key1, (batch_size, obs_dim, action_dim)) * 0.1
        biases = random.normal(key2, (batch_size, action_dim)) * 0.1
        return {'weights': weights, 'biases': biases}
    
    def agent_policy_fn(params, obs):
        return jnp.tanh(obs @ params['weights'] + params['biases'])
    
    # Create genotypes for each agent
    key1, key2, key3 = random.split(key, 3)
    agent1_genotypes = make_agent_genotype(key1, batch_size, obs_dim, action_dim_agent1)
    agent2_genotypes = make_agent_genotype(key2, batch_size, obs_dim, action_dim_agent2)
    
    # Shared observations
    obs = random.normal(key3, (50, obs_dim))  # 50 observation samples
    
    # Test each agent separately
    for agent_idx, (agent_genotypes, action_dim) in enumerate([
        (agent1_genotypes, action_dim_agent1),
        (agent2_genotypes, action_dim_agent2)
    ]):
        print(f"\n  Agent {agent_idx}:")
        # try:
        key, subkey = random.split(key)
        mutated_agent, _ = proximal_mutation(
            x=agent_genotypes,
            random_key=subkey,
            policy_fn=agent_policy_fn,
            obs=obs,
            mutation_mag=0.005,
        )
        
        # Check shapes
        print(f"    Original weights: {agent_genotypes['weights'].shape}")
        print(f"    Mutated weights: {mutated_agent['weights'].shape}")
        print(f"    ✓ Agent {agent_idx} mutation successful!")
            
        # except Exception as e:
        #     print(f"    ✗ Agent {agent_idx} mutation failed: {e}")
        #     return False
    
    print("  ✓ Multi-agent integration test passed!")
    return True

# Run integration test
test_multiagent_integration()


Testing Multi-Agent Integration...

  Agent 0:

  Agent 0:


TypeError: Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute

In [None]:
def test_proximal_mutation_with_jit():
    """Test proximal mutation with JAX JIT compilation."""
    
    print("\nTesting Proximal Mutation with JAX JIT...")
    
    from jax import jit
    import time
    
    key = random.PRNGKey(456)
    
    # Create test data
    batch_size = 4
    obs_dim = 8
    action_dim = 3
    obs_batch_size = 100
    
    def make_test_genotype(key, batch_size, obs_dim, action_dim):
        key1, key2 = random.split(key)
        weights = random.normal(key1, (batch_size, obs_dim, action_dim)) * 0.1
        biases = random.normal(key2, (batch_size, action_dim)) * 0.1
        return {'weights': weights, 'biases': biases}
    
    def test_policy_fn(params, obs):
        return jnp.tanh(obs @ params['weights'] + params['biases'])
    
    # Create test data
    key, subkey1, subkey2 = random.split(key, 3)
    genotypes = make_test_genotype(subkey1, batch_size, obs_dim, action_dim)
    obs = random.normal(subkey2, (obs_batch_size, obs_dim))
    
    print(f"Test setup:")
    print(f"  Batch size: {batch_size}")
    print(f"  Obs dim: {obs_dim}, Action dim: {action_dim}")
    print(f"  Obs batch size: {obs_batch_size}")
    
    # Test 1: JIT compile the main function
    try:
        print("\n1. Testing JIT compilation of proximal_mutation...")
        
        # JIT compile the function
        jit_proximal_mutation = jit(proximal_mutation, static_argnames=['policy_fn'])
        
        # First call (compilation + execution)
        key, subkey = random.split(key)
        start_time = time.time()
        mutated_jit, _ = jit_proximal_mutation(
            x=genotypes,
            random_key=subkey,
            policy_fn=test_policy_fn,
            obs=obs,
            mutation_mag=0.01,
            mutation_noise=False,
        )
        first_call_time = time.time() - start_time
        
        print(f"   First call (with compilation): {first_call_time:.4f}s")
        print("   ✓ JIT compilation successful!")
        
    except Exception as e:
        print(f"   ✗ JIT compilation failed: {e}")
        import traceback
        traceback.print_exc()
        return False
    
    # Test 2: Compare JIT vs non-JIT performance
    try:
        print("\n2. Performance comparison (JIT vs non-JIT)...")
        
        # Non-JIT version
        key, subkey = random.split(key)
        start_time = time.time()
        mutated_no_jit, _ = proximal_mutation(
            x=genotypes,
            random_key=subkey,
            policy_fn=test_policy_fn,
            obs=obs,
            mutation_mag=0.01,
            mutation_noise=False,
        )
        no_jit_time = time.time() - start_time
        
        # JIT version (already compiled)
        key, subkey = random.split(key)
        start_time = time.time()
        mutated_jit_second, _ = jit_proximal_mutation(
            x=genotypes,
            random_key=subkey,
            policy_fn=test_policy_fn,
            obs=obs,
            mutation_mag=0.01,
            mutation_noise=False,
        )
        jit_time = time.time() - start_time
        
        print(f"   Non-JIT time: {no_jit_time:.4f}s")
        print(f"   JIT time (after compilation): {jit_time:.4f}s")
        print(f"   Speedup: {no_jit_time/jit_time:.2f}x")
        print("   ✓ Performance comparison complete!")
        
    except Exception as e:
        print(f"   ✗ Performance comparison failed: {e}")
        return False
    
    # Test 3: Verify numerical consistency
    try:
        print("\n3. Testing numerical consistency...")
        
        # Use same random key for both versions
        test_key = random.PRNGKey(999)
        
        mutated_no_jit, _ = proximal_mutation(
            x=genotypes,
            random_key=test_key,
            policy_fn=test_policy_fn,
            obs=obs,
            mutation_mag=0.01,
            mutation_noise=False,
        )
        
        mutated_jit, _ = jit_proximal_mutation(
            x=genotypes,
            random_key=test_key,
            policy_fn=test_policy_fn,
            obs=obs,
            mutation_mag=0.01,
            mutation_noise=False,
        )
        
        # Compare results
        weight_diff = jnp.max(jnp.abs(mutated_jit['weights'] - mutated_no_jit['weights']))
        bias_diff = jnp.max(jnp.abs(mutated_jit['biases'] - mutated_no_jit['biases']))
        
        print(f"   Max weight difference: {weight_diff:.10f}")
        print(f"   Max bias difference: {bias_diff:.10f}")
        
        # Should be exactly the same (or very close due to floating point)
        if weight_diff < 1e-6 and bias_diff < 1e-6:
            print("   ✓ JIT and non-JIT results are numerically consistent!")
        else:
            print("   ⚠ Warning: JIT and non-JIT results differ slightly")
            
    except Exception as e:
        print(f"   ✗ Numerical consistency test failed: {e}")
        return False
    
    # Test 4: JIT with different batch sizes (shape polymorphism)
    try:
        print("\n4. Testing JIT with different batch sizes...")
        
        # Test with different batch sizes
        for test_batch_size in [2, 6, 8]:
            key, subkey1, subkey2 = random.split(key, 3)
            test_genotypes = make_test_genotype(subkey1, test_batch_size, obs_dim, action_dim)
            
            mutated, _ = jit_proximal_mutation(
                x=test_genotypes,
                random_key=subkey2,
                policy_fn=test_policy_fn,
                obs=obs,
                mutation_mag=0.01,
            )
            
            print(f"   Batch size {test_batch_size}: ✓")
        
        print("   ✓ JIT works with different batch sizes!")
        
    except Exception as e:
        print(f"   ✗ Different batch sizes test failed: {e}")
        return False
    
    # Test 5: JIT with mutation noise
    try:
        print("\n5. Testing JIT with mutation noise...")
        
        key, subkey = random.split(key)
        mutated_noise, _ = jit_proximal_mutation(
            x=genotypes,
            random_key=subkey,
            policy_fn=test_policy_fn,
            obs=obs,
            mutation_mag=0.01,
            mutation_noise=True,  # Enable noise
        )
        
        print("   ✓ JIT with mutation noise works!")
        
    except Exception as e:
        print(f"   ✗ JIT with mutation noise failed: {e}")
        return False
    
    print("\n🚀 All JIT tests passed! Proximal mutation is JIT-compatible and efficient.")
    return True

# Run JIT test
test_proximal_mutation_with_jit()


Testing Proximal Mutation with JAX JIT...
Test setup:
  Batch size: 4
  Obs dim: 8, Action dim: 3
  Obs batch size: 100

1. Testing JIT compilation of proximal_mutation...
   ✗ JIT compilation failed: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function proximal_mutation at /home/tin/Desktop/HaiDang/RL/Mix-ME/MA-QDax/qdax/core/emitters/mutation_operators.py:327 for jit. 
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
Test setup:
  Batch size: 4
  Obs dim: 8, Action dim: 3
  Obs batch size: 100

1. Testing JIT compilation of proximal_mutation...
   ✗ JIT compilation failed: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function proximal_mutation at /home/tin/Desktop/HaiDang/RL/Mix-ME/MA-QDax/qdax/core/emitters/mutation_operators.py:327 for jit. 
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionErr

Traceback (most recent call last):
  File "/tmp/ipykernel_27525/2353898674.py", line 46, in test_proximal_mutation_with_jit
    mutated_jit, _ = jit_proximal_mutation(
  File "/home/tin/anaconda3/envs/mix-me/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/tin/anaconda3/envs/mix-me/lib/python3.10/site-packages/jax/_src/pjit.py", line 304, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/home/tin/anaconda3/envs/mix-me/lib/python3.10/site-packages/jax/_src/pjit.py", line 171, in _python_pjit_helper
    attrs_tracked) = _infer_params(jit_info, args, kwargs)
  File "/home/tin/anaconda3/envs/mix-me/lib/python3.10/site-packages/jax/_src/pjit.py", line 598, in _infer_params
    jaxpr, consts, out_shardings_flat, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
  File "/home/tin/anaconda3/envs/mix-me/lib/python3.10/site-packages/jax/_src/p

False