## 生成哈密顿量的txt文件：

In [None]:
"""
Find the mathematical function skeleton that represents acceleration in a system, given data on time, position, and velocity.
A large positive sensitivity value for param[i] means that removing it significantly hurts the function's performance,
You should select those formulas with high sensitivity of parameter and remove formulas with low sensitivity of parameter, then add one or two new formula.
"""

import jax
import jax.numpy as jnp
from jax import config, grad, jit, vmap
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)


MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)

#@evaluate.run
def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    #n_dim = inputs.shape[1] // 2
    #q, p = inputs[:, :n_dim], inputs[:, n_dim:]
    q, p = inputs[:, 0], inputs[:,1]
    true_value = outputs
    
    @jit
    def compute_dynamics(q, p, params):
        def hamiltonian(q, p, params):
            return equation(q, p, params)  
        q_dot = jax.grad(hamiltonian, 1)(q, p, params)  
        p_dot = -jax.grad(hamiltonian, 0)(q, p, params)  
        return jnp.array([q_dot, p_dot])     #jnp.concatenate([q_dot, p_dot])

    batch_compute_dynamics = jit(vmap(compute_dynamics, (0, 0, None)))

    @jit
    def loss_fn(params):
        pred = batch_compute_dynamics(q, p, params)
        return jnp.mean(jnp.square(pred - true_value))

    '''def run_optimization(objective_fn, initial_guess):
        if initial_guess.size > MAX_NPARAMS:
            result = minimize(objective_fn, initial_guess,
                            method='BFGS', options={'maxiter': 500})
            return result.x
        else:
            def pso_wrapper(x):
                return objective_fn(jnp.array(x))   

            lb = [-10.0]*initial_guess.size
            ub = [10.0]*initial_guess.size
            pso_params, _ = pso(pso_wrapper, lb, ub, 
                            swarmsize=100, maxiter=300,omega=0.729, phip=1.49445, phig=1.49445)
            
            result = minimize(objective_fn, jnp.array(pso_params),
                            method='BFGS', options={'maxiter': 500})
            return result.x'''

    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit      
        def batch_loss(params_matrix):
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            initial_flat = (opt_params * mask).flatten()
            optimized_flat = run_optimization(sensitivity_objective, initial_flat)
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            losses = batch_loss(optimized_matrix)
            relative_loss = jnp.log2(losses / base_loss)
            return jnp.round(relative_loss,4)
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return jnp.zeros(MAX_NPARAMS, dtype=jnp.float64)

    try:
        optimized_params = run_optimization(loss_fn, params)
        final_loss = loss_fn(optimized_params)
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None

    sensitivities = calculate_sensitivities(optimized_params, final_loss)
    sensitivity_dict = {f"sensitive of params[{i}]": float(sensitivities[i])
                       for i in range(len(sensitivities))}

    return {
        'params': optimized_params,
        'loss': -final_loss.item(),
        'sensitivities': sensitivity_dict
    }

#@equation.evolve
@jit
def equation(q: jnp.array, p: jnp.array, params: jnp.array) -> jnp.array:
    #q = q[...,0]
    #p = p[...,0]
     
    #T = params[1] * jnp.square(p)
    #V =  params[3] * jnp.square(q) +params[0]*q +params[2]                   #懂了，本质上是我在调用equation时，传入的数据本就只有1组，

    #T = params[1] * jnp.dot(p,p)                                             #并没有区别，因为p就只有一个，q只有一个。jnp.sum(q)=q。
    #V = params[3] * jnp.dot(q,q) + params[0]*jnp.sum(q) + params[2]
    T = params[1] * jnp.square(p) 
    V = params[3] * jnp.cos(q)+params[0]

    return T + V


In [43]:
import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./pendulum_hamilton_data.csv')
#data0 = pd.read_csv('./hamiltonian_spring_mass_energy_data.csv')
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, 0:2], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 2:-1], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, -1], dtype=jnp.float64)
print(true_q_ddot)
print(state)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
    'energy': energy
}


print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

[[ 0.00000000e+00 -4.15778787e+01]
 [-8.66419284e-03 -4.15774283e+01]
 [-1.73281980e-02 -4.15760771e+01]
 ...
 [-1.08776333e+00  3.30105593e+01]
 [-1.08087271e+00  3.31423153e+01]
 [-1.07395339e+00  3.32730311e+01]]
