# 生成非线性项梯度

In [None]:
import jax
import jax.numpy as jnp

# --- 你的原始函数 ---
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array):
    """ Mathematical function for lagrangian in a one-dimensional physical system
    Args:
        q (jnp.array): observation of current generalized coordinate.
                         Expected shape (..., 1) or (...,) if already squeezed.
        q_t (jnp.array): observation of generalized velocity.
                         Expected shape (..., 1) or (...,) if already squeezed.
        params (jnp.array): List of numeric constants or parameters. Shape (N,).
    Returns:
        jnp.array: lagrangian as the result of applying the mathematical function.
                   Shape will match the leading dimensions of q/q_t.
    """
    # Ensure we are working with the core dimension
    if q.ndim > 0 and q.shape[-1] == 1:
        q = q[..., 0]
    if q_t.ndim > 0 and q_t.shape[-1] == 1:
        q_t = q_t[..., 0]

    # Original Calculation (assuming param indices match the description)
    # T = p[0]*(1-q_t^2)^(-1) + p[1] + p[2]*q_t^2 - p[4]
    # V = -p[3]*q + p[5]
    # L = T - V
    # It seems the original function might have missed some powers or had typos?
    # Let's refine based on common forms or stick strictly to the provided text:
    # T = params[0]*jnp.power(1-q_t**2,-1)+params[1] + params[2]*jnp.power(q_t,2) - params[4] # Term -params[4] seems unusual in T, maybe constant shift?
    # V = -params[3]*q + params[5] # Term +params[5] seems unusual in V, maybe constant shift?

    # Re-interpreting based on potential physics (e.g., relativistic-like kinetic term, linear potential)
    # Let's assume a simplified form for clarity, adjust if needed based on the *exact* intended physics.
    # Example: T = 0.5 * m * q_t^2 (using params[0] for mass 'm')
    # Example: V = 0.5 * k * q^2 (using params[1] for spring constant 'k')
    # L = T - V
    # --- Using the EXACT provided function for demonstration ---
    # Ensure params has enough elements if indices 0-5 are used.
    # Add safety checks or padding if params might be shorter.
    safe_params = jnp.pad(params, (0, max(0, 6 - len(params)))) # Pad if less than 6 params

    term0 = safe_params[0] * jnp.power(1 - q_t**2, -1) # Potential singularity if q_t approaches 1
    term1 = safe_params[1]
    term2 = safe_params[2] * jnp.power(q_t, 2)
    term3 = -safe_params[3] * q
    term4 = -safe_params[4] # Constant shift from T
    term5 = safe_params[5]  # Constant shift from V

    T = term0 + term1 + term2 + term4
    V = term3 - term5 # V = -p[3]*q + p[5] => -V = p[3]*q - p[5]

    result = T - V # L = (term0 + term1 + term2 + term4) - (term3 - term5)
                 # L = term0 + term1 + term2 - term3 + term4 + term5
    return result

# --- Function to calculate influence ---

def calculate_param_influence(q: jnp.array, q_t: jnp.array, params: jnp.array):
    """
    Calculates the influence of each parameter on the equation's output.

    Args:
        q (jnp.array): observation of current generalized coordinate.
        q_t (jnp.array): observation of generalized velocity.
        params (jnp.array): The original parameters. Shape (N,).

    Returns:
        jnp.array: The calculated 'influence' for each parameter.
                   Shape (N, ...) matching leading dimensions of q/q_t.
    """
    num_params = params.shape[0]
    indices = jnp.arange(num_params)

    # Calculate the original output
    original_output = equation(q, q_t, params)

    # Define a function that calculates the influence for a *single* index 'i'
    def influence_for_index(i, q, q_t, params, original_output):
        # Create parameters with the i-th element set to 0
        modified_params = params.at[i].set(0.0)
        # Calculate the output with the modified parameters
        modified_output = equation(q, q_t, modified_params)
        # Calculate the absolute difference
        diff = jnp.abs(original_output - modified_output)
        # Get the absolute value of the original parameter
        original_param_val = jnp.abs(params[i])
        # Calculate influence, handle division by zero if param was originally 0
        # If original_param_val is 0, setting it to 0 causes no change, so influence is 0.
        influence = jnp.where(original_param_val == 0.0,
                              0.0,
                              diff / original_param_val)
        return influence

    # Use vmap to apply 'influence_for_index' across all indices
    # Inputs to vmap:
    # - The function to vectorize: influence_for_index
    # - in_axes: Specifies how arguments map to the vectorized dimension
    #   - 0: Vectorize over the first argument (indices)
    #   - None: Broadcast the argument (q, q_t, params, original_output don't change per index)
    vectorized_influence_calc = jax.vmap(
        influence_for_index,
        in_axes=(0, None, None, None, None)
    )

    # Run the vectorized calculation
    all_influences = vectorized_influence_calc(indices, q, q_t, params, original_output)

    return all_influences

# --- Example Usage ---
key = jax.random.PRNGKey(42)

# Example inputs (adjust shapes as needed, e.g., add a batch dimension)
# Assume q and q_t represent observations for a single instance (or averaged)
# Add batch dim: q = jax.random.normal(key, (batch_size, 1))
q_example = jnp.array([0.5]) # Example coordinate shape (1,) -> becomes scalar in equation()
q_t_example = jnp.array([0.1]) # Example velocity shape (1,) -> becomes scalar in equation()

# Initial parameters (10 as requested)
params_example = jnp.arange(1, 11, dtype=jnp.float32) # Example: [1.0, 2.0, ..., 10.0]
# If your equation truly only uses params 0-5, you might adjust this:
# params_example = jnp.arange(1, 7, dtype=jnp.float32) # Example: [1.0, ..., 6.0]


print("Example q:", q_example)
print("Example q_t:", q_t_example)
print("Example Parameters:", params_example)

# Calculate influences
influences = calculate_param_influence(q_example, q_t_example, params_example)

print("\nCalculated Influences (per parameter):")
print(influences)

# Optional: If q/q_t had batch/time dimensions, you might want to average
# influences_mean = influences.mean(axis=tuple(range(1, influences.ndim))) # Average over all non-parameter axes
# print("\nMean Influences (averaged over input dimensions):")
# print(influences_mean)

# Optional: JIT compile for potential speedup on repeated calls with same shapes
jit_calculate_param_influence = jax.jit(calculate_param_influence)

print("\nRunning JIT compiled version:")
influences_jit = jit_calculate_param_influence(q_example, q_t_example, params_example)
print(influences_jit)

# Verify results are the same (within float tolerance)
assert jnp.allclose(influences, influences_jit)



In [None]:
import jax
import jax.numpy as jnp
import numpy as np # Still needed for data loading and potentially scipy interaction
from scipy.optimize import minimize

# Initialize parameters (as standard Python list or NumPy array is fine here)
MAX_NPARAMS = 10
# Initial guess for optimization (use NumPy array for scipy)
initial_params_np = np.ones(MAX_NPARAMS, dtype=np.float64) * 1.0

@jax.jit
def equation_jax(q: jnp.array, q_t: jnp.array, params: jnp.array) -> tuple[jnp.array, jnp.array]:
    """
    JAX version: Mathematical function for lagrangian and energy.
    Args:
        q (jnp.array): Observations of generalized coordinates. Shape (batch, 2).
        q_t (jnp.array): Observations of generalized velocities. Shape (batch, 2).
        params (jnp.array): List of numeric constants or parameters. Shape (MAX_NPARAMS,).
    Returns:
        tuple[jnp.array, jnp.array]: Lagrangian (T-V), Energy (T+V). Shape (batch,).
    """
    q1, q2 = q[:, 0], q[:, 1]
    q1_t, q2_t = q_t[:, 0], q_t[:, 1]

    # Ensure params has enough elements, padding with zero if necessary
    safe_params = jnp.pad(params, (0, max(0, MAX_NPARAMS - len(params))))

    # Kinetic energy (T)
    # Assuming params[0] for q1_t^2 and params[1] for q2_t^2 is more standard
    # T = safe_params[0] * q1_t**2 + safe_params[1] * q1_t**2 # Original had typo?
    T = safe_params[0] * q1_t**2 + safe_params[1] * q2_t**2
    # Note: params[2] is currently unused in this formula.

    # Potential energy (V)
    V = safe_params[3] * q1 + safe_params[4] * q2 + safe_params[5] *jnp.cos(q1)
    # Note: params[5] through params[9] are currently unused.

    lagrangian = T - V
    energy = T + V
    return lagrangian, energy

@jax.jit
def calculate_loss_jax(params: jnp.array, q: jnp.array, q_t: jnp.array, outputs1: jnp.array, outputs2: jnp.array) -> jnp.array:
    """Calculates the loss using JAX operations."""
    y_pred, energy_pred = equation_jax(q, q_t, params)
    loss1 = jnp.mean((y_pred - outputs1) ** 2)
    loss2 = 2 * jnp.mean((energy_pred - outputs2) ** 2)
    total_loss = loss1 + loss2
    return total_loss

@jax.jit
def calculate_param_influence_on_loss(
    params: jnp.array, # Typically the *optimized* parameters
    q: jnp.array,
    q_t: jnp.array,
    outputs1: jnp.array,
    outputs2: jnp.array
) -> jnp.array:
    """
    Calculates the influence of each parameter on the final loss value.

    Args:
        params (jnp.array): The parameters for which to calculate influence (usually optimized).
        q, q_t, outputs1, outputs2: Data arrays.

    Returns:
        jnp.array: The calculated 'influence' for each parameter. Shape (MAX_NPARAMS,).
    """
    num_params = params.shape[0]
    indices = jnp.arange(num_params)

    # Calculate the original loss with the given parameters
    original_loss = calculate_loss_jax(params, q, q_t, outputs1, outputs2)

    # Define a function that calculates the loss difference for a *single* index 'i'
    def loss_influence_for_index(i, current_params, q, q_t, outputs1, outputs2, original_loss_val):
        # Create parameters with the i-th element set to 0
        modified_params = current_params.at[i].set(0.0)
        # Calculate the loss with the modified parameters
        modified_loss = calculate_loss_jax(modified_params, q, q_t, outputs1, outputs2)
        # Calculate the absolute difference in loss
        loss_diff = jnp.abs(original_loss_val - modified_loss)
        # Get the absolute value of the original parameter
        original_param_val = jnp.abs(current_params[i])
        # Calculate influence, handle division by zero if param was originally 0
        influence = jnp.where(original_param_val == 0.0,
                              0.0,
                              loss_diff / original_param_val)
        return influence

    # Use vmap to apply 'loss_influence_for_index' across all indices
    vectorized_influence_calc = jax.vmap(
        loss_influence_for_index,
        in_axes=(0, None, None, None, None, None, None) # i changes, others are constant
    )

    # Run the vectorized calculation
    all_influences = vectorized_influence_calc(indices, params, q, q_t, outputs1, outputs2, original_loss)

    return all_influences

