In [None]:
"""
Find the mathematical function skeleton that represents lagrangian in a physical system, given data on generalized coordinate and generalized velocity.
Tips:You may only use no more than 10 parameters,Under limited parameter conditions, you can incorporate nonlinear terms rather than continuously adding new ones.
"""

import jax
import jax.numpy as jnp
from jax import config, jit, vmap
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = [jnp.array(1.0)]*MAX_NPARAMS

def evaluate(data: dict, params=initial_params) -> float:

    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q = inputs[:, :n_dim]
    q_t = inputs[:, n_dim:]
    true_accelerations = outputs

    @jit
    def compute_acceleration(q, q_t, params):

        @jit
        def lag(q_single, q_t_single, params):
            result = equation(q_single, q_t_single, params)
            return jnp.sum(result)

        hessian_q_t = jax.hessian(lag, 1)(q, q_t, params)
        grad_q = jax.grad(lag, 0)(q, q_t, params)
        jacobian_q_q_t = jax.jacobian(jax.grad(lag, 1), 0)(q, q_t, params)
        q_tt = jnp.linalg.pinv(hessian_q_t) @ (grad_q - jacobian_q_q_t @ q_t)
        return q_tt

    batch_compute_acceleration = jit(vmap(compute_acceleration, in_axes=(0, 0, None)))

    @jit
    def loss_fn(params):
        predicted_accelerations = batch_compute_acceleration(q, q_t, params)
        return jnp.mean(jnp.square(predicted_accelerations - true_accelerations))


    def objective(params):
        try:
            params = jnp.array(params)
            loss_value = loss_fn(params)
            return float(loss_value)
        except Exception as e:
            print(f"Error in objective function: {e}")

            return 1e10


    lb = [-10.0] * len(initial_params)
    ub = [10.0] * len(initial_params)


    optimized_params, optimized_loss = pso(objective, lb, ub, swarmsize=40, maxiter=500)

    print("pso Optimized parameters:", optimized_params)
    print("pso Optimized loss:", optimized_loss)

    loss_partial = jit(loss_fn)
    result = minimize(loss_partial, optimized_params, method='BFGS', options={'maxiter': 1000})
    optimized_params = result.x
    loss = result.fun

    if jnp.isnan(loss) or jnp.isinf(loss):
        return None
    else:
        print(optimized_params)
        return -loss.item()

@jit
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array):
    """ Mathematical function for lagrangian in a one-dimensional physical system
    Args:
        q (jnp.array): observation of current generalized coordinate.
        q_t (jnp.array): observation of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian as the result of applying the mathematical function to the inputs.
    """
    # the input
    q = q[...,0]
    q_t = q_t[...,0]

    # the energy
    T = params[0]*jnp.power(1-q_t**2,-0.5)

    # the potential
    V = -params[3]*q + params[2]

    return T-V



In [6]:
import jax
import jax.numpy as jnp
from jax import config, grad, jit, vmap
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)

