F sin(ωx) − αv3 − βx3 − γx · v − x cos(x)   
F sin(ωt) − αv3 − βx · v − δx · exp(γx).

In [3]:
"""
Find the mathematical function skeleton that represents acceleration in a damped nonlinear oscillator system with driving force, given data on time, position, and velocity. 
"""

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap, config
from scipy.optimize import minimize
from pyswarm import pso
from pyswarms.utils.plotters import plot_cost_history
config.update("jax_enable_x64", True)
import jax.random as random
from scipy.optimize import differential_evolution
from jaxopt import LBFGS

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    t, x, v = inputs[:,0], inputs[:,1], inputs[:,2]
    a = outputs
    print(t.shape, x.shape, v.shape, a.shape)

    @jit
    def loss_fn(params):
        pred = equation(t, x, v, params)
        true_accelerations = a
        return jnp.mean(jnp.square(pred - true_accelerations))

    def run_optimization(objective_fn, initial_guess):
        print(f"Running optimization with initial_guess size: {initial_guess.size}")
        print(initial_guess.size, MAX_NPARAMS)
        if initial_guess.size > MAX_NPARAMS:
            result = minimize(objective_fn, initial_guess,
                            method='BFGS', options={'maxiter': 100})
            return result.x
        else:
            def pso_wrapper(x):
                return objective_fn(jnp.array(x))
            
            lb = [-10.0]*initial_guess.size
            ub = [10.0]*initial_guess.size
            
            pso_params, _ = pso(pso_wrapper, lb, ub, 
                            swarmsize=100, maxiter=300,omega=0.729, phip=1.49445, phig=1.4944,minstep=1e-6, debug=False)
            
            result = minimize(objective_fn, jnp.array(pso_params),
                            method='BFGS', options={'maxiter': 500})
            return result.x
        
    
    # 敏感度分析模块
    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit      
        def batch_loss(params_matrix):
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            initial_flat = (opt_params * mask).flatten()
            optimized_flat = run_optimization(sensitivity_objective, initial_flat)   
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            losses = batch_loss(optimized_matrix)
            relative_loss = jnp.log2(losses / base_loss)
            return relative_loss
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return jnp.zeros(MAX_NPARAMS, dtype=jnp.float64)
    

    # 主流程
    # Main execution flow
    try:
        optimized_params = run_optimization(loss_fn, params)
        final_loss = loss_fn(optimized_params)
        print(f"Final loss: {final_loss}")
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None

    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': calculate_sensitivities(optimized_params, final_loss)
    }