# Removed the @evaluate.run decorator
def evaluate_jax(data: dict) -> tuple[float | None, np.ndarray | None]:
    """ Evaluate the equation, optimize, and calculate parameter influences using JAX."""

    # Load data observations (as NumPy arrays initially)
    inputs_np, outputs_np = data['inputs'], data['outputs']
    q_np = inputs_np[:, :2]
    q_t_np = inputs_np[:, 2:4]
    outputs1_np = outputs_np[:, 0]
    outputs2_np = outputs_np[:, 1]

    # --- Convert data to JAX arrays for JAX functions ---
    q = jnp.asarray(q_np)
    q_t = jnp.asarray(q_t_np)
    outputs1 = jnp.asarray(outputs1_np)
    outputs2 = jnp.asarray(outputs2_np)
    # ---

    # Define the loss function wrapper for scipy.optimize.minimize
    # It takes NumPy array (from scipy), converts to JAX, calls JAX loss, returns float
    def loss_for_scipy(params_np):
        params_jax = jnp.asarray(params_np)
        loss_val = calculate_loss_jax(params_jax, q, q_t, outputs1, outputs2)
        return float(loss_val) # Return standard float for scipy

    # Optimize parameters using scipy.optimize.minimize
    # We still use the NumPy-based initial guess
    result = minimize(loss_for_scipy, initial_params_np, method='BFGS')

    optimized_params_np = result.x
    final_loss = result.fun # This is the scalar loss value

    print('Optimized parameters (NumPy array):', optimized_params_np)
    print('Final loss:', final_loss)

    if np.isnan(final_loss) or np.isinf(final_loss):
        return None, None

    # --- Calculate Parameter Influences on the Final Loss ---
    # Convert optimized params to JAX array for the influence calculation
    optimized_params_jax = jnp.asarray(optimized_params_np)

    # Calculate influences using the JAX function
    param_influences_jax = calculate_param_influence_on_loss(
        optimized_params_jax, q, q_t, outputs1, outputs2
    )

    # Convert influences back to NumPy array for return, if desired, or keep as JAX array
    param_influences_np = np.asarray(param_influences_jax)
    print('Parameter influences on loss:', param_influences_np)
    # ---

    # Return evaluation score (negative loss) and parameter influences
    return -final_loss, param_influences_np

# --- Example Usage (requires a data dictionary) ---
if __name__ == '__main__':
    # Create some dummy data for demonstration
    import pandas as pd
    # 读取 CSV 文件并转换为 NumPy 数组
    data0 = pd.read_csv('./train.csv')#
    tae = data0.to_numpy()

    # 使用 JAX 的数组操作替换 PyTorch 的操作
    state = jnp.array(tae[:, :4], dtype=jnp.float64)  # 转换为 JAX 数组
    true_q_ddot = jnp.array(tae[:, 4:-1], dtype=jnp.float64)  # 转换为 JAX 数组
    energy = jnp.array(tae[:, -1], dtype=jnp.float64)
    print(true_q_ddot)
    # 将数据存储在字典中
    data = {
        'inputs': state,
        'outputs': true_q_ddot,  # 真实的加速度
        'energy': energy
    }


    print("--- Running JAX Evaluation ---")
    neg_loss, influences = evaluate_jax(data)

    if neg_loss is not None:
        print("\n--- Results ---")
        print(f"Negative Loss (Score): {neg_loss}")
        print(f"Parameter Influences: {influences}")

        # You can now analyze the 'influences' array. Higher absolute values
        # indicate parameters that, when set to zero, caused a larger change
        # in the optimized loss (normalized by the parameter's optimized value).
        # Parameters with very low influence might be candidates for removal.
        # Also note which parameters were unused in the equation (e.g., index 2, 5-9).
        # Their influence should ideally be zero or very close to it due to floating point noise.



# 为简单起见，我们先检验使用拉格朗日量和能量作为损失的框架来测试效果


In [None]:
import jax
import jax.numpy as jnp
from jax import config, jit, vmap
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict) -> float:
    """ Evaluate the equation on data observations."""

    # Load data observations
    inputs, outputs = data['inputs'], data['outputs']
    q = inputs[:,:2]
    q_t = inputs[:,2:4]
    outputs1 = outputs[:,0]
    outputs2 = outputs[:,1]

    def loss_fn(params):
        y_pred, energy = equation(q, q_t, params)
        # 使用 jnp.mean 计算损失
        return jnp.mean((y_pred - outputs1) ** 2) + 2*jnp.mean((energy - outputs2) ** 2)

    # 使用 loss_fn 定义 loss_partial
    loss_partial = lambda params: loss_fn(params)
    loss_partial = jit(loss_fn)  # 确保JIT编译
    # 使用 initial_params 初始化优化参数
    optimized_params = initial_params
    result = minimize(loss_partial, optimized_params, method='BFGS', options={'maxiter': 1000})
    # 从优化结果中获取优化后的参数和损失
    optimized_params = result.x
    loss = result.fun

    print('optimized_params:',optimized_params)

    # 使用 jnp 检查损失值是否有效
    if jnp.isnan(loss) or jnp.isinf(loss):
        return None
    else:
        return -loss


