## 生成哈密顿量的txt文件：

In [50]:
"""
Find the mathematical function skeleton that represents acceleration in a system, given data on time, position, and velocity.
A large positive sensitivity value for param[i] means that removing it significantly hurts the function's performance,
You should select those formulas with high sensitivity of parameter and remove formulas with low sensitivity of parameter, then add one or two new formula.
"""

import jax
import jax.numpy as jnp
from jax import config, grad, jit, vmap
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)

#@evaluate.run
def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q, p = inputs[:, :n_dim], inputs[:, n_dim:]
    true_value = outputs

    @jit
    def compute_dynamics(q, p, params):
        def hamiltonian(q, p, params):
            return equation(q, p, params)  
        q_dot = jax.grad(hamiltonian, 1)(q, p, params)  
        p_dot = -jax.grad(hamiltonian, 0)(q, p, params)  
        return jnp.concatenate([q_dot, p_dot])

    batch_compute_dynamics = jit(vmap(compute_dynamics, (0, 0, None)))

    @jit
    def loss_fn(params):
        pred = batch_compute_dynamics(q, p, params)
        return jnp.mean(jnp.square(pred - true_value))

    def run_optimization(objective_fn, initial_guess):
        if initial_guess.size > MAX_NPARAMS:
            result = minimize(objective_fn, initial_guess,
                            method='BFGS', options={'maxiter': 500})
            return result.x
        else:
            def pso_wrapper(x):
                return objective_fn(jnp.array(x))   

            lb = [-10.0]*initial_guess.size
            ub = [10.0]*initial_guess.size
            pso_params, _ = pso(pso_wrapper, lb, ub, 
                            swarmsize=100, maxiter=300,omega=0.729, phip=1.49445, phig=1.49445)
            
            result = minimize(objective_fn, jnp.array(pso_params),
                            method='BFGS', options={'maxiter': 500})
            return result.x

    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit      
        def batch_loss(params_matrix):
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            initial_flat = (opt_params * mask).flatten()
            optimized_flat = run_optimization(sensitivity_objective, initial_flat)
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            losses = batch_loss(optimized_matrix)
            relative_loss = jnp.log2(losses / base_loss)
            return jnp.round(relative_loss,4)
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return jnp.zeros(MAX_NPARAMS, dtype=jnp.float64)

    try:
        optimized_params = run_optimization(loss_fn, params)
        final_loss = loss_fn(optimized_params)
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None

    sensitivities = calculate_sensitivities(optimized_params, final_loss)
    sensitivity_dict = {f"sensitive of params[{i}]": float(sensitivities[i])
                       for i in range(len(sensitivities))}

    return {
        'params': optimized_params,
        'loss': -final_loss.item(),
        'sensitivities': sensitivity_dict
    }

#@equation.evolve
@jit
def equation(q: jnp.array, p: jnp.array, params: jnp.array) -> jnp.array:
    q = q[...,0]
    p = p[...,0]
     
    T = params[1] * jnp.square(p)
    V =  params[3] * jnp.square(q) +params[0]*q +params[2]

    return T + V


In [51]:
import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
#data0 = pd.read_csv('./pendulum_hamilton_data.csv')#
data0 = pd.read_csv('./hamiltonian_spring_mass_energy_data.csv')
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, 0:2], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 2:-1], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, -1], dtype=jnp.float64)
print(true_q_ddot)
print(state)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
    'energy': energy
}


print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

[[  0.         -10.981     ]
 [ -0.27456504 -10.9775669 ]
 [ -0.5489584  -10.96726975]
 ...
 [  5.76889777  -9.34101908]
 [  5.53353465  -9.48234218]
 [  5.29471152  -9.61773616]]
[[ 2.          0.        ]
 [ 1.99965669 -0.0274565 ]
 [ 1.99862697 -0.05489584]
 ...
 [ 1.83600191  0.57688978]
 [ 1.85013422  0.55335346]
 [ 1.86367362  0.52947115]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Stopping search: Swarm best objective change less than 1e-08
最终损失值 (MSE): {'params': Array([ -9.019     ,   5.        , -10.        ,   5.        ,
       -10.        , -10.        ,  -7.99325263,   9.95620674,
       -10.        , -10.        ], dtype=float64), 'loss': -3.194261037627877e-30, 'sensitivities': {'sensitive of params[0]': 102.10210000000001, 'sensitive of params[1]': 102.9021, 'sensitive of params[2]': 18.116, 'sensitive of params[3]': 102.8904, 'sensitive of params[4]': 18.116, 'sensitive of params[5]': 18.116, 'sensitive of params[6]': 18.116, 'sensitive of params[7]': 18.116, 'sensitive of param