# 测试以下DF+LBFGS优化器

### 以oscillator为例子
会出现敏感度计算失误的问题。

In [19]:
"""
Find the mathematical function skeleton that represents acceleration in a damped nonlinear oscillator system with driving force, given data on time, position, and velocity. 
"""

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap, config
from scipy.optimize import minimize
from pyswarm import pso
from pyswarms.utils.plotters import plot_cost_history
config.update("jax_enable_x64", True)
import jax.random as random
from scipy.optimize import differential_evolution

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    t, x, v = inputs[:,0], inputs[:,1], inputs[:,2]
    a = outputs
    print(t.shape, x.shape, v.shape, a.shape)

    @jit
    def loss_fn(params):
        pred = equation(t, x, v, params)
        true_accelerations = a
        return jnp.mean(jnp.square(pred - true_accelerations))
        
    def run_optimization(objective_fn, initial_guess):
        
        solver = LBFGS(fun=objective_fn, maxiter=100, tol=1e-8)
        if initial_guess.size > MAX_NPARAMS:
            result =  solver.run(initial_guess)
            return result.params
        else:
            bounds = [(-10.0, 10.0)] * MAX_NPARAMS
            # JIT 编译目标函数
            jit_loss = jax.jit(objective_fn)
            # NumPy wrapper 供 SciPy 调用
            def loss_numpy(x: np.ndarray) -> float:
                return float(jit_loss(jnp.array(x, dtype=jnp.float64)))
            # 1) 差分进化全局搜索
            result_de = differential_evolution(
                loss_numpy,
                bounds=bounds,
                maxiter=300,
                popsize=15,
                tol=0.01,
                disp=True
            )
            x0 = jnp.array(result_de.x, dtype=jnp.float64)
            print(f"[DE] 初始点 loss = {result_de.fun:.6f}")
            # 2) L-BFGS 局部精炼
            sol = solver.run(x0)
            return sol.params
    
    def calculate_sensitivities(loss_fn, opt_params, base_loss,  tol=1e-20):
        # 1) 自动识别“活跃”参数索引
        grads   = jax.grad(loss_fn)(opt_params)             # (MAX_NPARAMS,)
        active  = jnp.where(jnp.abs(grads) > tol)[0]        # e.g. [0,1,2,3,4,5,6,7]
        n_active = active.shape[0]

        # 2) 构造掩码矩阵 M，shape = (n_active, MAX_NPARAMS)
        #    每一行 M[i] 都是全 1，只有 active[i] 这一列为 0
        eye      = jnp.eye(opt_params.size)
        masks    = 1.0 - eye[active]                         # (n_active, MAX_NPARAMS)

        # 3) 生成每个子问题的初始猜测：masked_params = masks * opt_params
        init_batch = masks * opt_params                      # (n_active, MAX_NPARAMS)

        # 4) 定义一个“单点 L-BFGS 解算器”，并向量化
        solver    = LBFGS(fun=loss_fn, maxiter=100, tol=1e-8)

        @jax.jit
        def solve_one(p0):
            sol = solver.run(p0)
            return sol.params                              # shape = (MAX_NPARAMS,)

        # 5) 并行跑 n_active 次 L-BFGS
        sol_batch = vmap(solve_one)(init_batch)             # (n_active, MAX_NPARAMS)

        # 6) 计算每个子解对应的损失
        loss_batch = vmap(loss_fn)(sol_batch)               # (n_active,)

        # 7) 输出相对增益 log2(loss_i / base_loss)
        sens = jnp.log2(loss_batch / base_loss)       # (n_active,)

        return  sens


    # 主流程
    # Main execution flow
    try:
        optimized_params = run_optimization(loss_fn, params)
        final_loss = loss_fn(optimized_params)
        print(f"Final loss: {final_loss}")
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None

    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': calculate_sensitivities(loss_fn,optimized_params, final_loss)
    }