@jit
def equation(t: jnp.array, x: jnp.array, v: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for acceleration in a damped nonlinear oscillator

    Args:
        t: A jax array representing time.
        x: A jax array representing observations of current position.
        v: A jax array representing observations of velocity.
        params: Jax array of numeric constants or parameters to be optimized

    Return:
        A jax array representing acceleration as the result of applying the mathematical function to the inputs.
    """
    dv = params[0] * jnp.sin(params[1]*t) + params[2] * x*v + params[3] * v**3 + params[4]*x * jnp.exp(params[5]*x) + x**2 * params[6] #+params[7]#+ params[8]*jnp.cos(params[9]*t) + params[10]*x*v**2
    return dv

In [8]:

import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./train_os2.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :3], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:,-1], dtype=jnp.float64)
print(state.shape)
print(energy)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': energy
}


evaluate(data)


(10000, 3)
[0.66038351 0.6594464  0.65849712 ... 0.24477721 0.24244884 0.24011387]
(10000,) (10000,) (10000,) (10000,)
Initial guess size: 10, MAX_NPARAMS: 10
Running pyswarms PSO 5 times with unique initial swarms...
  PSO Run 1/5 (using JAX key split for init_pos)
    Run 1 completed. Loss: 7.848604236887877e-09
    New best PSO loss found: 7.848604236887877e-09
  PSO Run 2/5 (using JAX key split for init_pos)
    Run 2 completed. Loss: 1.795317025864746e-06
  PSO Run 3/5 (using JAX key split for init_pos)
    Run 3 completed. Loss: 1.3294243751197797e-07
  PSO Run 4/5 (using JAX key split for init_pos)
    Run 4 completed. Loss: 8.381093380532152e-07
  PSO Run 5/5 (using JAX key split for init_pos)
    Run 5 completed. Loss: 7.411056914504466e-07

Best PSO loss after 5 runs: 7.848604236887877e-09
Refining best PSO result with BFGS...
BFGS refinement successful. Final loss: 7.847743673937094e-09
Final optimized loss: 7.847743673937094e-09
Initial guess size: 100, MAX_NPARAMS: 10
Usin

{'params': array([ 3.00046113e-01,  9.99999924e-01, -9.99660021e-01, -4.99654140e-01,
        -4.99835597e+00, -5.36088866e-01, -5.18742236e+00,  3.77517554e-05,
        -9.77342550e-01,  8.06835574e+00]),
 'loss': Array(7.84774367e-09, dtype=float64),
 'sensitivities': {'sensitive of params[0]': 21.9529,
  'sensitive of params[1]': 21.9529,
  'sensitive of params[2]': 14.7852,
  'sensitive of params[3]': 12.2185,
  'sensitive of params[4]': 24.2893,
  'sensitive of params[5]': 5.2098,
  'sensitive of params[6]': -10.3903,
  'sensitive of params[7]': 0.0699,
  'sensitive of params[8]': -0.0882,
  'sensitive of params[9]': -0.0696}}

# 优化快速，5轮次不过4秒，损失非常低-21都有

In [7]:
"""
Find the mathematical function skeleton that represents acceleration in a damped nonlinear oscillator system with driving force, given data on time, position, and velocity. 
"""


import jax.numpy as jnp
from jax import jit, vmap, config
from scipy.optimize import minimize
import pyswarms as ps
config.update("jax_enable_x64", True)
import numpy as np
import jax.random as random

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    t, x, v = inputs[:,0], inputs[:,1], inputs[:,2]
    a = outputs
    print(t.shape, x.shape, v.shape, a.shape)
    master_key = random.PRNGKey(0)

    @jit
    def loss_fn(params):
        pred = equation(t, x, v, params)
        true_accelerations = a
        return jnp.mean(jnp.square(pred - true_accelerations))

    def run_optimization(objective_fn, initial_guess, key, num_pso_runs=5, pso_iters=300, swarmsize=100):
        print(f"Initial guess size: {initial_guess.size}, MAX_NPARAMS: {MAX_NPARAMS}")
        n_params = initial_guess.size

        if n_params > MAX_NPARAMS:
            # ... (BFGS only part remains the same) ...
            print("Using BFGS directly due to large number of parameters.")
            result = minimize(objective_fn, initial_guess,
                              method='BFGS', options={'maxiter': 500})
            if not result.success:
                print(f"BFGS optimization failed: {result.message}")
                return initial_guess # Fallback
            return result.x
        else:
            # --- Use pyswarms with JAX-controlled initial positions ---
            @jit
            def pso_objective_wrapper(particles_matrix):
                return vmap(objective_fn)(particles_matrix)

            min_bound_np = np.full(n_params, -10.0, dtype=np.float64)
            max_bound_np = np.full(n_params, 10.0, dtype=np.float64)
            bounds = (min_bound_np, max_bound_np)

            options = {'c1': 1.49445, 'c2': 1.49445, 'w': 0.729}

            best_pso_params = None
            best_pso_loss = jnp.inf

            current_key = key # Use the passed-in key

            print(f"Running pyswarms PSO {num_pso_runs} times with unique initial swarms...")
            for i in range(num_pso_runs):
                # Split the key for this run to ensure unique randomness
                current_key, subkey = random.split(current_key)

                # Generate initial positions using JAX PRNG within bounds
                # Use jnp arrays for min/max bounds in jax.random.uniform
                min_bound_jnp = jnp.full(n_params, -10.0, dtype=jnp.float64)
                max_bound_jnp = jnp.full(n_params, 10.0, dtype=jnp.float64)
                init_pos_jax = random.uniform(subkey,
                                              shape=(swarmsize, n_params),
                                              dtype=jnp.float64,
                                              minval=min_bound_jnp,
                                              maxval=max_bound_jnp)
                # Convert to NumPy array for pyswarms
                init_pos_np = np.array(init_pos_jax)

                print(f"  PSO Run {i+1}/{num_pso_runs} (using JAX key split for init_pos)")
                optimizer = ps.single.GlobalBestPSO(n_particles=swarmsize,
                                                     dimensions=n_params,
                                                     options=options,
                                                     bounds=bounds,
                                                     # Pass the generated initial positions
                                                     init_pos=init_pos_np) # <-- Pass init_pos here

                # Perform optimization (pyswarms will use the provided init_pos)
                current_pso_loss, current_pso_params = optimizer.optimize(
                    pso_objective_wrapper,
                    iters=pso_iters,
                    verbose=False
                )
                current_pso_params = jnp.array(current_pso_params, dtype=jnp.float64)

                print(f"    Run {i+1} completed. Loss: {current_pso_loss}")
                if current_pso_loss < best_pso_loss:
                    best_pso_loss = current_pso_loss
                    best_pso_params = current_pso_params
                    print(f"    New best PSO loss found: {best_pso_loss}")

            # ... (rest of the function: handling no solution, BFGS refinement) ...
            if best_pso_params is None:
                 print("Warning: PSO did not find any valid solution after multiple runs. Using initial guess for BFGS.")
                 best_pso_params = initial_guess # Fallback

            print(f"\nBest PSO loss after {num_pso_runs} runs: {best_pso_loss}")
            print("Refining best PSO result with BFGS...")

            result = minimize(objective_fn, jnp.array(best_pso_params),
                              method='BFGS', options={'maxiter': 500})

            if not result.success:
                 print(f"BFGS refinement failed: {result.message}")
                 return best_pso_params # Return PSO best

            print(f"BFGS refinement successful. Final loss: {result.fun}")
            return result.x



    # 敏感度分析模块
    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit      
        def batch_loss(params_matrix):
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            initial_flat = (opt_params * mask).flatten()
            optimized_flat = run_optimization(sensitivity_objective, initial_flat, key=opt_key, num_pso_runs=5)
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            losses = batch_loss(optimized_matrix)
            relative_loss = jnp.log2(losses / base_loss)
            return relative_loss
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return jnp.zeros(MAX_NPARAMS, dtype=jnp.float64)


    # 主流程
    # Main execution flow
    try:
        # Split the master key for the main optimization run
        opt_key, sensi_key = random.split(master_key) # Keep keys separate if needed later

        optimized_params = run_optimization(loss_fn, params, key=opt_key, num_pso_runs=5) # Pass the key
        if optimized_params is None:
             print("Optimization failed to produce parameters.")
             return None

        final_loss = loss_fn(optimized_params)
        print(f"Final optimized loss: {final_loss}")
    except Exception as e:
        # ... (error handling remains the same) ...
        print(f"Optimization or final loss calculation failed: {e}")
        import traceback
        traceback.print_exc()
        return None

    if not jnp.isfinite(final_loss):
        print("Final loss is not finite.")
        return None

    # Pass a key to sensitivity analysis if it also needs randomness
    # For now, assuming calculate_sensitivities doesn't need a separate key
    sensitivities = calculate_sensitivities(optimized_params, final_loss)
    sensitivity_dict = {f"sensitive of params[{i}]": round(float(sensitivities[i]), 4) 
                       for i in range(len(sensitivities))}


    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': sensitivity_dict
    }


@jit
def equation(t: jnp.array, x: jnp.array, v: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for acceleration in a damped nonlinear oscillator

    Args:
        t: A jax array representing time.
        x: A jax array representing observations of current position.
        v: A jax array representing observations of velocity.
        params: Jax array of numeric constants or parameters to be optimized

    Return:
        A jax array representing acceleration as the result of applying the mathematical function to the inputs.
    """
    dv = params[0] * jnp.sin(params[1]*t) + params[2] * x*v + params[3] * v**3 + params[4]*x * jnp.exp(params[5]*x) + x**2 * params[6] +params[7]
    
    return dv