In [1]:
"""
Find the mathematical function skeleton that represents acceleration in a damped nonlinear oscillator system with driving force, given data on time, position, and velocity. 
"""

import jax
import jax.numpy as jnp
from jax import jit, vmap, config
import pyswarms as ps
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
import numpy as np
import jax.random as random
from jaxopt import LBFGS

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, params=initial_params) -> dict:

    master_key = random.PRNGKey(0)
    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q, p = inputs[:, :n_dim], inputs[:, n_dim:]
    true_value = outputs
    
    @jit
    def compute_dynamics(q, p, params):
        def hamiltonian(q, p, params):
            return equation(q, p, params)  
        q_dot = jax.grad(hamiltonian, 1)(q, p, params)  
        p_dot = -jax.grad(hamiltonian, 0)(q, p, params)  
        return jnp.concatenate([q_dot, p_dot])

    batch_compute_dynamics = jit(vmap(compute_dynamics, (0, 0, None)))

    @jit
    def loss_fn(params):
        pred = batch_compute_dynamics(q, p, params)
        return jnp.mean(jnp.square(pred - true_value))
    
    def run_optimization(objective_fn, initial_guess, key, num_pso_runs=5, pso_iters=300, swarmsize=100, gg = False):
        print(f"Initial guess size: {initial_guess.size}, MAX_NPARAMS: {MAX_NPARAMS}")

        grads   = jax.grad(equation,2)(q[1],p[1],params)           # (MAX_NPARAMS,)           
        active  = jnp.where(jnp.abs(grads) > 0)[0]        # e.g. [0,1,2,3,4,5,6,7]
        n_params = active.shape[0]                        # 只保留活跃参数
        print(n_params)
        
        solver = LBFGS(objective_fn, maxiter=100, tol=1e-8)

        if  gg == True:
            result =  solver.run(initial_guess)
            return result.params
        else:
            # --- Use pyswarms with JAX-controlled initial positions ---
            @jit
            def pso_objective_wrapper(particles_matrix):
                return vmap(objective_fn)(particles_matrix)

            min_bound_np = np.full(n_params, -100.0, dtype=np.float64)
            max_bound_np = np.full(n_params, 100.0, dtype=np.float64)
            bounds = (min_bound_np, max_bound_np)

            options = {'c1': 1.49445, 'c2': 1.49445, 'w': 0.729}

            best_pso_params = None
            best_pso_loss = jnp.inf

            current_key = key # Use the passed-in key

            print(f"Running pyswarms PSO {num_pso_runs} times with unique initial swarms...")
            for i in range(num_pso_runs):
                # Split the key for this run to ensure unique randomness
                current_key, subkey = random.split(current_key)

                # Generate initial positions using JAX PRNG within bounds
                # Use jnp arrays for min/max bounds in jax.random.uniform
                min_bound_jnp = jnp.full(n_params, -10.0, dtype=jnp.float64)
                max_bound_jnp = jnp.full(n_params, 10.0, dtype=jnp.float64)
                init_pos_jax = random.uniform(subkey,
                                              shape=(swarmsize, n_params),
                                              dtype=jnp.float64,
                                              minval=min_bound_jnp,
                                              maxval=max_bound_jnp)
                # Convert to NumPy array for pyswarms
                init_pos_np = np.array(init_pos_jax)

                print(f"  PSO Run {i+1}/{num_pso_runs} (using JAX key split for init_pos)")
                optimizer = ps.single.GlobalBestPSO(n_particles=swarmsize,
                                                     dimensions=n_params,
                                                     options=options,
                                                     bounds=bounds,
                                                     # Pass the generated initial positions
                                                     init_pos=init_pos_np) # <-- Pass init_pos here

                # Perform optimization (pyswarms will use the provided init_pos)
                current_pso_loss, current_pso_params = optimizer.optimize(
                    pso_objective_wrapper,
                    iters=pso_iters,
                    verbose=False
                )
                current_pso_params = jnp.array(current_pso_params, dtype=jnp.float64)

                print(f"    Run {i+1} completed. Loss: {current_pso_loss}")
                if current_pso_loss < best_pso_loss:
                    best_pso_loss = current_pso_loss
                    best_pso_params = current_pso_params
                    print(f"    New best PSO loss found: {best_pso_loss}")

            # ... (rest of the function: handling no solution, BFGS refinement) ...
            if best_pso_params is None:
                 print("Warning: PSO did not find any valid solution after multiple runs. Using initial guess for BFGS.")
                 best_pso_params = initial_guess # Fallback

            print(f"\nBest PSO loss after {num_pso_runs} runs: {best_pso_loss}")
            print("Refining best PSO result with L- BFGS...")

            result =  solver.run(best_pso_params)
            
            return result.params

    
    def calculate_sensitivities(opt_params: jnp.ndarray, base_loss: float) -> jnp.ndarray:
        n_active = opt_params.size
        masks = 1.0 - jnp.eye(n_active, dtype=opt_params.dtype)
        @jit
        def solve_one(mask_i):
            init = opt_params * mask_i
            def sub_loss(p):
                return loss_fn(p * mask_i)
            solver = LBFGS(fun=sub_loss, maxiter=100, tol=1e-8)
            out    = solver.run(init)         
            loss_i = loss_fn(out.params * mask_i)       
            return jnp.log2(loss_i / base_loss)

        sensitivities = vmap(solve_one)(masks)
        return sensitivities

    try:
        opt_key, sensi_key = random.split(master_key)
        optimized_params = run_optimization(loss_fn, params, key=opt_key)
        final_loss = loss_fn(optimized_params)
        print(f"Final loss after L-BFGS: {final_loss}")
        if optimized_params is None:
             print("Optimization failed to produce parameters.")
             return None

        final_loss = loss_fn(optimized_params)
        print(f"Final optimized loss: {final_loss}")
    except Exception as e:
        # ... (error handling remains the same) ...
        print(f"Optimization or final loss calculation failed: {e}")
        import traceback
        traceback.print_exc()
        return None

    if not jnp.isfinite(final_loss):
        print("Final loss is not finite.")
        return None

    # Pass a key to sensitivity analysis if it also needs randomness
    # For now, assuming calculate_sensitivities doesn't need a separate key
    sensitivities = calculate_sensitivities(optimized_params, final_loss)
    sensitivity_dict = {f"sensitive of params[{i}]": round(float(sensitivities[i]), 4) 
                       for i in range(len(sensitivities))}


    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': sensitivity_dict
    }