def evaluate(data: dict, params=initial_params) -> dict:
    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q, q_t = inputs[:, :n_dim], inputs[:, n_dim:]
    true_accelerations = outputs

    @jit
    def compute_acceleration(q, q_t, params):
        def lag(q, q_t, params):
            return equation(q, q_t, params)
        hessian_q_t = jax.hessian(lag, 1)(q, q_t, params)
        grad_q = jax.grad(lag, 0)(q, q_t, params)
        jacobian_q_q_t = jax.jacobian(jax.grad(lag, 1), 0)(q, q_t, params)
        return jnp.linalg.pinv(hessian_q_t) @ (grad_q - jacobian_q_q_t @ q_t)

    batch_compute_acceleration = jit(vmap(compute_acceleration, (0, 0, None)))

    @jit
    def loss_fn(params):
        pred = batch_compute_acceleration(q, q_t, params)
        return jnp.mean(jnp.square(pred - true_accelerations))

    def run_optimization(objective_fn, initial_guess):
        if initial_guess.size > MAX_NPARAMS:
            result = minimize(objective_fn, initial_guess,
                            method='BFGS', options={'maxiter': 500})
            return result.x
        else:
            def pso_wrapper(x):
                return objective_fn(jnp.array(x))
            
            lb = [-1.0]*initial_guess.size
            ub = [10.0]*initial_guess.size
            
            pso_params, _ = pso(pso_wrapper, lb, ub, 
                            swarmsize=30, maxiter=200)
            
            result = minimize(objective_fn, jnp.array(pso_params),
                            method='BFGS', options={'maxiter': 500})
            return result.x

    def calculate_sensitivities(opt_params, base_loss):
        mask = 1 - jnp.eye(MAX_NPARAMS)
        
        @jit
        def batch_loss(params_matrix):
            return vmap(loss_fn)(params_matrix * mask)

        def sensitivity_objective(flat_params):
            matrix_params = flat_params.reshape(MAX_NPARAMS, MAX_NPARAMS)
            return jnp.sum(batch_loss(matrix_params))

        try:
            initial_flat = (opt_params * mask).flatten()
            optimized_flat = run_optimization(sensitivity_objective, initial_flat)
            optimized_matrix = optimized_flat.reshape(MAX_NPARAMS, MAX_NPARAMS)
            losses = batch_loss(optimized_matrix)
            return jnp.log2(losses / base_loss)
            
        except Exception as e:
            print(f"Sensitivity analysis error: {str(e)}")
            return jnp.zeros(MAX_NPARAMS)

    try:
        optimized_params = run_optimization(loss_fn, params)
        final_loss = loss_fn(optimized_params)
    except Exception as e:
        print(f"Optimization failed: {e}")
        return None

    if not jnp.isfinite(final_loss):
        return None
    
    sensitivities = calculate_sensitivities(optimized_params, final_loss)
    sensitivity_dict = {f"sensitive of params[{i}]": round(float(sensitivities[i]), 4) 
                       for i in range(len(sensitivities))}

    return {
        'params': optimized_params.tolist(),
        'loss': - float(final_loss),
        'sensitivities': sensitivity_dict
    }

@jit
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array):
    """ Mathematical function for lagrangian in a one-dimensional physical system
    Args:
        q (jnp.array): observation of current generalized coordinate.
        q_t (jnp.array): observation of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian as the result of applying the mathematical function to the inputs.
    """
    # the input
    q = q[...,0]
    q_t = q_t[...,0]

    # the energy
    T = params[0]*jnp.power(1-q_t**2,-0.5)

    # the potential
    V = -params[3]*q + params[2]

    return T-V


In [7]:
import jax
import jax.numpy as jnp
import pandas as pd
from jax.experimental.ode import odeint
from jax import config
config.update("jax_enable_x64", True)



def qdotdot(q, q_t, conditionals):
    g = conditionals
    
    q_tt = (
        g * (1 - q_t**2)**(5./2) / 
        (1 + 2 * q_t**2)
    )
    
    return q_t, q_tt


# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./relative_particle.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :2], dtype=jnp.float64)  # 转换为 JAX 数组
g = jnp.array(tae[:, 2], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 3:4], dtype=jnp.float64)  # 转换为 JAX 数组
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
}

f = evaluate(data)
print(f)

'''
inputs, outputs = data['inputs'], data['outputs']
n_dim = inputs.shape[1] // 2
q = inputs[:, :n_dim]
q_t = inputs[:, n_dim:]
true_accelerations = outputs
print(q[:5],q_t[:5],true_accelerations[:5])
print(g)


dd = qdotdot(q, q_t, g)[1]
loss = jnp.mean((dd - true_accelerations)**2)  # 计算损失
print('Loss:', loss)
'''