@jit
def equation(t: jnp.array, x: jnp.array, v: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for acceleration in a damped nonlinear oscillator

    Args:
        t: A jax array representing time.
        x: A jax array representing observations of current position.
        v: A jax array representing observations of velocity.
        params: Jax array of numeric constants or parameters to be optimized

    Return:
        A jax array representing acceleration as the result of applying the mathematical function to the inputs.
    """
    dv = params[0] * jnp.sin(params[1]*t) + params[2] * x*v + params[3] * v**3 + params[4]*x * jnp.exp(params[5]*x) + x**2 * params[6] #+params[7]#+ params[8]*jnp.cos(params[9]*t) + params[10]*x*v**2
    return dv

In [20]:

import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./train_os2.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :3], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:,-1], dtype=jnp.float64)
print(state.shape)
print(energy)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': energy
}


evaluate(data)


(10000, 3)
[0.66038351 0.6594464  0.65849712 ... 0.24477721 0.24244884 0.24011387]
(10000,) (10000,) (10000,) (10000,)
differential_evolution step 1: f(x)= 0.09313057349369772
differential_evolution step 2: f(x)= 0.05858682986039899
differential_evolution step 3: f(x)= 0.040736511032269385
differential_evolution step 4: f(x)= 0.040736511032269385
differential_evolution step 5: f(x)= 0.040736511032269385
differential_evolution step 6: f(x)= 0.040736511032269385
differential_evolution step 7: f(x)= 0.040673256457241444
differential_evolution step 8: f(x)= 0.03513630338988301
differential_evolution step 9: f(x)= 0.03513630338988301
differential_evolution step 10: f(x)= 0.032930633049247315
differential_evolution step 11: f(x)= 0.032930633049247315
differential_evolution step 12: f(x)= 0.032930633049247315
differential_evolution step 13: f(x)= 0.015069052149403366
differential_evolution step 14: f(x)= 0.015069052149403366
differential_evolution step 15: f(x)= 0.015069052149403366
different

{'params': Array([ 0.29999887,  1.00000036, -0.99994083, -0.50003713, -5.00015655,
        -0.49650175, -4.98746969, -7.07820057, -4.6952015 ,  0.90605955],      dtype=float64),
 'loss': Array(4.00140317e-10, dtype=float64),
 'sensitivities': Array([-5.32266674e-05,  2.62484073e+01,  1.41231323e-03,  5.70507663e-04,
         2.42037339e+01,  2.37875258e+00,  3.95585604e+00], dtype=float64)}

# 测试原算法，优化掩膜情况

In [None]:
"""
Find the mathematical function skeleton that represents acceleration in a damped nonlinear oscillator system with driving force, given data on time, position, and velocity. 
"""

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap, config
from scipy.optimize import minimize
from pyswarm import pso
from pyswarms.utils.plotters import plot_cost_history
config.update("jax_enable_x64", True)
import jax.random as random
from scipy.optimize import differential_evolution
from jaxopt import LBFGS
import pyswarms as ps

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    t, x, v = inputs[:,0], inputs[:,1], inputs[:,2]
    a = outputs
    print(t.shape, x.shape, v.shape, a.shape)
    master_key = random.PRNGKey(0)

    @jit
    def loss_fn(params):
        pred = equation(t, x, v, params)
        true_accelerations = a
        return jnp.mean(jnp.square(pred - true_accelerations))

    '''def run_optimization(objective_fn, initial_guess):
        if initial_guess.size > MAX_NPARAMS:
            result = minimize(objective_fn, initial_guess,
                            method='BFGS', options={'maxiter': 100})
            return result.x
        else:
            def pso_wrapper(x):
                return objective_fn(jnp.array(x))
            
            lb = [-10.0]*MAX_NPARAMS
            ub = [10.0]*MAX_NPARAMS
            
            pso_params, _ = pso(pso_wrapper, lb, ub, 
                            swarmsize=100, maxiter=300,omega=0.729, phip=1.49445, phig=1.4944,minstep=1e-6, debug=False)
            
            result = minimize(objective_fn, jnp.array(pso_params),
                            method='BFGS', options={'maxiter': 500})
            return result.x'''
    def run_optimization(objective_fn, initial_guess, key, num_pso_runs=5, pso_iters=300, swarmsize=100):
        print(f"Initial guess size: {initial_guess.size}, MAX_NPARAMS: {MAX_NPARAMS}")
        n_params = initial_guess.size

        if n_params > MAX_NPARAMS:
            # ... (BFGS only part remains the same) ...
            print("Using BFGS directly due to large number of parameters.")
            result = minimize(objective_fn, initial_guess,
                              method='BFGS', options={'maxiter': 500})
            if not result.success:
                print(f"BFGS optimization failed: {result.message}")
                return initial_guess # Fallback
            return result.x
        else:
            # --- Use pyswarms with JAX-controlled initial positions ---
            @jit
            def pso_objective_wrapper(particles_matrix):
                return vmap(objective_fn)(particles_matrix)

            min_bound_np = np.full(n_params, -10.0, dtype=np.float64)
            max_bound_np = np.full(n_params, 10.0, dtype=np.float64)
            bounds = (min_bound_np, max_bound_np)

            options = {'c1': 1.49445, 'c2': 1.49445, 'w': 0.729}

            best_pso_params = None
            best_pso_loss = jnp.inf

            current_key = key # Use the passed-in key

            print(f"Running pyswarms PSO {num_pso_runs} times with unique initial swarms...")
            for i in range(num_pso_runs):
                # Split the key for this run to ensure unique randomness
                current_key, subkey = random.split(current_key)

                # Generate initial positions using JAX PRNG within bounds
                # Use jnp arrays for min/max bounds in jax.random.uniform
                min_bound_jnp = jnp.full(n_params, -10.0, dtype=jnp.float64)
                max_bound_jnp = jnp.full(n_params, 10.0, dtype=jnp.float64)
                init_pos_jax = random.uniform(subkey,
                                              shape=(swarmsize, n_params),
                                              dtype=jnp.float64,
                                              minval=min_bound_jnp,
                                              maxval=max_bound_jnp)
                # Convert to NumPy array for pyswarms
                init_pos_np = np.array(init_pos_jax)

                print(f"  PSO Run {i+1}/{num_pso_runs} (using JAX key split for init_pos)")
                optimizer = ps.single.GlobalBestPSO(n_particles=swarmsize,
                                                     dimensions=n_params,
                                                     options=options,
                                                     bounds=bounds,
                                                     # Pass the generated initial positions
                                                     init_pos=init_pos_np) # <-- Pass init_pos here

                # Perform optimization (pyswarms will use the provided init_pos)
                current_pso_loss, current_pso_params = optimizer.optimize(
                    pso_objective_wrapper,
                    iters=pso_iters,
                    verbose=False
                )
                current_pso_params = jnp.array(current_pso_params, dtype=jnp.float64)

                print(f"    Run {i+1} completed. Loss: {current_pso_loss}")
                if current_pso_loss < best_pso_loss:
                    best_pso_loss = current_pso_loss
                    best_pso_params = current_pso_params
                    print(f"    New best PSO loss found: {best_pso_loss}")

            # ... (rest of the function: handling no solution, BFGS refinement) ...
            if best_pso_params is None:
                 print("Warning: PSO did not find any valid solution after multiple runs. Using initial guess for BFGS.")
                 best_pso_params = initial_guess # Fallback

            print(f"\nBest PSO loss after {num_pso_runs} runs: {best_pso_loss}")
            print("Refining best PSO result with BFGS...")

            result = minimize(objective_fn, jnp.array(best_pso_params),
                              method='BFGS', options={'maxiter': 500})

            if not result.success:
                 print(f"BFGS refinement failed: {result.message}")
                 return best_pso_params # Return PSO best

            print(f"BFGS refinement successful. Final loss: {result.fun}")
            return result.x
        
    
    # 敏感度分析模块
    def calculate_sensitivities(opt_params, base_loss):
        grads   = jax.grad(loss_fn)(params)           # (MAX_NPARAMS,)
        print(grads)                                      #能够准确识别梯度            
        active  = jnp.where(jnp.abs(grads) > 0)[0]        # e.g. [0,1,2,3,4,5,6,7]
        opt_params = opt_params[active]                     # 只保留活跃参数
        n_active = active.shape[0]
        mask    = 1.0 - jnp.eye(n_active) 
        print(mask )

        @jit      
        def batch_loss(params_matrix):
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            matrix_params = flat_params.reshape(n_active,n_active)
            return jnp.sum(batch_loss(matrix_params))

        try:
            initial_flat = (opt_params * mask).flatten()
            optimized_flat = run_optimization(sensitivity_objective, initial_flat,key=opt_key)   
            optimized_matrix = optimized_flat.reshape(n_active,n_active)
            losses = batch_loss(optimized_matrix)
            relative_loss = jnp.log2(losses / base_loss)
            return relative_loss
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return jnp.zeros(MAX_NPARAMS, dtype=jnp.float64)
    

    # 主流程
    # Main execution flow
    try:
        opt_key, sensi_key = random.split(master_key)
        optimized_params = run_optimization(loss_fn, params, key=opt_key, num_pso_runs=5)
        final_loss = loss_fn(optimized_params)
        print(f"Final loss: {final_loss}")
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None

    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': calculate_sensitivities(optimized_params, final_loss)
    }


@jit
def equation(t: jnp.array, x: jnp.array, v: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for acceleration in a damped nonlinear oscillator

    Args:
        t: A jax array representing time.
        x: A jax array representing observations of current position.
        v: A jax array representing observations of velocity.
        params: Jax array of numeric constants or parameters to be optimized

    Return:
        A jax array representing acceleration as the result of applying the mathematical function to the inputs.
    """
    dv = params[0] * jnp.sin(params[1]*t) + params[2] * x*v + params[3] * v**3 + params[4]*x * jnp.exp(params[5]*x) + x**2 * params[6] +params[7]#+ params[8]*jnp.cos(params[9]*t) + params[10]*x*v**2
    return dv

In [18]:

import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./train_os2.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :3], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:,-1], dtype=jnp.float64)

# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': energy
}


evaluate(data)


(10000,) (10000,) (10000,) (10000,)
Initial guess size: 10, MAX_NPARAMS: 10
Running pyswarms PSO 5 times with unique initial swarms...
  PSO Run 1/5 (using JAX key split for init_pos)
    Run 1 completed. Loss: 1.8779801776864734e-08
    New best PSO loss found: 1.8779801776864734e-08
  PSO Run 2/5 (using JAX key split for init_pos)
    Run 2 completed. Loss: 0.0026699999824137534
  PSO Run 3/5 (using JAX key split for init_pos)
    Run 3 completed. Loss: 0.008855897474132195
  PSO Run 4/5 (using JAX key split for init_pos)
    Run 4 completed. Loss: 0.005595129434093178
  PSO Run 5/5 (using JAX key split for init_pos)
    Run 5 completed. Loss: 6.63555642553374e-09
    New best PSO loss found: 6.63555642553374e-09

Best PSO loss after 5 runs: 6.63555642553374e-09
Refining best PSO result with BFGS...
BFGS refinement successful. Final loss: 6.5940665863206395e-09
Final loss: 6.5940665863206395e-09
[ 1.05987578e+00 -1.11681569e+00  1.98393131e-03  7.80097456e-04
  1.62920544e-01  2.1800

{'params': array([ 2.99983654e-01,  1.00000068e+00, -9.99896612e-01, -5.01206299e-01,
        -5.00180918e+00,  4.64317718e-01, -1.79077881e-01,  8.11270132e-06,
        -9.25496229e-02,  9.13170731e+00]),
 'loss': Array(6.59406659e-09, dtype=float64),
 'sensitivities': Array([22.20403872, 22.20403872, 15.03910487, 12.47087436, 24.55283238,
         5.61911479, -8.63815805, -0.11304308], dtype=float64)}