def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for lagrangian in a physical system,which is a conservative system.
    Args:
        q (jnp.array): observations of current generalized coordinate,which dimension is rad.
        q_t (jnp.array): observations of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian (T-V) and total energy (T+V).
    """

    q1, q2 = q[:,0],q[:,1]
    q1_t, q2_t = q_t[:,0],q_t[:,1]

    # Kinetic energy (T)
    T =  params[0]*q1_t**2 + params[1]*q2_t**2
    # potential energy (V)
    V =  params[3]*q1 +params[4]*q2 + params[5]*jnp.cos(q1) +  params[6]*jnp.cos(q2) + params[7]*jnp.sin(q1) #+ params[8]*jnp.sin(q2) 
    return T - V, T + V




In [20]:

import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./train.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :4], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 4:6], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, 6:], dtype=jnp.float64)
print(energy)
print(true_q_ddot)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': energy
}


f = evaluate(data)
print(f)

[[ 1.80206776e-15 -1.80206776e-15]
 [ 7.70196859e-04 -1.14977777e-11]
 [ 3.08078740e-03 -1.21308289e-11]
 ...
 [ 4.98795246e+01 -2.30539589e-06]
 [ 4.99611381e+01 -2.79509812e-06]
 [ 5.00323659e+01 -3.35014171e-06]]
[[-9.81000000e+00  0.00000000e+00]
 [-9.80999999e+00 -2.26759624e-08]
 [-9.80999979e+00 -3.62815393e-07]
 ...
 [ 1.96515787e+01 -1.87375710e+01]
 [ 1.54711058e+01 -1.33573195e+01]
 [ 1.11144949e+01 -7.73028138e+00]]
优化参数：[  0.61962555   1.          -1.74516287  -0.99169285 -16.06962547
  -9.46751544  -0.03475219   4.1065319    1.        ]
优化参数：[  1.20214258   1.          -1.07756867  -0.30950188 -16.72852577
  -7.17620188  -0.08846496   2.146836     1.        ]
优化参数：[  0.92578317   0.51176874  -2.27120662  -1.05742179 -18.3589969
  -9.44787378  -0.09683702   5.04498894   1.        ]
优化参数：[  0.92532286   0.51136107   1.          -0.84686675 -18.43831607
  -9.37900276  -0.32259573   1.81515464   1.        ]
优化参数：[  0.92554156   0.51069574   1.          -1.38673051 -18.9628145

In [None]:
import jax
import jax.numpy as jnp
from jax import config, jit, vmap
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)

def evaluate(data: dict, initial_params_override: jnp.array = None) -> dict | None:
    """
    优化参数，评估损失，并计算所有参数的敏感性。
    将损失函数和敏感性分析逻辑封装在此函数内部。

    Args:
        data (dict): 包含 'inputs' 和 'outputs' 的数据字典。
                     outputs 预期有两列：目标 T-V 和 目标 T+V。
        initial_params_override (jnp.array, optional): 用于覆盖全局 initial_params 的初始参数。

    Returns:
        dict | None: 如果成功，返回包含以下键的字典：
                     'optimized_params': 优化后的参数数组。
                     'final_loss': 优化后的标量损失值。
                     'negative_loss': -final_loss (用于与旧版输出比较)。
                     'sensitivities': 包含所有参数索引到其敏感性得分的字典。
                     如果优化失败或损失无效，则返回 None。
    """

    # --- 损失函数定义 (嵌套在 evaluate 内部) ---
    @jit
    def loss_fn_internal(params: jnp.array, current_data: dict) -> float:
        """ 计算给定参数和数据的损失 (evaluate 内部版本)。"""
        inputs, outputs = current_data['inputs'], current_data['outputs']
        q = inputs[:,:2]
        q_t = inputs[:,2:4]
        outputs_lagrangian = outputs[:, 0] # 对应 T-V
        outputs_energy = outputs[:, 1]     # 对应 T+V
        y_pred, energy_pred = equation(q, q_t, params)
        loss_lagrangian = jnp.mean((y_pred - outputs_lagrangian) ** 2)
        loss_energy = 2 * jnp.mean((energy_pred - outputs_energy) ** 2)
        return loss_lagrangian + loss_energy

    # --- 内部辅助函数：敏感性计算 (嵌套在 evaluate 内部) ---
    def _calculate_sensitivity_internal(optimized_params: jnp.array, original_loss: float, current_data: dict, param_indices_to_test: list[int]) -> dict:
        """ 内部辅助函数：计算损失对每个指定参数的敏感性。"""
        sensitivities = {}
        print("\n--- (内部) 开始敏感性分析 ---")
        print(f"用于分析的原始损失: {original_loss:.6f}")

        for i in param_indices_to_test:
            original_param_value = optimized_params[i]
            print(f"\n测试参数索引 {i} (原始值: {original_param_value:.6f})")

            # 创建一个屏蔽函数，强制将特定参数保持为0
            def masked_loss_fn(params_subset, data, mask_idx):
                # 创建完整参数集，但将被屏蔽的参数设为0
                full_params = jnp.zeros_like(optimized_params)
                
                # 将子集参数放入正确的位置，跳过被屏蔽的参数
                idx = 0
                for j in range(len(optimized_params)):
                    if j != mask_idx:
                        full_params = full_params.at[j].set(params_subset[idx])
                        idx += 1
                
                return loss_fn_internal(full_params, data)

            # 创建不包含被测试参数的初始参数子集
            initial_params_subset = jnp.array([optimized_params[j] for j in range(len(optimized_params)) if j != i])
            
            # 为当前屏蔽的参数创建部分损失函数
            loss_partial_masked = jit(lambda p: masked_loss_fn(p, current_data, i))
            
            print(f"开始重新优化（忽略参数 {i}）...")
            try:
                # 尝试不同的优化方法（如果BFGS有问题）
                result_masked = minimize(loss_partial_masked,initial_params_subset,method='BFGS',options={'maxiter': 500})

                # 打印更详细的优化结果信息
                print(f"优化状态: {'成功' if getattr(result_masked, 'success', False) else '未成功'}")
                if hasattr(result_masked, 'status'):
                    print(f"状态码: {result_masked.status}")
                if hasattr(result_masked, 'message'):
                    print(f"消息: {result_masked.message}")
                if hasattr(result_masked, 'nfev'):
                    print(f"函数评估次数: {result_masked.nfev}")
                if hasattr(result_masked, 'nit'):
                    print(f"迭代次数: {result_masked.nit}")
                
                # 构建完整的参数集（屏蔽的参数为0）
                optimized_params_with_mask = jnp.zeros_like(optimized_params)
                idx = 0
                for j in range(len(optimized_params)):
                    if j != i:
                        optimized_params_with_mask = optimized_params_with_mask.at[j].set(result_masked.x[idx])
                        idx += 1
                
                # 计算新的最优损失
                modified_loss = loss_fn_internal(optimized_params_with_mask, current_data)
                
                # 验证损失是否有效
                if jnp.isnan(modified_loss) or jnp.isinf(modified_loss):
                    print(f"警告: 参数 {i} 的重新优化结果包含无效损失值 (NaN 或 Inf)")
                    sensitivities[i] = float('nan')  # 标记为无效
                else:
                    print(f"参数 {i} 忽略后重新优化的损失: {modified_loss:.6f}")
                    
                    loss_difference = modified_loss - original_loss
                    sensitivity = loss_difference / original_loss
                    sensitivities[i] = sensitivity
                    
                    print(f"损失差异: {loss_difference:.6f}")
                    print(f"敏感性得分: {sensitivity:.6f}")
                
            except Exception as e:
                print(f"对参数 {i} 进行重新优化时发生错误: {e}")
                sensitivities[i] = float('nan')  # 标记失败的敏感性计算
        
        return sensitivities

    # --- evaluate 函数主体逻辑 ---
    # ... (确定 current_initial_params 的逻辑不变) ...
    current_initial_params = initial_params_override if initial_params_override is not None else initial_params
    print(f"使用的初始参数: {current_initial_params}")

    # 1. 参数优化
    loss_partial = jit(lambda p: loss_fn_internal(p, data))
    print("开始优化...")
    try:
        result = minimize(loss_partial, current_initial_params, method='BFGS', options={'maxiter': 500})
    except Exception as e:
        print(f"优化过程中发生错误: {e}")
        return None

    if not result.success:
        print(f"优化失败: {result.message}")
        # pass # 允许继续，但结果可能不可靠

    optimized_params = result.x
    final_loss = loss_fn_internal(optimized_params, data)

    print('优化后的参数:', optimized_params)
    # print('最终损失 (来自优化器):', result.fun) # 可以取消注释以进行比较
    print('最终损失 (重新计算):', final_loss)

    if jnp.isnan(final_loss) or jnp.isinf(final_loss):
        print("警告: 优化结果包含无效 (NaN 或 Inf) 的损失值。")
        return None

    # 2. 敏感性分析 (如果优化成功且损失有效)
    # ***更正***: 定义要测试的所有参数索引
    # 使用 range(len(optimized_params)) 来获取从 0 到 n-1 的所有索引
    all_param_indices = list(range(len(optimized_params)))

    print(f"\n将对以下参数索引进行敏感性分析: {all_param_indices}")

    if not all_param_indices: # 检查列表是否为空 (虽然不太可能，除非优化参数为空)
         print("警告：没有有效的参数索引用于敏感性分析。")
         sensitivities = {}
    else:
        # 调用内部定义的敏感性分析函数，传入所有索引
        sensitivities = _calculate_sensitivity_internal(
            optimized_params,
            final_loss,
            data,
            all_param_indices # <--- 使用包含所有索引的列表
        )

    # 3. 返回结果
    return {
        'optimized_params': optimized_params,
        'final_loss': final_loss,
        'negative_loss': -final_loss,
        'sensitivities': sensitivities # 现在包含所有测试过的参数的敏感性
    }

def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for lagrangian in a physical system,which is a conservative system.
    Args:
        q (jnp.array): observations of current generalized coordinate,which dimension is rad.
        q_t (jnp.array): observations of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian (T-V) and total energy (T+V).
    """

    q1, q2 = q[:,0],q[:,1]
    q1_t, q2_t = q_t[:,0],q_t[:,1]

    # Kinetic energy (T)
    T =  params[0]*q1_t**2 + params[1]*q2_t**2
    # potential energy (V)
    V = params[3]*q1 + params[4]*q2 + params[5]*jnp.cos(q1) +   params[6]*jnp.cos(q2) + params[8]*jnp.sin(q1) + params[7]*jnp.sin(q2) 
    return T - V, T + V

### 循环计算每一个参数的伪梯度，效果良好，最终优化的代码，最简洁

In [19]:
import jax
import jax.numpy as jnp
from jax import config, jit, vmap
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, initial_params_override: jnp.array = None) -> dict | None:
    @jit
    def loss_fn_internal(params: jnp.array, current_data: dict) -> float:
        inputs, outputs = current_data['inputs'], current_data['outputs']
        q = inputs[:,:2]
        q_t = inputs[:,2:4]
        outputs_lagrangian = outputs[:, 0]
        outputs_energy = outputs[:, 1]
        y_pred, energy_pred = equation(q, q_t, params)
        loss_lagrangian = jnp.mean((y_pred - outputs_lagrangian) ** 2)
        loss_energy = 2 * jnp.mean((energy_pred - outputs_energy) ** 2)
        return loss_lagrangian + loss_energy

    def _calculate_sensitivity_internal(optimized_params: jnp.array, original_loss: float, 
                                   current_data: dict, param_indices_to_test: list[int]) -> dict:
        sensitivities = {}
        
        for i in param_indices_to_test:
            mask = jnp.ones_like(optimized_params, dtype=bool)
            mask = mask.at[i].set(False)
            initial_params_subset = optimized_params[mask]
    
            @jit
            def masked_loss_fn(params_subset):
                full_params = jnp.zeros_like(optimized_params)
                full_params = full_params.at[mask].set(params_subset)
                return loss_fn_internal(full_params, current_data)
            
            try:
                result_masked = minimize(masked_loss_fn, initial_params_subset, method='BFGS', options={'maxiter': 500})
                print(f"优化参数：{result_masked.x}")
                optimized_params_with_mask = jnp.zeros_like(optimized_params)
                optimized_params_with_mask = optimized_params_with_mask.at[mask].set(result_masked.x)
                
                modified_loss = loss_fn_internal(optimized_params_with_mask, current_data)
                
                sensitivities[i] = jnp.where(
                    jnp.isnan(modified_loss) | jnp.isinf(modified_loss),
                    jnp.nan,
                    (modified_loss - original_loss) / original_loss
                )
            except Exception as e:
                sensitivities[i] = float('nan')
        
        return sensitivities

    current_initial_params = initial_params_override if initial_params_override is not None else initial_params

    loss_partial = jit(lambda p: loss_fn_internal(p, data))
    try:
        result = minimize(loss_partial, current_initial_params, method='BFGS', options={'maxiter': 500})
    except Exception as e:
        return None

    optimized_params = result.x
    final_loss = loss_fn_internal(optimized_params, data)

    if jnp.isnan(final_loss) or jnp.isinf(final_loss):
        return None

    all_param_indices = list(range(len(optimized_params)))
    sensitivities = _calculate_sensitivity_internal(optimized_params, final_loss, data, all_param_indices)

    return {
        'optimized_params': optimized_params,
        'final_loss': final_loss,
        'negative_loss': -final_loss,
        'sensitivities': sensitivities
    }