Stopping search: Swarm best objective change less than 1e-08
{'params': [0.7727671435875787, 0.9889311251799003, 6.853301227843082, 7.5731173436548405, 5.214923455301921, 1.305482912520252, 5.780432597713505, 0.4861455058073598, 3.049734631454361, 2.537328650951322], 'loss': -3.6944959872168676e-14, 'sensitivities': {'sensitive of params[0]': 46.8886, 'sensitive of params[1]': 0.0, 'sensitive of params[2]': 0.0, 'sensitive of params[3]': 46.8886, 'sensitive of params[4]': 0.0, 'sensitive of params[5]': 0.0, 'sensitive of params[6]': 0.0, 'sensitive of params[7]': 0.0, 'sensitive of params[8]': 0.0, 'sensitive of params[9]': 0.0}}


"\ninputs, outputs = data['inputs'], data['outputs']\nn_dim = inputs.shape[1] // 2\nq = inputs[:, :n_dim]\nq_t = inputs[:, n_dim:]\ntrue_accelerations = outputs\nprint(q[:5],q_t[:5],true_accelerations[:5])\nprint(g)\n\n\ndd = qdotdot(q, q_t, g)[1]\nloss = jnp.mean((dd - true_accelerations)**2)  # 计算损失\nprint('Loss:', loss)\n"

# 下面这个优化虽然好，但是太慢，还存在的一个问题是参数编号需要重新对上

In [None]:
"""
Find the mathematical function skeleton that represents acceleration in a damped nonlinear oscillator system with driving force, given data on time, position, and velocity. 
"""

import jax
import jax.numpy as jnp
from jax import jit, vmap, config
import pyswarms as ps
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
import numpy as np
import jax.random as random
from jaxopt import LBFGS

MAX_NPARAMS = 10
initial_params = jnp.ones(MAX_NPARAMS, dtype=jnp.float64)