[[  0.78539816   0.        ]
 [  0.78538733  -0.10397031]
 [  0.78535483  -0.20793838]
 ...
 [ -0.5960815  -13.05316   ]
 [ -0.59879171 -12.97047246]
 [ -0.60148549 -12.8874407 ]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Optimization or final loss calculation failed: iteration over a 0-d array
最终损失值 (MSE): None


Traceback (most recent call last):
  File "C:\Users\19464\AppData\Local\Temp\ipykernel_16488\3614080711.py", line 173, in evaluate
    optimized_params = run_optimization(loss_fn, params, key=random.PRNGKey(0), maxiter=500)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\19464\AppData\Local\Temp\ipykernel_16488\3614080711.py", line 135, in run_optimization
    sol = solver.run(initial_guess)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\anacoda\envs\pytorch\Lib\site-packages\jaxopt\_src\base.py", line 358, in run
    return run(init_params, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\anacoda\envs\pytorch\Lib\site-packages\jaxopt\_src\implicit_diff.py", line 251, in wrapped_solver_fun
    return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\anacoda\envs\pytorch\Lib\site-packages\jax\_src\traceback_util.p

In [58]:
"""
Find the mathematical function skeleton that represents acceleration in a damped nonlinear oscillator system with driving force, given data on time, position, and velocity. 
"""


import jax.numpy as jnp
from jax import jit, vmap, config
from scipy.optimize import minimize
import pyswarms as ps
config.update("jax_enable_x64", True)
import numpy as np
import jax.random as random
from scipy.optimize import differential_evolution
from jaxopt import LBFGS

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, params=initial_params) -> dict:

    master_key = random.PRNGKey(0)
    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q, p = inputs[:, :n_dim], inputs[:, n_dim:]
    true_value = outputs
    
    @jit
    def compute_dynamics(q, p, params):
        def hamiltonian(q, p, params):
            return equation(q, p, params)  
        q_dot = jax.grad(hamiltonian, 1)(q, p, params)  
        p_dot = -jax.grad(hamiltonian, 0)(q, p, params)  
        return jnp.concatenate([q_dot, p_dot])

    batch_compute_dynamics = jit(vmap(compute_dynamics, (0, 0, None)))

    @jit
    def loss_fn(params):
        pred = batch_compute_dynamics(q, p, params)
        return jnp.mean(jnp.square(pred - true_value))

    '''def run_optimization(objective_fn, initial_guess, key, num_pso_runs=5, pso_iters=300, swarmsize=100):
        print(f"Initial guess size: {initial_guess.size}, MAX_NPARAMS: {MAX_NPARAMS}")
        n_params = initial_guess.size

        if n_params > MAX_NPARAMS:
            # ... (BFGS only part remains the same) ...
            print("Using BFGS directly due to large number of parameters.")
            result = minimize(objective_fn, initial_guess,
                              method='BFGS', options={'maxiter': 500})
            if not result.success:
                print(f"BFGS optimization failed: {result.message}")
                return initial_guess # Fallback
            return result.x
        else:
            # --- Use pyswarms with JAX-controlled initial positions ---
            @jit
            def pso_objective_wrapper(particles_matrix):
                return vmap(objective_fn)(particles_matrix)

            min_bound_np = np.full(n_params, -10.0, dtype=np.float64)
            max_bound_np = np.full(n_params, 10.0, dtype=np.float64)
            bounds = (min_bound_np, max_bound_np)

            options = {'c1': 1.49445, 'c2': 1.49445, 'w': 0.729}

            best_pso_params = None
            best_pso_loss = jnp.inf

            current_key = key # Use the passed-in key

            print(f"Running pyswarms PSO {num_pso_runs} times with unique initial swarms...")
            for i in range(num_pso_runs):
                # Split the key for this run to ensure unique randomness
                current_key, subkey = random.split(current_key)

                # Generate initial positions using JAX PRNG within bounds
                # Use jnp arrays for min/max bounds in jax.random.uniform
                min_bound_jnp = jnp.full(n_params, -10.0, dtype=jnp.float64)
                max_bound_jnp = jnp.full(n_params, 10.0, dtype=jnp.float64)
                init_pos_jax = random.uniform(subkey,
                                              shape=(swarmsize, n_params),
                                              dtype=jnp.float64,
                                              minval=min_bound_jnp,
                                              maxval=max_bound_jnp)
                # Convert to NumPy array for pyswarms
                init_pos_np = np.array(init_pos_jax)

                print(f"  PSO Run {i+1}/{num_pso_runs} (using JAX key split for init_pos)")
                optimizer = ps.single.GlobalBestPSO(n_particles=swarmsize,
                                                     dimensions=n_params,
                                                     options=options,
                                                     bounds=bounds,
                                                     # Pass the generated initial positions
                                                     init_pos=init_pos_np) # <-- Pass init_pos here

                # Perform optimization (pyswarms will use the provided init_pos)
                current_pso_loss, current_pso_params = optimizer.optimize(
                    pso_objective_wrapper,
                    iters=pso_iters,
                    verbose=False
                )
                current_pso_params = jnp.array(current_pso_params, dtype=jnp.float64)

                print(f"    Run {i+1} completed. Loss: {current_pso_loss}")
                if current_pso_loss < best_pso_loss:
                    best_pso_loss = current_pso_loss
                    best_pso_params = current_pso_params
                    print(f"    New best PSO loss found: {best_pso_loss}")

            # ... (rest of the function: handling no solution, BFGS refinement) ...
            if best_pso_params is None:
                 print("Warning: PSO did not find any valid solution after multiple runs. Using initial guess for BFGS.")
                 best_pso_params = initial_guess # Fallback

            print(f"\nBest PSO loss after {num_pso_runs} runs: {best_pso_loss}")
            print("Refining best PSO result with BFGS...")

            result = minimize(objective_fn, jnp.array(best_pso_params),
                              method='BFGS', options={'maxiter': 500})

            if not result.success:
                 print(f"BFGS refinement failed: {result.message}")
                 return best_pso_params # Return PSO best

            print(f"BFGS refinement successful. Final loss: {result.fun}")
            return result.x'''
    def run_optimization(objective_fn,
                        n_params: int,
                        bound_min: float = -10.0,
                        bound_max: float = 10.0,
                        maxiter_de: int = 1000,
                        popsize: int = 15,
                        maxiter_lbfgs: int = 500,
                        tol_lbfgs: float = 1e-6):
        """
        1) 差分进化寻找全局初始解
        2) L-BFGS 精炼

        参数：
        objective_fn: fn(params: jnp.ndarray) -> scalar loss
        n_params: 参数维度
        bound_min, bound_max: 每个参数的取值范围
        maxiter_de, popsize: 差分进化设置
        maxiter_lbfgs, tol_lbfgs: L-BFGS 设置

        返回：
        优化后的参数 jnp.ndarray
        """
        # 自动构造 bounds
        bounds = [(bound_min, bound_max)] * n_params

        # JIT 编译目标函数
        jit_loss = jax.jit(objective_fn)

        # NumPy wrapper 供 SciPy 调用
        def loss_numpy(x: np.ndarray) -> float:
            return float(jit_loss(jnp.array(x, dtype=jnp.float64)))

        # 1) 差分进化全局搜索
        result_de = differential_evolution(
            loss_numpy,
            bounds=bounds,
            maxiter=maxiter_de,
            popsize=popsize,
            tol=0.01,
            disp=True
        )
        x0 = jnp.array(result_de.x, dtype=jnp.float64)
        print(f"[DE] 初始点 loss = {result_de.fun:.6f}")

        # 2) L-BFGS 局部精炼
        solver = LBFGS(fun=objective_fn, maxiter=maxiter_lbfgs, tol=tol_lbfgs)
        sol = solver.run(x0)
        final_loss = float(objective_fn(sol.params))
        print(f"[L-BFGS] 最终 loss = {final_loss:.6f}")

        return sol.params
        

    # 敏感度分析模块
    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit      
        def batch_loss(params_matrix):
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            initial_flat = (opt_params * mask).flatten()
            #optimized_flat = run_optimization(sensitivity_objective, initial_flat, key=opt_key, num_pso_runs=5)
            optimized_flat = run_optimization(sensitivity_objective,n_params=10)
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            losses = batch_loss(optimized_matrix)
            relative_loss = jnp.log2(losses / base_loss)
            return relative_loss
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return jnp.zeros(MAX_NPARAMS, dtype=jnp.float64)


    # 主流程
    # Main execution flow
    try:
        # Split the master key for the main optimization run
        #opt_key, sensi_key = random.split(master_key) # Keep keys separate if needed later

        #optimized_params = run_optimization(loss_fn, params, key=opt_key, num_pso_runs=5) # Pass the key
        optimized_params = run_optimization(loss_fn,n_params=10)
        final_loss = loss_fn(optimized_params)
        print(f"Final loss after L-BFGS: {final_loss}")
        if optimized_params is None:
             print("Optimization failed to produce parameters.")
             return None

        final_loss = loss_fn(optimized_params)
        print(f"Final optimized loss: {final_loss}")
    except Exception as e:
        # ... (error handling remains the same) ...
        print(f"Optimization or final loss calculation failed: {e}")
        import traceback
        traceback.print_exc()
        return None

    if not jnp.isfinite(final_loss):
        print("Final loss is not finite.")
        return None

    # Pass a key to sensitivity analysis if it also needs randomness
    # For now, assuming calculate_sensitivities doesn't need a separate key
    sensitivities = calculate_sensitivities(optimized_params, final_loss)
    sensitivity_dict = {f"sensitive of params[{i}]": round(float(sensitivities[i]), 4) 
                       for i in range(len(sensitivities))}


    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': sensitivity_dict
    }