def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for lagrangian in a physical system,which is a conservative system.
    Args:
        q (jnp.array): observations of current generalized coordinate,which dimension is rad.
        q_t (jnp.array): observations of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian (T-V) and total energy (T+V).
    """

    q1, q2 = q[:,0],q[:,1]
    q1_t, q2_t = q_t[:,0],q_t[:,1]

    # Kinetic energy (T)
    T =  params[0]*q1_t**2 + params[1]*q2_t**2
    # potential energy (V)
    V = params[3]*q1 + params[4]*q2 + params[5]*jnp.cos(q1) +   params[6]*jnp.cos(q2) + params[8]*jnp.sin(q1) + params[7]*jnp.sin(q2) 
    return T - V, T + V

# 以上代码已经能够运行，我们尝试能否优化代码，加速计算.
优化前：12.4s，10.0S   
对应的参数的敏感度：{0: Array(1.84394218, dtype=float64), 1: Array(2.61179105, dtype=float64), 2: Array(0., dtype=float64), 3: Array(0.00395862, dtype=float64), 4: Array(0.0077421, dtype=float64), 5: Array(2.23812998, dtype=float64), 6: Array(0.4276709, dtype=float64), 7: Array(2.36861228e-05, dtype=float64), 8: Array(0.00891643, dtype=float64), 9: Array(0., dtype=float64)}   
优化后：2.3s   
对应参数的敏感度：   
'sensitivities': {0: 1.8530223881547223, 1: 2.627426792901825, 2: -1.6278262024711505e-14, 3: 0.2583657524769522, 4: 0.07432524506118177, 5: 2.6259165000759235, 6: 0.5372272257631501, 7: 5.758876017124561e-05, 8: 0.029088609674932472, 9: -1.6278262024711505e-14}}

In [5]:
import jax
import jax.numpy as jnp
from jax import config, jit, vmap, grad
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, initial_params_override: jnp.array = None) -> dict | None:
    @jit
    def loss_fn_internal(params: jnp.array, current_data: dict) -> float:
        inputs, outputs = current_data['inputs'], current_data['outputs']
        q = inputs[:,:2]
        q_t = inputs[:,2:4]
        outputs_lagrangian = outputs[:, 0]
        outputs_energy = outputs[:, 1]
        y_pred, energy_pred = equation(q, q_t, params)
        loss_lagrangian = jnp.mean((y_pred - outputs_lagrangian) ** 2)
        loss_energy = 2 * jnp.mean((energy_pred - outputs_energy) ** 2)
        return loss_lagrangian + loss_energy
    
    def calculate_masked_sensitivities_batch(optimized_params, original_loss, current_data):
        # 创建掩码参数矩阵 - 每行代表一组去掉一个参数的配置
        # 初始化为与optimized_params相同的值
        masked_params_matrix = jnp.tile(optimized_params, (MAX_NPARAMS, 1))
        
        # 对每行设置对应位置的参数为0（即掩码）
        for i in range(MAX_NPARAMS):
            masked_params_matrix = masked_params_matrix.at[i, i].set(0.0)
        
        # 定义批量损失函数 - 计算每个掩码配置的损失
        @jit
        def batch_masked_loss_fn(params_matrix):
            # 计算每行参数配置的损失
            def compute_loss_for_row(row_params):
                return loss_fn_internal(row_params, current_data)
            
            # 使用vmap并行计算所有行的损失
            losses = jax.vmap(compute_loss_for_row)(params_matrix)
            
            # 返回所有损失的总和 - 这是我们要优化的目标
            return jnp.sum(losses)
        
        # 对batch_masked_loss_fn进行优化
        # 注意：我们需要定义一个函数来维持特定位置的掩码（即保持对应位置为0）
        def optimize_with_masks():
            # 创建初始掩码矩阵的平坦版本（展平为一维数组）
            initial_flat = masked_params_matrix.flatten()
            
            # 定义优化函数，保持掩码位置为0
            def masked_opt_fn(flat_params):
                # 重塑为矩阵
                reshaped = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
                # 确保掩码位置为0
                for i in range(MAX_NPARAMS):
                    reshaped = reshaped.at[i, i].set(0.0)
                # 计算损失
                return batch_masked_loss_fn(reshaped)
            
            # 执行优化
            try:
                result = minimize(masked_opt_fn, initial_flat, method='BFGS', options={'maxiter': 500})
                # 将结果重塑为矩阵，并确保掩码位置为0
                optimized_matrix = result.x.reshape(MAX_NPARAMS, MAX_NPARAMS)
                for i in range(MAX_NPARAMS):
                    optimized_matrix = optimized_matrix.at[i, i].set(0.0)
                return optimized_matrix
            except Exception as e:
                print(f"Optimization error: {str(e)}")
                return masked_params_matrix  # 返回初始矩阵作为后备
        
        # 执行优化并获取优化后的掩码参数矩阵
        optimized_matrix = optimize_with_masks()
        
        # 计算每个掩码配置的损失和敏感度
        sensitivities = {}
        
        for i in range(MAX_NPARAMS):
            # 获取第i行的掩码参数配置
            masked_params = optimized_matrix[i]
            
            # 计算该配置下的损失
            masked_loss = loss_fn_internal(masked_params, current_data)
            
            # 计算敏感度
            sensitivity = jnp.where(
                jnp.isnan(masked_loss) | jnp.isinf(masked_loss),
                jnp.nan,
                (masked_loss - original_loss) / original_loss
            )
            
            sensitivities[i] = float(sensitivity)
        
        return sensitivities

    current_initial_params = initial_params_override if initial_params_override is not None else initial_params

    loss_partial = jit(lambda p: loss_fn_internal(p, data))
    try:
        result = minimize(loss_partial, current_initial_params, method='BFGS', options={'maxiter': 500})
    except Exception as e:
        return None

    optimized_params = result.x
    final_loss = loss_fn_internal(optimized_params, data)

    if jnp.isnan(final_loss) or jnp.isinf(final_loss):
        return None

    # 计算各个参数的敏感度
    sensitivities = calculate_masked_sensitivities_batch(optimized_params, final_loss, data)

    return {
        'optimized_params': optimized_params,
        'final_loss': final_loss,
        'negative_loss': -final_loss,
        'sensitivities': sensitivities
    }


def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for lagrangian in a physical system,which is a conservative system.
    Args:
        q (jnp.array): observations of current generalized coordinate,which dimension is rad.
        q_t (jnp.array): observations of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian (T-V) and total energy (T+V).
    """

    q1, q2 = q[:,0],q[:,1]
    q1_t, q2_t = q_t[:,0],q_t[:,1]

    # Kinetic energy (T)
    T = params[0]*q1_t**2 + params[1]*q2_t**2 
    # potential energy (V)
    V = params[3]*q1 + params[4]*q2 + params[5]*jnp.cos(q1) + params[6]*jnp.cos(q2) + \
        params[8]*jnp.sin(q1) + params[7]*jnp.sin(q2) 
    
    return T - V, T + V


In [6]:

import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./train.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :4], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 4:6], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, 6:], dtype=jnp.float64)
print(energy)
print(true_q_ddot)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': energy
}


f = evaluate(data)
print(f)

[[ 1.80206776e-15 -1.80206776e-15]
 [ 7.70196859e-04 -1.14977777e-11]
 [ 3.08078740e-03 -1.21308289e-11]
 ...
 [ 4.98795246e+01 -2.30539589e-06]
 [ 4.99611381e+01 -2.79509812e-06]
 [ 5.00323659e+01 -3.35014171e-06]]
[[-9.81000000e+00  0.00000000e+00]
 [-9.80999999e+00 -2.26759624e-08]
 [-9.80999979e+00 -3.62815393e-07]
 ...
 [ 1.96515787e+01 -1.87375710e+01]
 [ 1.54711058e+01 -1.33573195e+01]
 [ 1.11144949e+01 -7.73028138e+00]]
{'optimized_params': Array([  0.92578317,   0.51176874,   1.        ,  -2.27120662,
        -1.05742179, -18.3589969 ,  -9.44787378,  -0.09683702,
         5.04498894,   1.        ], dtype=float64), 'final_loss': Array(105.6324943, dtype=float64), 'negative_loss': Array(-105.6324943, dtype=float64), 'sensitivities': {0: 1.8530223881547223, 1: 2.627426792901825, 2: -1.6278262024711505e-14, 3: 0.2583657524769522, 4: 0.07432524506118177, 5: 2.6259165000759235, 6: 0.5372272257631501, 7: 5.758876017124561e-05, 8: 0.029088609674932472, 9: -1.6278262024711505e-14}}


In [21]:
import jax
import jax.numpy as jnp
from jax import config, jit, vmap, grad
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, initial_params_override: jnp.array = None) -> dict | None:
    @jit
    def loss_fn_internal(params: jnp.array, current_data: dict) -> float:
        inputs, outputs = current_data['inputs'], current_data['outputs']
        q = inputs[:,:2]
        q_t = inputs[:,2:4]
        outputs_lagrangian = outputs[:, 0]
        outputs_energy = outputs[:, 1]
        y_pred, energy_pred = equation(q, q_t, params)
        loss_lagrangian = jnp.mean((y_pred - outputs_lagrangian) ** 2)
        loss_energy = 2 * jnp.mean((energy_pred - outputs_energy) ** 2)
        return loss_lagrangian + loss_energy
    
    def calculate_masked_sensitivities_batch(optimized_params, original_loss, current_data):
        # 使用矩阵运算创建掩码参数矩阵
        mask = 1 - jnp.eye(MAX_NPARAMS, dtype=optimized_params.dtype)
        masked_params_matrix = optimized_params * mask

        @jit
        def batch_masked_loss(params_matrix):
            return vmap(loss_fn_internal, in_axes=(0, None))(params_matrix, current_data)

        # 优化后的掩码优化函数
        def optimize_with_masks():
            def masked_opt_fn(flat_params):
                reshaped = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS) * mask
                return jnp.sum(batch_masked_loss(reshaped))
            
            result = minimize(masked_opt_fn, masked_params_matrix.flatten(),
                            method='BFGS', options={'maxiter': 500})
            return result.x.reshape(MAX_NPARAMS, MAX_NPARAMS) * mask

        try:
            optimized_matrix = optimize_with_masks()
        except Exception as e:
            print(f"Optimization error: {str(e)}")
            return {i: float('nan') for i in range(MAX_NPARAMS)}

        # 向量化计算敏感度
        losses = batch_masked_loss(optimized_matrix)
        relative_loss = (losses - original_loss) / original_loss
        return {i: float(jnp.nan_to_num(relative_loss[i])) for i in range(MAX_NPARAMS)}

    
    current_initial_params = initial_params_override if initial_params_override is not None else initial_params

    loss_partial = jit(lambda p: loss_fn_internal(p, data))
    try:
        result = minimize(loss_partial, current_initial_params, method='BFGS', options={'maxiter': 500})
    except Exception as e:
        return None

    optimized_params = result.x
    final_loss = loss_fn_internal(optimized_params, data)

    if jnp.isnan(final_loss) or jnp.isinf(final_loss):
        return None

    # 计算各个参数的敏感度
    sensitivities = calculate_masked_sensitivities_batch(optimized_params, final_loss, data)

    return {
        'optimized_params': optimized_params,
        'final_loss': final_loss,
        'negative_loss': -final_loss,
        'sensitivities': sensitivities
    }