def evaluate(data: dict, params=initial_params) -> dict:

    master_key = random.PRNGKey(0)  
    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q = inputs[:, :n_dim]
    q_t = inputs[:, n_dim:]
    true_accelerations = outputs

    @jit
    def compute_acceleration(q, q_t, params):

        @jit
        def lag(q_single, q_t_single, params):
            result = equation(q_single, q_t_single, params)
            return jnp.sum(result)

        hessian_q_t = jax.hessian(lag, 1)(q, q_t, params)
        grad_q = jax.grad(lag, 0)(q, q_t, params)
        jacobian_q_q_t = jax.jacobian(jax.grad(lag, 1), 0)(q, q_t, params)
        q_tt = jnp.linalg.pinv(hessian_q_t) @ (grad_q - jacobian_q_q_t @ q_t)
        return q_tt

    batch_compute_acceleration = jit(vmap(compute_acceleration, in_axes=(0, 0, None)))

    @jit
    def loss_fn(params):
        predicted_accelerations = batch_compute_acceleration(q, q_t, params)
        return jnp.mean(jnp.square(predicted_accelerations - true_accelerations))
    
    def run_optimization(objective_fn, initial_guess, key, num_pso_runs=3, pso_iters=100, swarmsize=50, gg = False):
        print(f"Initial guess size: {initial_guess.size}, MAX_NPARAMS: {MAX_NPARAMS}")

        grads   = jax.grad(equation,2)(q[1],q_t[1],params)           # (MAX_NPARAMS,)           
        active  = jnp.where(jnp.abs(grads) > 0)[0]        # e.g. [0,1,2,3,4,5,6,7]
        n_params = active.shape[0]                        # 只保留活跃参数
        print(n_params)
        
        solver = LBFGS(objective_fn, maxiter=100, tol=1e-8)

        if  gg == True:
            result =  solver.run(initial_guess)
            return result.params
        else:
            @jit
            def pso_objective_wrapper(particles_matrix):
                return vmap(objective_fn)(particles_matrix)

            min_bound_np = np.full(n_params, -10.0, dtype=np.float64)
            max_bound_np = np.full(n_params, 10.0, dtype=np.float64)
            bounds = (min_bound_np, max_bound_np)

            options = {'c1': 1.49445, 'c2': 1.49445, 'w': 0.729}

            best_pso_params = None
            best_pso_loss = jnp.inf

            current_key = key # Use the passed-in key

            print(f"Running pyswarms PSO {num_pso_runs} times with unique initial swarms...")
            for i in range(num_pso_runs):
                current_key, subkey = random.split(current_key)
                min_bound_jnp = jnp.full(n_params, -10.0, dtype=jnp.float64)
                max_bound_jnp = jnp.full(n_params, 10.0, dtype=jnp.float64)
                init_pos_jax = random.uniform(subkey,
                                              shape=(swarmsize, n_params),
                                              dtype=jnp.float64,
                                              minval=min_bound_jnp,
                                              maxval=max_bound_jnp)
                # Convert to NumPy array for pyswarms
                init_pos_np = np.array(init_pos_jax)

                print(f"  PSO Run {i+1}/{num_pso_runs} (using JAX key split for init_pos)")
                optimizer = ps.single.GlobalBestPSO(n_particles=swarmsize,
                                                     dimensions=n_params,
                                                     options=options,
                                                     bounds=bounds,
                                                     # Pass the generated initial positions
                                                     init_pos=init_pos_np) # <-- Pass init_pos here

                # Perform optimization (pyswarms will use the provided init_pos)
                current_pso_loss, current_pso_params = optimizer.optimize(
                    pso_objective_wrapper,
                    iters=pso_iters,
                    verbose=False
                )
                current_pso_params = jnp.array(current_pso_params, dtype=jnp.float64)

                print(f"    Run {i+1} completed. Loss: {current_pso_loss}")
                if current_pso_loss < best_pso_loss:
                    best_pso_loss = current_pso_loss
                    best_pso_params = current_pso_params
                    print(f"    New best PSO loss found: {best_pso_loss}")

            # ... (rest of the function: handling no solution, BFGS refinement) ...
            if best_pso_params is None:
                 print("Warning: PSO did not find any valid solution after multiple runs. Using initial guess for BFGS.")
                 best_pso_params = initial_guess # Fallback

            print(f"\nBest PSO loss after {num_pso_runs} runs: {best_pso_loss}")
            print("Refining best PSO result with L- BFGS...")

            result =  solver.run(best_pso_params)
            
            return result.params

    
    def calculate_sensitivities(opt_params: jnp.ndarray, base_loss: float) -> jnp.ndarray:
        n_active = opt_params.size
        masks = 1.0 - jnp.eye(n_active, dtype=opt_params.dtype)
        @jit
        def solve_one(mask_i):
            init = opt_params * mask_i
            def sub_loss(p):
                return loss_fn(p * mask_i)
            solver = LBFGS(fun=sub_loss, maxiter=100, tol=1e-8)
            out    = solver.run(init)         
            loss_i = loss_fn(out.params * mask_i)       
            return jnp.log2(loss_i / base_loss)

        sensitivities = vmap(solve_one)(masks)
        return sensitivities

    try:
        opt_key, sensi_key = random.split(master_key)
        optimized_params = run_optimization(loss_fn, params, key=opt_key)
        final_loss = loss_fn(optimized_params)
        print(f"Final loss after L-BFGS: {final_loss}")
        if optimized_params is None:
             print("Optimization failed to produce parameters.")
             return None

        final_loss = loss_fn(optimized_params)
        print(f"Final optimized loss: {final_loss}")
    except Exception as e:
        # ... (error handling remains the same) ...
        print(f"Optimization or final loss calculation failed: {e}")
        import traceback
        traceback.print_exc()
        return None

    if not jnp.isfinite(final_loss):
        print("Final loss is not finite.")
        return None

    # Pass a key to sensitivity analysis if it also needs randomness
    # For now, assuming calculate_sensitivities doesn't need a separate key
    sensitivities = calculate_sensitivities(optimized_params, final_loss)
    sensitivity_dict = {f"sensitive of params[{i}]": round(float(sensitivities[i]), 4) 
                       for i in range(len(sensitivities))}


    return {
        'params': optimized_params,
        'loss': final_loss,
        'sensitivities': sensitivity_dict
    }

