In [None]:
import jax
import jax.numpy as jnp
from jax import config, grad, jit, vmap
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)

def evaluate(data: dict, params=initial_params) -> dict:
    t = data['t']
    x = data['x']
    v = data['v']
    a = data['a']


    @jit
    def loss_fn(params):
        pred = equation(t, x, v, params)
        true_accelerations = a[...,0]
        return jnp.mean(jnp.square(pred - true_accelerations))

    def run_optimization(objective_fn, initial_guess):
        print(initial_guess.size, MAX_NPARAMS)
        if initial_guess.size > MAX_NPARAMS:
            result = minimize(objective_fn, initial_guess,
                            method='BFGS', options={'maxiter': 500})
            return result.x
        else:
            def pso_wrapper(x):
                return objective_fn(jnp.array(x))
            
            lb = [-10.0]*initial_guess.size
            ub = [10.0]*initial_guess.size
            
            pso_params, _ = pso(pso_wrapper, lb, ub, 
                            swarmsize=100, maxiter=300,omega=0.729, phip=1.49445, phig=1.49445)
            
            result = minimize(objective_fn, jnp.array(pso_params),
                            method='BFGS', options={'maxiter': 500})
            return result.x


    # 敏感度分析模块
    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit
        def batch_loss(params_matrix):
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            initial_flat = (opt_params * mask).flatten()
            optimized_flat = run_optimization(sensitivity_objective, initial_flat)
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            losses = batch_loss(optimized_matrix)
            relative_loss = (losses - base_loss) / base_loss
            return {i: float(jnp.nan_to_num(relative_loss[i])) for i in range(MAX_NPARAMS)}
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return {i: 0.0 for i in range(MAX_NPARAMS)}


    # 主流程
    try:
        optimized_params = run_optimization(loss_fn, params)
        final_loss = loss_fn(optimized_params)
        print(f"Final loss: {final_loss}")
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None

    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': calculate_sensitivities(optimized_params, final_loss)
    }

@jit
def equation(t: jnp.array, x: jnp.array, v: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for acceleration in a damped nonlinear oscillator

    Args:
        t: A numpy array representing time.
        x: A numpy array representing observations of current position.
        v: A numpy array representing observations of velocity.
        params: Array of numeric constants or parameters to be optimized

    Return:
        A numpy array representing acceleration as the result of applying the mathematical function to the inputs.
    """
    t = t[...,0]
    x = x[...,0]
    v = v[...,0]

    dv = params[0] * jnp.sin(params[6]*t) + params[1] * x * jnp.exp(params[5] * x) +  params[2] * x * v  + params[3]*v**3 #+ params[4]*x**2 
    return dv


In [28]:
import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./train_os2.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
t = jnp.array(tae[:, 0:1], dtype=jnp.float64)  # 转换为 JAX 数组
x = jnp.array(tae[:, 1:2], dtype=jnp.float64)  # 转换为 JAX 数组
v = jnp.array(tae[:, 2:3], dtype=jnp.float64)  # 转换为 JAX 数组
a = jnp.array(tae[:, 3:4], dtype=jnp.float64)  # 转换为 JAX 数组    

# 将数据存储在字典中
data = {
    't': t,  # 时间
    'x': x,  # 位置
    'v': v,
    'a': a  # 真实的加速度
}


print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
10 10
Stopping search: Swarm best objective change less than 1e-08
Final loss: 1.4375045274456086e-08
100 10
最终损失值 (MSE): {'params': Array([ -0.3000032 ,  -4.99969831,  -1.00296103,  -0.49121451,
         1.20742845,   0.50022786,  -0.99999536,   9.01338693,
         2.82007058, -10.        ], dtype=float64), 'loss': Array(1.43750453e-08, dtype=float64), 'sensitivities': {0: 2226086.90175513, 1: 11303429.103475606, 2: 15449.119850859408, 3: 2604.7434241437636, 4: -0.9991142345887681, 5: 74025.71894724574, 6: 2226086.90175513, 7: -0.9991138505159284, 8: -0.9981675517218953, 9: -0.9987966680106485}}