def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for lagrangian in a physical system,which is a conservative system.
    Args:
        q (jnp.array): observations of current generalized coordinate,which dimension is rad.
        q_t (jnp.array): observations of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian (T-V) and total energy (T+V).
    """

    q1, q2 = q[:,0],q[:,1]
    q1_t, q2_t = q_t[:,0],q_t[:,1]

    # Kinetic energy (T)
    T = params[0]*q1_t**2 + params[1]*q2_t**2 
    # potential energy (V)
    V = params[3]*q1 + params[4]*q2 + params[5]*jnp.cos(q1) + params[6]*jnp.cos(q2) + \
        params[8]*jnp.sin(q1) + params[7]*jnp.sin(q2) 
    
    return T - V, T + V


In [22]:

import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./train.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :4], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 4:6], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, 6:], dtype=jnp.float64)
print(energy)
print(true_q_ddot)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': energy
}


f = evaluate(data)
print(f)

[[ 1.80206776e-15 -1.80206776e-15]
 [ 7.70196859e-04 -1.14977777e-11]
 [ 3.08078740e-03 -1.21308289e-11]
 ...
 [ 4.98795246e+01 -2.30539589e-06]
 [ 4.99611381e+01 -2.79509812e-06]
 [ 5.00323659e+01 -3.35014171e-06]]
[[-9.81000000e+00  0.00000000e+00]
 [-9.80999999e+00 -2.26759624e-08]
 [-9.80999979e+00 -3.62815393e-07]
 ...
 [ 1.96515787e+01 -1.87375710e+01]
 [ 1.54711058e+01 -1.33573195e+01]
 [ 1.11144949e+01 -7.73028138e+00]]
{'optimized_params': Array([  0.92578317,   0.51176874,   1.        ,  -2.27120662,
        -1.05742179, -18.3589969 ,  -9.44787378,  -0.09683702,
         5.04498894,   1.        ], dtype=float64), 'final_loss': Array(105.6324943, dtype=float64), 'negative_loss': Array(-105.6324943, dtype=float64), 'sensitivities': {0: 1.8530223881597834, 1: 2.6274267929103012, 2: -1.6143730933598184e-14, 3: 0.2583657521740584, 4: 0.07432524495188332, 5: 2.625916500241012, 6: 0.5372272258175684, 7: 5.758875910575937e-05, 8: 0.029088609628322156, 9: -1.6143730933598184e-14}}


### 包含pso优化的双摆例子进行梯度计算

In [25]:
import jax
import jax.numpy as jnp
from jax import config, grad, jit, vmap
import numpy as np
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = [jnp.array(1.0)]*MAX_NPARAMS



def evaluate(data: dict, params=initial_params) -> float:

    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q = inputs[:, :n_dim]
    q_t = inputs[:, n_dim:]
    true_accelerations = outputs


    # 修改：将lagrangian直接嵌入compute_acceleration中，避免将函数作为参数传递
    @jit
    def compute_acceleration(q, q_t, params):
        # 直接使用equation函数而不是传递lagrangian函数
        
        # 预先计算梯度
        def lag(q, q_t, params):
            return equation(q, q_t, params)
            
        # 计算二阶导数
        hessian_q_t = jax.hessian(lag, 1)(q, q_t, params)
        grad_q = jax.grad(lag, 0)(q, q_t, params)
        jacobian_q_q_t = jax.jacobian(jax.grad(lag, 1), 0)(q, q_t, params)
        q_tt = jnp.linalg.pinv(hessian_q_t) @ (grad_q - jacobian_q_q_t @ q_t)
        return q_tt

    # 使用vmap批处理
    batch_compute_acceleration = jit(vmap(compute_acceleration, in_axes=(0, 0, None)))

    # 使用jit装饰器优化损失函数
    @jit
    def loss_fn(params):
        predicted_accelerations = batch_compute_acceleration(q, q_t, params)
        return jnp.mean(jnp.square(predicted_accelerations - true_accelerations))

    # 针对numpy数组的包装函数
    def objective(params):
        try:
            params = jnp.array(params)  # 确保转换为JAX数组
            loss_value = loss_fn(params)
            return float(loss_value)
        except Exception as e:
            print(f"Error in objective function: {e}")
            # 返回一个大的损失值，避免优化器选择这个点
            return 1e10

    # 粒子群优化的参数
    lb = [-1.0] * len(initial_params)  # 参数下限
    ub = [10.0] * len(initial_params)   # 参数上限

    # 调用 pso 函数进行优化
    optimized_params, optimized_loss = pso(objective, lb, ub, swarmsize=30, maxiter=500)

    print("pso Optimized parameters:", optimized_params)
    print("pso Optimized loss:", optimized_loss)

    # 使用JAX的优化器进一步优化
    loss_partial = jit(loss_fn)  # 确保JIT编译
    result = minimize(loss_partial, optimized_params, method='BFGS', options={'maxiter': 1000})
    optimized_params = result.x
    loss = result.fun

    if jnp.isnan(loss) or jnp.isinf(loss):
        return None
    else:
        print(optimized_params)
        return -loss.item()


@jit  # 添加JIT装饰器
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for lagrangian in a physical system
    Args:
        q (jnp.array): observations of current generalized coordinate.
        q_t (jnp.array): observations of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian as the result of applying the mathematical function to the inputs.
    """
    """Improved version of `equation_v1` with additional parameters for damping and external forces."""
    q1, q2 = q[...,0], q[..., 1]
    q1_t, q2_t = q_t[..., 0], q_t[..., 1]

    m1, m2, l1, l2, g = params[0], params[1], params[2], params[3], params[4]
    c1, c2 = params[5], params[6]  # damping coefficients for each pendulum
    F1, F2 = params[7], params[8]  # external forces for each pendulum
    k = params[9]  # coefficient for interaction term between q1 and q2

    # Kinetic energy terms (considering interactions between coordinates)
    T = (0.5 * m1 * l1**2 * q1_t**2 +
         0.5 * m2 * (l1**2 * q1_t**2 + l2**2 * q2_t**2 + 2 * l1 * l2 * q1_t * q2_t * jnp.cos(q1 - q2)))

    # Damping terms
    D = (c1 * q1_t +
         c2 * q2_t)

    # Potential energy terms (gravity and external forces)
    V = -g * (m1 * l1 * jnp.cos(q1) + m2 * (l1 * jnp.cos(q1) + l2 * jnp.cos(q2))) + \
        F1 * q1 + F2 * q2


    # Lagrangian including potential, kinetic, damping, and interaction terms
    return T - V - D 

In [1]:
import jax
import jax.numpy as jnp
from jax import config, grad, jit, vmap
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)

def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q, q_t = inputs[:, :n_dim], inputs[:, n_dim:]
    true_accelerations = outputs

    @jit
    def compute_acceleration(q, q_t, params):
        def lag(q, q_t, params):
            return equation(q, q_t, params)
        hessian_q_t = jax.hessian(lag, 1)(q, q_t, params)
        grad_q = jax.grad(lag, 0)(q, q_t, params)
        jacobian_q_q_t = jax.jacobian(jax.grad(lag, 1), 0)(q, q_t, params)
        return jnp.linalg.pinv(hessian_q_t) @ (grad_q - jacobian_q_q_t @ q_t)

    batch_compute_acceleration = jit(vmap(compute_acceleration, (0, 0, None)))

    @jit
    def loss_fn(params):
        pred = batch_compute_acceleration(q, q_t, params)
        return jnp.mean(jnp.square(pred - true_accelerations))

    def run_optimization(objective_fn, initial_guess):
        if initial_guess.size > MAX_NPARAMS:
            result = minimize(objective_fn, initial_guess,
                            method='BFGS', options={'maxiter': 500})
            return result.x
        else:
            def pso_wrapper(x):
                return objective_fn(jnp.array(x))
            
            lb = [-1.0]*initial_guess.size
            ub = [10.0]*initial_guess.size
            
            pso_params, _ = pso(pso_wrapper, lb, ub, 
                            swarmsize=30, maxiter=200)
            
            result = minimize(objective_fn, jnp.array(pso_params),
                            method='BFGS', options={'maxiter': 500})
            return result.x


    # 敏感度分析模块
    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit
        def batch_loss(params_matrix):
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            initial_flat = (opt_params * mask).flatten()
            optimized_flat = run_optimization(sensitivity_objective, initial_flat)
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            losses = batch_loss(optimized_matrix)
            relative_loss = (losses - base_loss) / base_loss
            return {i: float(jnp.nan_to_num(relative_loss[i])) for i in range(MAX_NPARAMS)}
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return {i: 0.0 for i in range(MAX_NPARAMS)}


    # 主流程
    try:
        optimized_params = run_optimization(loss_fn, params)
        final_loss = loss_fn(optimized_params)
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None

    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': calculate_sensitivities(optimized_params, final_loss)
    }

@jit
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    q1, q2 = q[...,0], q[...,1]
    q1_t, q2_t = q_t[...,0], q_t[...,1]
    
    m1, m2, l1, l2, g = params[0], params[1], params[2], params[3], params[4]
    c1, c2, F1, F2, k = params[5], params[6], params[7], params[8], params[9]
    
    T = 0.5*(m1*l1**2*q1_t**2 + m2*(l1**2*q1_t**2 + l2**2*q2_t**2 + 2*l1*l2*q1_t*q2_t*jnp.cos(q1-q2)))
    V = -g*(m1*l1*jnp.cos(q1) + m2*(l1*jnp.cos(q1)+l2*jnp.cos(q2))) + F1*q1 + F2*q2
    D = c1*q1_t + c2*q2_t
    return T - V - D


In [2]:
import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./double_pendulum_data_with_energy.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :4], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 4:-1], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, -1], dtype=jnp.float64)
print(true_q_ddot)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
    'energy': energy
}

''''inputs, outputs = data['inputs'], data['outputs']
print(inputs.shape, outputs.shape)
n_dim = inputs.shape[1] // 2
q = inputs[:, :n_dim]
q_t = inputs[:, n_dim:]
true_accelerations = outputs

q1, q2 = q[...,0], q[..., 1]
q1_t, q2_t = q_t[..., 0], q_t[..., 1]
print(q1.shape, q2.shape, q1_t.shape, q2_t.shape)

print(q.shape, q_t.shape, true_accelerations.shape)



print(equation(q,q_t,initial_params))'''


