In [13]:
import jax
import jax.numpy as jnp
from jax import config, grad, jit, vmap
import numpy as np
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = [jnp.array(1.0)]*MAX_NPARAMS



def evaluate(data: dict, params=initial_params) -> float:

    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q = inputs[:, :n_dim]
    q_t = inputs[:, n_dim:]
    true_accelerations = outputs


    # 修改：将lagrangian直接嵌入compute_acceleration中，避免将函数作为参数传递
    @jit
    def compute_acceleration(q, q_t, params):
        # 直接使用equation函数而不是传递lagrangian函数
        
        # 预先计算梯度
        @jit
        def lag(q_single, q_t_single, params):
            # 确保对单个样本调用时返回标量
            result = equation(q_single, q_t_single, params)
            # 如果结果是数组，取和或平均值来确保返回标量
            return jnp.sum(result)
            
        # 计算二阶导数
        hessian_q_t = jax.hessian(lag, 1)(q, q_t, params)
        grad_q = jax.grad(lag, 0)(q, q_t, params)
        jacobian_q_q_t = jax.jacobian(jax.grad(lag, 1), 0)(q, q_t, params)
        q_tt = jnp.linalg.pinv(hessian_q_t) @ (grad_q - jacobian_q_q_t @ q_t)
        return q_tt

    # 使用vmap批处理
    batch_compute_acceleration = jit(vmap(compute_acceleration, in_axes=(0, 0, None)))

    # 使用jit装饰器优化损失函数
    @jit
    def loss_fn(params):
        predicted_accelerations = batch_compute_acceleration(q, q_t, params)
        return jnp.mean(jnp.square(predicted_accelerations - true_accelerations))
    
    #return loss_fn(params)

    # 针对numpy数组的包装函数
    def objective(params):
        try:
            params = jnp.array(params)  # 确保转换为JAX数组
            loss_value = loss_fn(params)
            return float(loss_value)
        except Exception as e:
            print(f"Error in objective function: {e}")
            # 返回一个大的损失值，避免优化器选择这个点
            return 1e10

    # 粒子群优化的参数
    lb = [-1.0] * len(initial_params)  # 参数下限
    ub = [10.0] * len(initial_params)   # 参数上限

    # 调用 pso 函数进行优化
    optimized_params, optimized_loss = pso(objective, lb, ub, swarmsize=30, maxiter=500)

    print("pso Optimized parameters:", optimized_params)
    print("pso Optimized loss:", optimized_loss)

    # 使用JAX的优化器进一步优化
    loss_partial = jit(loss_fn)  # 确保JIT编译
    result = minimize(loss_partial, optimized_params, method='BFGS', options={'maxiter': 1000})
    optimized_params = result.x
    loss = result.fun

    if jnp.isnan(loss) or jnp.isinf(loss):
        return None
    else:
        print(optimized_params)
        return -loss.item()

'''
@jit  # 添加JIT装饰器
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array) -> jnp.array:
    """ Mathematical function for lagrangian in a one-dimensional physical system
    Args:
        q (jnp.array): observation of current generalized coordinate.
        q_t (jnp.array): observation of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian as the result of applying the mathematical function to the inputs.
    """ 
    

    T =  1/(jnp.sqrt(1-params[0]*q_t**2))-params[1]
    V =  params[3]*q   
    

    return T - V'''


@jit
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array):

    #q = q[...,0]
    #q_t = q_t[...,0]
    T = 1/(jnp.sqrt(1-q_t**2))-params[1]
    V = -params[3]*q   
    
    result = T - V
    '''# 确保返回标量
    if hasattr(result, 'shape') and result.shape == (1,):
        return result[0]  # 返回标量而不是数组'''
    return result



In [14]:
from jax import random

import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./relative_particle/relative_particle.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :2], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 3:], dtype=jnp.float64)  # 转换为 JAX 数组
print(state.shape)
print(true_q_ddot)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
}

'''inputs, outputs = data['inputs'], data['outputs']
n_dim = inputs.shape[1] // 2
q = inputs[:, :n_dim]
q_t = inputs[:, n_dim:]
true_accelerations = outputs

initial_params = jnp.array([1.0, 1.0, 1.0, 9.787379841057392 ], dtype=jnp.float64)  # 初始参数

print(equation(q,q_t,initial_params))'''

print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