@jit
def equation(q: jnp.array, p: jnp.array, params: jnp.array) -> jnp.array:
    q = q[...,0]
    p = p[...,0]

    #T = params[1] * jnp.square(p)
    #V =  params[3] * jnp.square(q) +params[0]*q +params[2]

    T = params[1] * jnp.square(p) 
    V = params[3] * jnp.cos(params[2]*q)+params[0]

    return T + V

In [10]:

import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./pendulum_hamilton_data.csv')
#data0 = pd.read_csv('./hamiltonian_spring_mass_energy_data.csv')
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, 0:2], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 2:-1], dtype=jnp.float64)  # 转换为 JAX 数组
energy = jnp.array(tae[:, -1], dtype=jnp.float64)
print(true_q_ddot)
print(state)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
    'energy': energy
}


print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

[[ 0.00000000e+00 -4.15778787e+01]
 [-8.66419284e-03 -4.15774283e+01]
 [-1.73281980e-02 -4.15760771e+01]
 ...
 [-1.08776333e+00  3.30105593e+01]
 [-1.08087271e+00  3.31423153e+01]
 [-1.07395339e+00  3.32730311e+01]]
[[  0.78539816   0.        ]
 [  0.78538733  -0.10397031]
 [  0.78535483  -0.20793838]
 ...
 [ -0.5960815  -13.05316   ]
 [ -0.59879171 -12.97047246]
 [ -0.60148549 -12.8874407 ]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Running pyswarms PSO 3 times with unique initial swarms...
  PSO Run 1/3 (using JAX key split for init_pos)
    Run 1 completed. Loss: 4.4069574625617666e-30
    New best PSO loss found: 4.4069574625617666e-30
  PSO Run 2/3 (using JAX key split for init_pos)
    Run 2 completed. Loss: 3.3123587074527874e-29
  PSO Run 3/3 (using JAX key split for init_pos)
    Run 3 completed. Loss: 1.0004304544837082e-28
Final loss after L-BFGS: 4.352300786408347e-30
最终损失值 (MSE): {'params': Array([ 4.16666667e-02, -5.88000000e+01], dtype=float64), 'loss': -4.352300786408347e-30, 'se

# 又遇到了优化困难的问题
### 尝试了JAX自带的L-BFGS，垃圾的一批：好几百的损失。
### 尝试差分进化+bfgs：效果很不错，再测试一下双摆的例子优化效果。

In [9]:
"""
Find the mathematical function skeleton that represents Hamiltonian in a system, given generalized position and velocity.
Parameter importance is defined as: log2(optimal loss after removing expressions related to the parameter / current optimal loss of the expression)
You should select some formulas with high importance of parameter and remove formulas with low and negative importance of parameter, then add one new formula.
"""

import jax
import jax.numpy as jnp
from jax import jit, vmap, config
import pyswarms as ps
import numpy as np
import jax.random as random
from jaxopt import LBFGS
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS)


def evaluate(data: dict, params=initial_params) -> dict:
    master_key = random.PRNGKey(0)
    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q, p = inputs[:, :n_dim], inputs[:, n_dim:]
    true_value = outputs

    @jit
    def compute_dynamics(q, p, params):
        def hamiltonian(q, p, params):
            return equation(q, p, params)
        q_dot = jax.grad(hamiltonian, 1)(q, p, params)
        p_dot = -jax.grad(hamiltonian, 0)(q, p, params)
        return jnp.concatenate([q_dot, p_dot])

    batch_compute_dynamics = jit(vmap(compute_dynamics, (0, 0, None)))

    @jit
    def loss_fn(params):
        pred = batch_compute_dynamics(q, p, params)
        return jnp.mean(jnp.square(pred - true_value))

    def run_optimization(objective_fn, initial_guess, key, num_pso_runs=3, pso_iters=300, swarmsize=100, gg = False):
        grads   = jax.grad(equation,2)(q[1],p[1],params)
        active  = jnp.where(jnp.abs(grads) > 0)[0]
        n_params = active.shape[0]
        solver = LBFGS(objective_fn, maxiter=100, tol=1e-8)

        @jit
        def pso_objective_wrapper(particles_matrix):
            return vmap(objective_fn)(particles_matrix)

        min_bound_np = np.full(n_params, -100.0, dtype=np.float64)
        max_bound_np = np.full(n_params, 100.0, dtype=np.float64)
        bounds = (min_bound_np, max_bound_np)
        options = {'c1': 1.49445, 'c2': 1.49445, 'w': 0.729}
        best_pso_params = None
        best_pso_loss = jnp.inf
        current_key = key
        print(f"Running pyswarms PSO {num_pso_runs} times with unique initial swarms...")
        for i in range(num_pso_runs):
            current_key, subkey = random.split(current_key)
            min_bound_jnp = jnp.full(n_params, -10.0, dtype=jnp.float64)
            max_bound_jnp = jnp.full(n_params, 10.0, dtype=jnp.float64)
            init_pos_jax = random.uniform(subkey, shape=(swarmsize, n_params), dtype=jnp.float64, minval=min_bound_jnp, maxval=max_bound_jnp)
            init_pos_np = np.array(init_pos_jax)

            print(f"  PSO Run {i+1}/{num_pso_runs} (using JAX key split for init_pos)")
            optimizer = ps.single.GlobalBestPSO(n_particles=swarmsize, dimensions=n_params, options=options, bounds=bounds, init_pos=init_pos_np)
            current_pso_loss, current_pso_params = optimizer.optimize( pso_objective_wrapper, iters=pso_iters, verbose=False)
            current_pso_params = jnp.array(current_pso_params, dtype=jnp.float64)

            print(f"    Run {i+1} completed. Loss: {current_pso_loss}")
            if current_pso_loss < best_pso_loss:
                best_pso_loss = current_pso_loss
                best_pso_params = current_pso_params
                print(f"    New best PSO loss found: {best_pso_loss}")
        result =  solver.run(best_pso_params)
        return result.params


    def calculate_sensitivities(opt_params: jnp.ndarray, base_loss: float) -> jnp.ndarray:
        n_active = opt_params.size
        masks = 1.0 - jnp.eye(n_active, dtype=opt_params.dtype)
        @jit
        def solve_one(mask_i):
            init = opt_params * mask_i
            def sub_loss(p):
                return loss_fn(p * mask_i)
            solver = LBFGS(fun=sub_loss, maxiter=100, tol=1e-8)
            out    = solver.run(init)
            loss_i = loss_fn(out.params * mask_i)
            return jnp.log2(loss_i / base_loss)
        sensitivities = vmap(solve_one)(masks)
        return sensitivities

    try:
        opt_key, sensi_key = random.split(master_key)
        optimized_params = run_optimization(loss_fn, params, key=opt_key)
        final_loss = loss_fn(optimized_params)
        print(f"Final loss after L-BFGS: {final_loss}")

    except Exception as e:
        print(f"Optimization or final loss calculation failed: {e}")
        import traceback
        traceback.print_exc()
        return None

    if not jnp.isfinite(final_loss):
        print("Final loss is not finite.")
        return None

    sensitivities = calculate_sensitivities(optimized_params, final_loss)
    sensitivity_dict = {f"importance of params[{i}]": round(float(sensitivities[i]),4)
                       for i in range(len(sensitivities))}

    return {
        'params': optimized_params,
        'loss': -float(final_loss),
        'sensitivities': sensitivity_dict
    }


@jit
def equation(q: jnp.array, p: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for Hamiltonian in system

    Args:
        q: A  jnp.array of position.
        p: A  jnp.array of momentum.
        params: Array of numeric constants or parameters to be optimized

    Return:
        A jnp.array representing Hamiltonian as the result of applying the mathematical function to the inputs.
    """
    p = p[...,0]
    q = q[...,0]    
    T = params[0] * p**2 
    V = params[5] * jnp.cos(q) 

    return T + V