print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

[[-9.81000000e+00  0.00000000e+00]
 [-9.80999997e+00 -5.53723182e-08]
 [-9.80999948e+00 -8.85956966e-07]
 ...
 [ 1.52940364e+01 -1.63022950e+01]
 [ 1.48602131e+01 -1.55387182e+01]
 [ 1.43807650e+01 -1.47005543e+01]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Stopping search: Swarm best objective change less than 1e-08
最终损失值 (MSE): {'params': Array([1.03445108e+01, 1.03445105e+01, 2.01300011e+00, 2.01300034e+00,
       1.97475313e+01, 9.84050994e-01, 4.90573269e+00, 1.18473294e-05,
       9.51485829e-06, 2.57651355e-01], dtype=float64), 'loss': Array(9.69115154e-13, dtype=float64), 'sensitivities': {0: 3.20130858136045e+19, 1: 139222546490502.81, 2: 119653707896404.27, 3: 139222547621080.39, 4: 298745128231405.3, 5: 2676.4397483168295, 6: 2676.4397483168295, 7: 2658.04703368037, 8: 2721.4478237104254, 9: 2676.4397483168295}}


#### 结构化优化的过程中，我们发现，对于复杂式子的优化和伪梯度计算，会出现不该有的项敏感度为负值，这是我们所希望的结果，但偶尔也会出现敏感度为几千，但该有的项敏感度为8，9，甚至10几次方量级，故而，可以使用概率归一化来缩小数值，以便让大模型更好的分析。
不行，数值相差太大是，其他值都为0，一个数为1，绝对值归一化后也会导致过于大的值占主导，其余值基本保持一致。
好吧，尝试各种方法，对于绝对大的敏感度好像无法量化衡量，其余相比之下都为0.

使得中位数为敏感度为1,不去归一化，某一项绝对大   
$$ sensetive = e^{\frac{x-x_d}{x_d}}$$ 

In [None]:
import jax
import jax.numpy as jnp
from jax import config, grad, jit, vmap
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)

def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q, q_t = inputs[:, :n_dim], inputs[:, n_dim:]
    true_accelerations = outputs

    @jit
    def compute_acceleration(q, q_t, params):
        def lag(q, q_t, params):
            return equation(q, q_t, params)
        hessian_q_t = jax.hessian(lag, 1)(q, q_t, params)
        grad_q = jax.grad(lag, 0)(q, q_t, params)
        jacobian_q_q_t = jax.jacobian(jax.grad(lag, 1), 0)(q, q_t, params)
        return jnp.linalg.pinv(hessian_q_t) @ (grad_q - jacobian_q_q_t @ q_t)

    batch_compute_acceleration = jit(vmap(compute_acceleration, (0, 0, None)))

    @jit
    def loss_fn(params):
        pred = batch_compute_acceleration(q, q_t, params)
        return jnp.mean(jnp.square(pred - true_accelerations))

    def run_optimization(objective_fn, initial_guess):
        """处理任意维度的参数优化"""
        # PSO阶段需要特殊处理高维参数
        if initial_guess.size > MAX_NPARAMS:
            result = minimize(objective_fn, initial_guess,
                            method='BFGS', options={'maxiter': 500})
            return result.x
        else:
            # 原始混合优化逻辑
            def pso_wrapper(x):
                return objective_fn(jnp.array(x))
            
            lb = [-1.0]*initial_guess.size
            ub = [10.0]*initial_guess.size
            
            pso_params, _ = pso(pso_wrapper, lb, ub, 
                            swarmsize=30, maxiter=200)
            
            result = minimize(objective_fn, jnp.array(pso_params),
                            method='BFGS', options={'maxiter': 500})
            return result.x


    # 敏感度分析模块
    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit
        def batch_loss(params_matrix):
            # 确保输入是正确形状的矩阵
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            # 显式重塑参数形状
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            # 生成正确的初始参数（复制参数矩阵并应用掩码）
            initial_flat = (opt_params * mask).flatten()
            
            # 执行优化（确保处理100个参数）
            optimized_flat = run_optimization(sensitivity_objective, initial_flat)
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            
            # 计算损失差异
            losses = batch_loss(optimized_matrix)
            raw_relative = (losses - base_loss) / base_loss
            scaled = jnp.nan_to_num(raw_relative, nan=0.0, posinf=0.0, neginf=0.0)
            
            # 新增加的核心计算逻辑
            x_d = jnp.median(scaled)
            epsilon = 1e-8  # 防止除以零
            adjusted_xd = x_d 
            
            # 应用指数公式
            relative_to_median = (scaled - x_d) / adjusted_xd
            weights = jnp.exp(relative_to_median)
            
            # 保持符号（相对于中位数的变化方向）
            final_sensitivities = weights * jnp.sign(scaled)

            return {i: float(final_sensitivities[i]) for i in range(MAX_NPARAMS)}
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return {i: 0.0 for i in range(MAX_NPARAMS)}


    # 主流程
    try:
        optimized_params = run_optimization(loss_fn, params)
        final_loss = loss_fn(optimized_params)
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None

    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': calculate_sensitivities(optimized_params, final_loss)
    }

@jit
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    q1, q2 = q[...,0], q[...,1]
    q1_t, q2_t = q_t[...,0], q_t[...,1]
    
    m1, m2, l1, l2, g = params[0], params[1], params[2], params[3], params[4]
    c1, c2, F1, F2, k = params[5], params[6], params[7], params[8], params[9]
    
    T = 0.5*(m1*l1**2*q1_t**2 + m2*(l1**2*q1_t**2 + l2**2*q2_t**2 + 2*l1*l2*q1_t*q2_t*jnp.cos(q1-q2)))
    V = -g*(m1*l1*jnp.cos(q1) + m2*(l1*jnp.cos(q1)+l2*jnp.cos(q2))) #+ F1*q1 + F2*q2
    D = c1*q1_t + c2*q2_t
    return T - V - D


In [10]:
import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./double_pendulum_data_with_energy.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :4], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 4:-1], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, -1], dtype=jnp.float64)
print(true_q_ddot)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
    'energy': energy
}


print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

[[-9.81000000e+00  0.00000000e+00]
 [-9.80999997e+00 -5.53723182e-08]
 [-9.80999948e+00 -8.85956966e-07]
 ...
 [ 1.52940364e+01 -1.63022950e+01]
 [ 1.48602131e+01 -1.55387182e+01]
 [ 1.43807650e+01 -1.47005543e+01]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Stopping search: Swarm best position change less than 1e-08
最终损失值 (MSE): {'params': Array([ 5.728103  ,  5.72810327,  1.0260507 ,  1.02605066, 10.06555711,
        5.79314061,  6.02788142,  9.16491405,  5.19117204, 10.        ],      dtype=float64), 'loss': Array(1.41403035e-13, dtype=float64), 'sensitivities': {0: inf, 1: 3.3242815168480804, 2: 2.718281828323312, 3: 3.3242815168480804, 4: 366.9478979839086, 5: 0.3678794411898118, 6: 0.3678794411898118, 7: 0.3678794411898118, 8: 0.3678794411898118, 9: 0.3678794411898118}}


## 来测试一下其他的优化器：
### 1.量子隧穿优化，有一次优化到-19次方，牛逼，运行16S左右，和粒子群差不多,但在参数未全部使用的情况下，无效   
优化损失: [3.10366193e+07 1.42791394e+02 1.36837775e+02 1.42791518e+02
 1.73510532e+07 1.66615206e-14 1.66615206e-14 1.74181776e-14
 1.55317123e-14 1.66615206e-14]
最终损失值 (MSE): {'params': Array([ 1.38116467e+00,  1.38116467e+00,  5.92812112e-01,  5.92812112e-01,
        5.81548682e+00, -2.47551587e-02,  4.32660101e-01,  1.83811304e-10,
        9.94207883e-11,  7.87441955e-01], dtype=float64), 'loss': Array(1.10887512e-19, dtype=float64), 'sensitivities': {0: 2.7989282841833538e+26, 1: 1.2877139359385067e+21, 2: 1.2340233210216482e+21, 3: 1.287715049505326e+21, 4: 1.5647436669708444e+26, 5: 150255.05974498438, 6: 150255.05974498438, 7: 157078.7052214516, 8: 140066.28182291775, 9: 150255.05974498438}}
### 2.神经引导优化，7.6s就能优化到-13次方，，同量子隧穿，参数未全部使用则无效。但是对于掩码后的参数优化，不能有效适配,不同网络不稳定，使用了稳定的网络到-15，但是要1分半钟，优化到22秒，但是调了几个超参，发现影响很大，这说明对于不同问题，还需要调超惨，这种方法不太行。   
优化损失: [1.93514032e+07 1.43640117e+02 1.43390734e+02 1.54258418e+02
 3.66413656e+03 5.23901447e+07 1.34418381e+07 3.54652590e+04
 2.75959884e+02 2.89812589e+02]
最终损失值 (MSE): {'params': Array([ 1.26675148e+00,  1.26675151e+00,  7.75105047e-01,  7.75105064e-01,
        7.60378114e+00,  7.11072584e-01,  7.69675069e-01, -3.48068037e-07,
       -1.08772388e-07,  7.48417784e-01], dtype=float64), 'loss': Array(3.22607551e-13, dtype=float64), 'sensitivities': {0: 5.99843465581859e+19, 1: 445247225168795.9, 2: 444474201985906.1, 3: 478161213559623.4, 4: 1.1357875980563062e+16, 5: 1.6239590331551914e+20, 6: 4.166622282788226e+19, 7: 1.0993313352637762e+17, 8: 855404293090404.2, 9: 898344097753334.1}}
### 3.混沌优化效果也不错，但是需要1分钟甚至2分钟的运行时间
优化损失: [3.10520114e+07 1.30095089e+02 1.12380677e+02 1.30095090e+02
 5.15585807e+02 2.23281006e-13 2.23281006e-13 2.33281008e-13
 2.25769055e-13 2.23281006e-13]
最终损失值 (MSE): {'params': Array([ 1.07311760e+01,  1.07311760e+01,  3.13674746e+00,  3.13674747e+00,
        3.07714928e+01,  8.77470277e-01,  9.34985895e+00, -6.56614019e-06,
        4.65663855e-07,  8.99687204e+00], dtype=float64), 'loss': Array(5.16725456e-15, dtype=float64), 'sensitivities': {0: 6.009382954031375e+21, 1: 2.5176829932118064e+16, 2: 2.1748624192127092e+16, 3: 2.517683001464005e+16, 4: 9.977944796813381e+16, 5: 42.210761888552184, 6: 42.210761888552184, 7: 44.14602586173936, 8: 42.692265023123255, 9: 42.210761888552184}}

In [13]:
import jax
from jax import random
import optax
from flax import linen as nn
from scipy.optimize import differential_evolution
from flax.training import train_state 
import jax.numpy as jnp
from jax import config, grad, jit, vmap
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)