(1000, 2)
[[0.15145342]
 [0.15184868]
 [0.1522456 ]
 [0.1526442 ]
 [0.15304448]
 [0.15344645]
 [0.15385011]
 [0.15425547]
 [0.15466253]
 [0.1550713 ]
 [0.15548178]
 [0.15589401]
 [0.156308  ]
 [0.15672377]
 [0.15714133]
 [0.15756067]
 [0.15798181]
 [0.15840475]
 [0.15882949]
 [0.15925604]
 [0.15968441]
 [0.16011402]
 [0.16054629]
 [0.16098041]
 [0.16141639]
 [0.16185423]
 [0.16229395]
 [0.16273556]
 [0.16317907]
 [0.16362449]
 [0.16407184]
 [0.16452113]
 [0.16497238]
 [0.16542561]
 [0.16588084]
 [0.16633806]
 [0.1667973 ]
 [0.16725858]
 [0.16772189]
 [0.16818726]
 [0.1686547 ]
 [0.16912421]
 [0.16959583]
 [0.17006955]
 [0.17054539]
 [0.17102337]
 [0.1715035 ]
 [0.17198579]
 [0.17247026]
 [0.17295691]
 [0.17344577]
 [0.17393685]
 [0.17443015]
 [0.17492571]
 [0.17542352]
 [0.1759236 ]
 [0.17642598]
 [0.17693065]
 [0.17743765]
 [0.17794697]
 [0.17845864]
 [0.17897267]
 [0.17948907]
 [0.18000786]
 [0.18052907]
 [0.1810527 ]
 [0.18157877]
 [0.18210731]
 [0.18263831]
 [0.1831718 ]
 [0.183707

# 检验.sum是否影响求解梯度

In [None]:
import jax.numpy as jnp
import jax

# 定义方程
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array):
    T = 1 / jnp.sqrt(1 - q_t**2) - params[1]
    V = -params[3] * q   
    result = T - V
    return result

# 示例数据
key = jax.random.PRNGKey(0)
q = jax.random.uniform(key, (100,), minval=-0.5, maxval=0.5)  # 随机 q
q_t = jax.random.uniform(key, (100,), minval=-0.9, maxval=0.9)  # 随机 q_t，|q_t| < 1
params = jnp.array([0.0, 1.0, 0.0, 2.0])  # 示例参数，params[1] = 1.0, params[3] = 2.0

# 方式 1：分别计算每个数据点的梯度
def single_point_grad(q_i, q_t_i, params):
    # 计算单个数据点的梯度
    grad_fn = jax.grad(equation, argnums=(0, 1))
    grad_q, grad_q_t = grad_fn(q_i, q_t_i, params)
    return grad_q, grad_q_t

# 使用 vmap 向量化处理 100 个数据点
vectorized_grad = jax.vmap(single_point_grad, in_axes=(0, 0, None))
grad_q_1, grad_q_t_1 = vectorized_grad(q, q_t, params)

# 方式 2：对输出求和后再计算梯度
def summed_equation(q, q_t, params):
    return equation(q, q_t, params).sum()

grad_fn_sum = jax.grad(summed_equation, argnums=(0, 1))
grad_q_2, grad_q_t_2 = grad_fn_sum(q, q_t, params)

# 比较两种方式的结果
print("关于 q 的梯度（方式 1）：", grad_q_1[:5])  # 打印前 5 个
print("关于 q 的梯度（方式 2）：", grad_q_2[:5])
print("关于 q 的梯度差异：", jnp.abs(grad_q_1 - grad_q_2).max())

print("关于 q_t 的梯度（方式 1）：", grad_q_t_1[:5])
print("关于 q_t 的梯度（方式 2）：", grad_q_t_2[:5])
print("关于 q_t 的梯度差异：", jnp.abs(grad_q_t_1 - grad_q_t_2).max())


In [None]:
f = jnp.power(2,-2.5)
print(f)

# 检查完毕代码的数据结构，开始撰写SPECIFICATION文档供LLM-SR模型使用

In [None]:
"""
Find the mathematical function skeleton that represents lagrangian in a physical system, given data on generalized coordinate and generalized velocity.
Tips:You may only use no more than 10 parameters,Under limited parameter conditions, you can incorporate nonlinear terms rather than continuously adding new ones.
"""



import jax
import jax.numpy as jnp
from jax import config, jit, vmap
from pyswarm import pso
from jax.scipy.optimize import minimize
config.update("jax_enable_x64", True)

MAX_NPARAMS = 10
initial_params = [jnp.array(1.0)]*MAX_NPARAMS

@evaluate.run
def evaluate(data: dict, params=initial_params) -> float:

    inputs, outputs = data['inputs'], data['outputs']
    n_dim = inputs.shape[1] // 2
    q = inputs[:, :n_dim]
    q_t = inputs[:, n_dim:]
    true_accelerations = outputs


    # 修改：将lagrangian直接嵌入compute_acceleration中，避免将函数作为参数传递
    @jit
    def compute_acceleration(q, q_t, params):
        # 直接使用equation函数而不是传递lagrangian函数
        
        # 预先计算梯度
        @jit
        def lag(q_single, q_t_single, params):
            # 确保对单个样本调用时返回标量
            result = equation(q_single, q_t_single, params)
            # 如果结果是数组，取和或平均值来确保返回标量
            return result
            
        # 计算二阶导数
        hessian_q_t = jax.hessian(lag, 1)(q, q_t, params)
        grad_q = jax.grad(lag, 0)(q, q_t, params)
        jacobian_q_q_t = jax.jacobian(jax.grad(lag, 1), 0)(q, q_t, params)
        q_tt = jnp.linalg.pinv(hessian_q_t) @ (grad_q - jacobian_q_q_t @ q_t)
        return q_tt

    # 使用vmap批处理
    batch_compute_acceleration = jit(vmap(compute_acceleration, in_axes=(0, 0, None)))

    # 使用jit装饰器优化损失函数
    @jit
    def loss_fn(params):
        predicted_accelerations = batch_compute_acceleration(q, q_t, params)
        return jnp.mean(jnp.square(predicted_accelerations - true_accelerations))
    
    #return loss_fn(params)

    # 针对numpy数组的包装函数
    def objective(params):
        try:
            params = jnp.array(params)  # 确保转换为JAX数组
            loss_value = loss_fn(params)
            return float(loss_value)
        except Exception as e:
            print(f"Error in objective function: {e}")
            # 返回一个大的损失值，避免优化器选择这个点
            return 1e10

    # 粒子群优化的参数
    lb = [-10.0] * len(initial_params)  # 参数下限
    ub = [10.0] * len(initial_params)   # 参数上限

    # 调用 pso 函数进行优化
    optimized_params, optimized_loss = pso(objective, lb, ub, swarmsize=30, maxiter=500)

    print("pso Optimized parameters:", optimized_params)
    print("pso Optimized loss:", optimized_loss)

    # 使用JAX的优化器进一步优化
    loss_partial = jit(loss_fn)  # 确保JIT编译
    result = minimize(loss_partial, optimized_params, method='BFGS', options={'maxiter': 1000})
    optimized_params = result.x
    loss = result.fun

    if jnp.isnan(loss) or jnp.isinf(loss):
        return None
    else:
        print(optimized_params)
        return -loss.item()

@equation.evolve
@jit
def equation(q: jnp.array, q_t: jnp.array, params: jnp.array):
    """ Mathematical function for lagrangian in a one-dimensional physical system
    Args:
        q (jnp.array): observation of current generalized coordinate.
        q_t (jnp.array): observation of generalized velocity.
        params (jnp.array): List of numeric constants or parameters to be optimized.
    Returns:
        jnp.array: lagrangian as the result of applying the mathematical function to the inputs.
    """ 
    q = q[...,0]
    q_t = q_t[...,0]
    T = params[0]*jnp.power(1-q_t**2,-1)+params[1] + params[2]*jnp.power(q_t,2) - params[4]
    V = -params[3]*q + params[5]   
    
    result = T - V
    return result



In [6]:
from jax import random

import pandas as pd
# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./relative_particle/relative_particle.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :2], dtype=jnp.float64)  # 转换为 JAX 数组
g = jnp.array(tae[:, 2:3], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 3:], dtype=jnp.float64)  # 转换为 JAX 数组
print(state.shape)
#print(g)
#print(true_q_ddot)
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,
    'condition': g  # 真实的加速度
}



print(initial_params)
# 评估并优化参数
final_loss = evaluate(data, initial_params)
print("最终损失值 (MSE):", final_loss)

(1000, 2)
[Array(1., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True)]
Stopping search: Swarm best objective change less than 1e-08
pso Optimized parameters: [0.25336819 8.11832287 0.15681473 7.84612288 7.27479332 1.46364404
 9.3834202  3.20691351 1.65209292 3.38102884]
pso Optimized loss: 0.010096556562511948
[0.25338648 8.11832287 0.15680736 7.84612243 7.27479332 1.46364404
 9.3834202  3.20691351 1.65209292 3.38102884]
最终损失值 (MSE): -0.0100965436749942