@jit
def equation(q: jnp.array, p: jnp.array, params: jnp.array) -> jnp.array:
    q1, q2 = q[...,0], q[...,1]
    p1, p2 = p[...,0], p[...,1]

    T_num = (params[0] * p1**2 + params[1] * p2**2 + params[2] * p1 * p2)
    T_den = params[3] + params[4] * jnp.cos(q1 - q2) 
    T = T_num / T_den

    # Potential Energy (V)
    V = params[5]* jnp.cos(q1) - params[6] * jnp.cos(q2)

    return T + V

INFO:2025-04-25 01:11:55,085:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-04-25 01:11:55,085 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-04-25 01:11:55,101:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.
2025-04-25 01:11:55,101 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.


In [2]:
import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('../double_pendulum/double_pendulum_high_precision.csv')
#data0 = pd.read_csv('./hamiltonian_spring_mass_energy_data.csv')
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, 1:3], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 3:-1], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, -1], dtype=jnp.float64)
print(true_q_ddot)
print(state)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
    'energy': energy
}


print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

[[ 0.          0.        ]
 [-0.0654109  -0.03270545]
 [-0.1308218  -0.0654109 ]
 ...
 [ 8.65638418  3.16140084]
 [ 8.57651316  3.21913912]
 [ 8.49518264  3.27744751]]
[[  1.57079633   1.57079633]
 [  1.57074181   1.57079633]
 [  1.57057825   1.57079633]
 ...
 [  0.72130633 -25.8504519 ]
 [  0.73509878 -25.84159421]
 [  0.74876059 -25.83246209]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Initial guess size: 10, MAX_NPARAMS: 10
7
Running pyswarms PSO 5 times with unique initial swarms...
  PSO Run 1/5 (using JAX key split for init_pos)
    Run 1 completed. Loss: 24.169338790776283
    New best PSO loss found: 24.169338790776283
  PSO Run 2/5 (using JAX key split for init_pos)
    Run 2 completed. Loss: 24.169338790770542
    New best PSO loss found: 24.169338790770542
  PSO Run 3/5 (using JAX key split for init_pos)
    Run 3 completed. Loss: 24.16933879077056
  PSO Run 4/5 (using JAX key split for init_pos)
    Run 4 completed. Loss: 24.169338790770542
  PSO Run 5/5 (using JAX key split for init_