def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q, q_t = inputs[:, :n_dim], inputs[:, n_dim:]
    true_accelerations = outputs

    @jit
    def compute_acceleration(q, q_t, params):
        def lag(q, q_t, params):
            return equation(q, q_t, params)
        hessian_q_t = jax.hessian(lag, 1)(q, q_t, params)
        grad_q = jax.grad(lag, 0)(q, q_t, params)
        jacobian_q_q_t = jax.jacobian(jax.grad(lag, 1), 0)(q, q_t, params)
        return jnp.linalg.pinv(hessian_q_t) @ (grad_q - jacobian_q_q_t @ q_t)

    batch_compute_acceleration = jit(vmap(compute_acceleration, (0, 0, None)))

    @jit
    def loss_fn(params):
        pred = batch_compute_acceleration(q, q_t, params)
        return jnp.mean(jnp.square(pred - true_accelerations))


        # 量子隧穿优化
    def run_optimization(objective_fn, initial_guess, method='quantum', key=random.PRNGKey(0)):
        """
        Optimized function to minimize objective_fn using quantum tunneling optimization.
        
        Args:
            objective_fn: Function to minimize
            initial_guess: Initial parameters (jnp.array)
            method: Optimization method ('quantum', 'bfgs')
            key: JAX random key
        
        Returns:
            Optimized parameters (jnp.array)
        """
        @jit
        def bfgs_optimize(params):
            result = minimize(objective_fn, params, method='BFGS', 
                            options={'maxiter': 200, 'gtol': 1e-6})
            return result.x

        if initial_guess.size > MAX_NPARAMS or method == 'bfgs':
            return bfgs_optimize(initial_guess)
        
        # Quantum tunneling optimization parameters
        TUNNEL_PROB = 0.05  # Reduced tunneling probability for better exploration
        SWARM_SIZE = 30     # Increased swarm size for diversity
        MAX_ITER = 150      # More iterations for convergence
        DIM = initial_guess.size  # Parameter dimension (10)
        
        # Initialize swarm
        keys = random.split(key, 5)
        positions = initial_guess + random.normal(keys[0], (SWARM_SIZE, DIM)) * 0.2
        velocities = random.normal(keys[1], (SWARM_SIZE, DIM)) * 0.02
        pbest = jnp.copy(positions)
        pbest_losses = vmap(objective_fn)(positions)
        gbest_idx = jnp.argmin(pbest_losses)
        gbest = positions[gbest_idx]
        potential_well = jnp.copy(gbest)
        
        def update_particle(carry, _):
            pos, vel, pbest, pbest_losses, gbest, potential_well = carry
            
            # Compute losses
            losses = vmap(objective_fn)(pos)
            
            # Quantum tunneling mask
            tunnel_mask = random.bernoulli(keys[2], TUNNEL_PROB, (SWARM_SIZE,))
            
            # Velocity update with adaptive quantum term
            r1 = random.uniform(keys[3], (SWARM_SIZE, DIM))
            r2 = random.uniform(keys[4], (SWARM_SIZE, DIM))
            cognitive = 1.0 * r1 * (pbest - pos)  # Emphasize personal best
            social = 1.0 * r2 * (gbest - pos)     # Emphasize global best
            quantum_term = 0.05 * random.normal(keys[3], (SWARM_SIZE, DIM))
            vel = 0.8 * vel + cognitive + social + quantum_term
            
            # Position update with selective tunneling
            tunnel_pos = potential_well + 0.1 * random.normal(keys[4], (SWARM_SIZE, DIM))
            pos = jnp.where(tunnel_mask[:, None], tunnel_pos, pos + vel)
            
            # Update personal and global best
            pbest = jnp.where(losses[:, None] < pbest_losses[:, None], pos, pbest)
            pbest_losses = jnp.minimum(losses, pbest_losses)
            gbest_idx = jnp.argmin(losses)
            gbest = pos[gbest_idx]
            
            # Update potential well adaptively
            well_update_factor = 0.05 / (1 + jnp.exp(-jnp.min(losses)))  # Adaptive based on loss
            potential_well = (1 - well_update_factor) * potential_well + well_update_factor * gbest
            
            return (pos, vel, pbest, pbest_losses, gbest, potential_well), None
        
        # Run quantum optimization
        (positions, _, pbest, pbest_losses, gbest, _), _ = jax.lax.scan(
            update_particle,
            (positions, velocities, pbest, pbest_losses, gbest, potential_well),
            None,
            length=MAX_ITER
        )
        
        # BFGS refinement
        quantum_params = gbest
        return bfgs_optimize(quantum_params)
    
    # 神经引导优化算法
    '''%autoawaitclass GuideNN(nn.Module):
        hidden_dim: int
        output_dim: int

        @nn.compact
        def __call__(self, x):
            x = nn.Dense(features=self.hidden_dim)(x)
            x = nn.relu(x)
            x = nn.Dense(features=self.output_dim)(x) # Output matches parameter dimension
            return x

    # 神经引导优化算法
    def run_optimization(loss_fn_to_optimize, init_params_opt, num_steps=1000, learning_rate=1e-3, guide_strength=0.05, seed=42):
        """
        Runs optimization using a fixed neuro-guided approach.

        Args:
            loss_fn_to_optimize: The loss function to minimize.
            init_params_opt: Initial parameters for optimization.
            num_steps: Number of optimization steps.
            learning_rate: Learning rate for the Adam optimizer.
            guide_strength: Factor to scale the guide network's output.
            seed: PRNG seed for reproducibility.

        Returns:
            Optimized parameters.
        """
        key = random.PRNGKey(seed)
        guide_key, opt_key = random.split(key)

        param_dim = init_params_opt.shape[0] # Dimension of parameters being optimized

        # Initialize Guide Network (defined above within evaluate)
        guide_model = GuideNN(hidden_dim=32, output_dim=param_dim)
        # Use a dummy input of the correct shape for initialization
        try:
            guide_nn_params = guide_model.init(guide_key, jnp.zeros_like(init_params_opt))['params']
        except Exception as e:
             print(f"Error initializing GuideNN: {e}. Input shape: {jnp.zeros_like(init_params_opt).shape}")
             raise # Re-raise the error after logging

        # JIT the apply function for the guide network with static parameters
        guide_apply_jit = jit(lambda p, nn_params: guide_model.apply({'params': nn_params}, p))

        # Initialize Optimizer (e.g., Adam)
        optimizer = optax.adam(learning_rate)
        opt_state = optimizer.init(init_params_opt)

        # Gradient function
        value_and_grad_fn = jit(jax.value_and_grad(loss_fn_to_optimize))

        current_params = init_params_opt

        # Define the optimization step function
        @jit
        def step(params_step, opt_state_step, guide_nn_params_frozen):
            loss_val, grads = value_and_grad_fn(params_step)
            # Get guidance direction (using the frozen NN parameters)
            guidance = guide_apply_jit(params_step, guide_nn_params_frozen)

            # --- Core NGO Logic ---
            # Combine gradient and guidance.
            combined_update_direction = grads + guide_strength * guidance
            # ---------------------

            # Compute updates using the optimizer
            updates, new_opt_state = optimizer.update(combined_update_direction, opt_state_step, params_step)
            # Apply updates
            new_params = optax.apply_updates(params_step, updates)
            return new_params, new_opt_state, loss_val

        # Optimization loop
        print(f"Starting NGO: {param_dim} params, {num_steps} steps, lr={learning_rate}, guide_strength={guide_strength}")
        for i in range(num_steps):
            # Pass the static guide network parameters to the step function
            current_params, opt_state, loss_val = step(current_params, opt_state, guide_nn_params)
            if i % (num_steps // 10) == 0 or i == num_steps - 1:
                 # Check for NaNs in loss or params during optimization
                 if not jnp.isfinite(loss_val) or not jnp.all(jnp.isfinite(current_params)):
                     print(f"Warning: Non-finite value encountered at step {i+1}. Loss: {loss_val}, Params Finite: {jnp.all(jnp.isfinite(current_params))}")
                     # Optionally break or handle the divergence
                     # break
                 print(f"  Step {i+1}/{num_steps}, Loss: {loss_val:.6f}")


        final_loss_check = loss_fn_to_optimize(current_params)
        print(f"NGO finished. Final Loss: {final_loss_check:.6f}")
        if not jnp.isfinite(final_loss_check):
            print(f"Warning: Final loss is non-finite: {final_loss_check}")
        if not jnp.all(jnp.isfinite(current_params)):
            print(f"Warning: Final parameters contain non-finite values.")

        return current_params'''
    
    #差分进化算法
    '''def run_optimization(objective_fn, initial_guess):
        """处理任意维度的参数优化"""
        # 高维参数（通常在敏感度分析中遇到）使用纯BFGS优化
        if initial_guess.size > MAX_NPARAMS:
            result = minimize(objective_fn, initial_guess,
                              method='BFGS', options={'maxiter': 500})
            if not result.success:
                 print(f"High-dim BFGS optimization failed: {result.message}")
                 # 可以考虑返回 initial_guess 或 raise error
                 return initial_guess
            return result.x
        else:
            # 低维参数（主优化流程）使用差分进化 + BFGS
            # 定义参数边界
            bounds = [(-1.0, 10.0)] * initial_guess.size

            # 差分进化（DE）阶段
            de_result = differential_evolution(objective_fn, bounds,
                                               maxiter=100, # 可以调整迭代次数
                                               popsize=15,   # 可以调整种群大小
                                               tol=0.01,     # 可以调整容忍度
                                               mutation=(0.5, 1), # 变异因子范围
                                               recombination=0.7, # 交叉概率
                                               seed=None)     # 可设置随机种子以便复现

            if not de_result.success:
                print(f"Differential Evolution failed: {de_result.message}")
                # 如果DE失败，可以考虑直接使用初始猜测进行BFGS或返回DE的最佳结果
                de_best_params = de_result.x
            else:
                de_best_params = de_result.x

            # BFGS 精炼阶段
            # 使用 DE 找到的最优参数作为 BFGS 的初始点
            bfgs_result = minimize(objective_fn, jnp.array(de_best_params),
                                   method='BFGS', options={'maxiter': 500}) # 增加BFGS迭代次数

            if not bfgs_result.success:
                 print(f"BFGS refinement failed: {bfgs_result.message}")
                 # 如果BFGS失败，返回DE的结果
                 return de_best_params

            return bfgs_result.x'''
    
    #神经引导优化
    '''def run_optimization(objective_fn, initial_guess):
        """使用简化的神经引导优化算法处理参数优化，提高计算速度"""
        import flax.linen as nn
        from flax.training import train_state
        import optax
        
        param_dim = initial_guess.size
        
        # 简化网络结构 - 减少层数和神经元数量
        class PredictorNetwork(nn.Module):
            @nn.compact
            def __call__(self, x):
                x = nn.Dense(features=16)(x)
                x = nn.relu(x)
                x = nn.Dense(features=param_dim)(x)
                return x
        
        # 初始化神经网络
        predictor = PredictorNetwork()
        key = jax.random.PRNGKey(0)
        params = predictor.init(key, jnp.ones((1, param_dim)))
        
        # 使用更快的优化器，增大学习率
        optimizer = optax.adam(learning_rate=0.03)
        state = train_state.TrainState.create(
            apply_fn=predictor.apply,
            params=params,
            tx=optimizer
        )
        
        # 使用JIT加速训练过程
        @jit
        def train_step(state, x, y):
            def loss_fn(params):
                pred = state.apply_fn(params, x)
                return jnp.mean(jnp.square(pred - y))
            
            grad_fn = jax.value_and_grad(loss_fn)
            loss, grads = grad_fn(state.params)
            return state.apply_gradients(grads=grads), loss
        
        # 简化数据收集 - 减少样本数量
        def collect_data(current_params, n_samples=10):  # 减少样本量
            key = jax.random.PRNGKey(int(jnp.sum(current_params) * 1000) % 2**32)
            
            # 一次性生成所有随机噪声，避免循环
            key, subkey = jax.random.split(key)
            noises = jax.random.normal(subkey, shape=(n_samples, param_dim)) * 0.1
            samples = current_params + noises
            
            # 批量计算梯度
            grad_fn = jax.grad(objective_fn)
            batch_grad_fn = jax.vmap(grad_fn)
            gradients = batch_grad_fn(samples)
            
            return samples, gradients
        
        # 优化神经引导主循环 - 减少迭代次数
        current_params = initial_guess
        best_params = initial_guess
        best_loss = objective_fn(initial_guess)
        
        # 减少总迭代次数
        for epoch in range(15):  # 降低迭代次数
            # 收集数据并训练预测器
            x_data, y_data = collect_data(current_params)
            
            # 减少训练步数
            for _ in range(30):  # 降低训练步数
                state, loss = train_step(state, x_data, y_data)
            
            # 使用神经网络预测梯度
            predicted_direction = -state.apply_fn(state.params, current_params.reshape(1, -1))[0]
            normalized_direction = predicted_direction / (jnp.linalg.norm(predicted_direction) + 1e-8)
            
            # 简化线搜索 - 减少步长尝试
            step_sizes = jnp.geomspace(0.02, 0.5, num=5)  # 减少步长选项
            best_step = 0.1  # 默认步长
            best_step_loss = objective_fn(current_params)
            
            for step_size in step_sizes:
                new_params = current_params + step_size * normalized_direction
                new_loss = objective_fn(new_params)
                
                if new_loss < best_step_loss:
                    best_step = step_size
                    best_step_loss = new_loss
            
            # 更新参数
            current_params = current_params + best_step * normalized_direction
            
            # 降低BFGS频率
            if epoch % 5 == 0:  # 每5次神经引导后进行一次BFGS优化
                result = minimize(objective_fn, current_params, 
                                method='BFGS', 
                                options={'maxiter': 30})  # 减少最大迭代次数
                current_params = result.x
            
            # 记录最佳结果
            current_loss = objective_fn(current_params)
            if current_loss < best_loss:
                best_loss = current_loss
                best_params = current_params
        
        # 最后进行一次完整的BFGS优化
        result = minimize(objective_fn, best_params, 
                        method='BFGS', 
                        options={'maxiter': 100})  # 减少最终优化的迭代次数
        
        return result.x'''


    # 混沌优化
    '''def run_optimization(objective_fn, initial_guess, key=random.PRNGKey(0)):

        @jit
        def bfgs_optimize(params):
            result = minimize(objective_fn, params, method='BFGS', 
                            options={'maxiter': 200, 'gtol': 1e-6})
            return result.x

        DIM = initial_guess.size  # 10
        N_POINTS = 20
        MAX_ITER = 100
        
        # 初始化混沌点
        keys = random.split(key, 2)
        chaos = random.uniform(keys[0], (N_POINTS,)) * 0.9 + 0.1  # [0.1, 1]
        r = 3.9  # 混沌区
        bounds = jnp.array([-10., 10.])  # 参数范围
        
        def chaos_step(carry, _):
            chaos, points, best_point, best_loss = carry
            # 逻辑映射
            chaos = r * chaos * (1 - chaos)
            # 映射到参数空间
            scaled = bounds[0] + (bounds[1] - bounds[0]) * chaos
            new_points = points.at[:, 0].set(scaled)
            for i in range(1, DIM):
                chaos = r * chaos * (1 - chaos)
                scaled = bounds[0] + (bounds[1] - bounds[0]) * chaos
                new_points = new_points.at[:, i].set(scaled)
            
            # 评估损失
            losses = vmap(objective_fn)(new_points)
            min_idx = jnp.argmin(losses)
            new_best_point = new_points[min_idx]
            new_best_loss = losses[min_idx]
            
            # 更新最佳点
            best_point = jnp.where(new_best_loss < best_loss, new_best_point, best_point)
            best_loss = jnp.minimum(best_loss, new_best_loss)
            return (chaos, new_points, best_point, best_loss), None
        
        # 运行混沌优化
        points = initial_guess + random.normal(keys[1], (N_POINTS, DIM)) * 0.2
        best_loss = objective_fn(initial_guess)
        (chaos, points, best_point, best_loss), _ = jax.lax.scan(
            chaos_step, (chaos, points, initial_guess, best_loss), None, length=MAX_ITER)
        
        # BFGS精修
        return bfgs_optimize(best_point)'''
    


    # 敏感度分析模块
    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit
        def batch_loss(params_matrix):
            # 确保输入是正确形状的矩阵
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            # 显式重塑参数形状
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            # 生成正确的初始参数（复制参数矩阵并应用掩码）
            initial_flat = (opt_params * mask).flatten()
            
            # 执行优化（确保处理100个参数）
            optimized_flat = run_optimization(sensitivity_objective, initial_flat)
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            
            # 计算损失差异
            losses = batch_loss(optimized_matrix)
            print('优化损失:',losses)
            raw_relative = (losses - base_loss) / base_loss
            '''scaled = jnp.nan_to_num(raw_relative, nan=0.0, posinf=0.0, neginf=0.0)
            
            # 新增加的核心计算逻辑
            x_d = jnp.median(scaled)
            epsilon = 1e-8  # 防止除以零
            adjusted_xd = x_d 
            
            # 应用指数公式
            relative_to_median = (scaled - x_d) / adjusted_xd
            weights = jnp.exp(relative_to_median)
            
            # 保持符号（相对于中位数的变化方向）
            final_sensitivities = weights * jnp.sign(scaled)'''

            return {i: float(raw_relative[i]) for i in range(MAX_NPARAMS)}
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return {i: 0.0 for i in range(MAX_NPARAMS)}


    # 主流程
    try:
        optimized_params = run_optimization(loss_fn, params)
        final_loss = loss_fn(optimized_params)
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None

    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': calculate_sensitivities(optimized_params, final_loss)
    }

@jit
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    q1, q2 = q[...,0], q[...,1]
    q1_t, q2_t = q_t[...,0], q_t[...,1]
    
    m1, m2, l1, l2, g = params[0], params[1], params[2], params[3], params[4]
    c1, c2, F1, F2, k = params[5], params[6], params[7], params[8], params[9]
    
    T = 0.5*(m1*l1**2*q1_t**2 + m2*(l1**2*q1_t**2 + l2**2*q2_t**2 + 2*l1*l2*q1_t*q2_t*jnp.cos(q1-q2)))
    V = -g*(m1*l1*jnp.cos(q1) + m2*(l1*jnp.cos(q1)+l2*jnp.cos(q2))) + F1*q1 + F2*q2
    D = c1*q1_t + c2*q2_t
    return T - V - D


In [6]:
import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./double_pendulum_data_with_energy.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :4], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 4:-1], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, -1], dtype=jnp.float64)
print(true_q_ddot)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
    'energy': energy
}


print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

[[-9.81000000e+00  0.00000000e+00]
 [-9.80999997e+00 -5.53723182e-08]
 [-9.80999948e+00 -8.85956966e-07]
 ...
 [ 1.52940364e+01 -1.63022950e+01]
 [ 1.48602131e+01 -1.55387182e+01]
 [ 1.43807650e+01 -1.47005543e+01]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
最终损失值 (MSE): {'params': Array([2.92511250e-01, 2.93263427e-01, 5.90578390e-02, 5.90171138e-02,
       5.79794518e-01, 1.61471544e-08, 0.00000000e+00, 1.00000000e+00,
       1.00000000e+00, 1.00000000e+00], dtype=float64), 'loss': Array(0.0003686, dtype=float64), 'sensitivities': {0: 48669615238.17453, 1: 408578.53417865915, 2: 391923.4500097751, 3: 408706.6548063402, 4: 1096953.6475976415, 5: -0.00039801818556416883, 6: -0.00039801818556416883, 7: -0.271558974297123, 8: -0.271558974297123, 9: -0.271558974297123}}
