# 均匀势场中的相对论粒子
$ L = (1-\dot{q}^2)^{-\frac{1}{2}}-1 + gq $   
$ \ddot{q} = \frac{g (1-\dot{q}^2)^{2.5}}{1+2\dot{q}^2}$

In [None]:
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# 定义微分方程
def dqdt(t, y, g):
    q, qdot = y  # y = [q, qdot]
    qddot = g * (1 - qdot**2)**(5/2) / (1 + 2 * qdot**2)
    return [qdot, qddot]

# 初始条件
q0 = 0.0      # q(0) = 0
qdot0 = -0.9   # qdot(0) = 0.4 (确保 |qdot| < 1)
g = 9.8       # 假设 g = 9.8（根据物理问题调整）

# 时间范围
t_span = (0, 10)  # 从 t=0 到 t=10
t_eval = np.linspace(t_span[0], t_span[1], 1000)  # 时间点

# 求解
sol = solve_ivp(dqdt, t_span, [q0, qdot0], args=(g,), t_eval=t_eval, method='RK45')

# 提取结果
q = sol.y[0]      # q(t)
qdot = sol.y[1]   # qdot(t)

# 绘制结果
plt.figure(figsize=(10, 6))
plt.subplot(2, 1, 1)
plt.plot(sol.t, q, label='$q(t)$')
plt.xlabel('Time')
plt.ylabel('Position $q$')
plt.title('Solution of the Euler-Lagrange Equation')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(sol.t, qdot, label='$\dot{q}(t)$', color='r')
plt.xlabel('Time')
plt.ylabel('Velocity $\dot{q}$')
plt.legend()
plt.tight_layout()
plt.show()


In [2]:
from jax.experimental.ode import odeint
import jax
import jax.numpy as jnp
import numpy as np # get rid of this eventually
import argparse
from jax import jit
from jax.experimental.ode import odeint
from functools import partial # reduces arguments to function by making some subset implicit

from jax.example_libraries import stax
from jax.example_libraries import optimizers

@jax.jit
def qdotdot(q, q_t, conditionals):
    g = conditionals
    
    q_tt = (
        g * (1 - q_t**2)**(5./2) / 
        (1 + 2 * q_t**2)
    )
    
    return q_t, q_tt

@jax.jit
def ofunc(y, t=None):
    q = y[::3]
    q_t = y[1::3]
    g = y[2::3]
    
    q_t, q_tt = qdotdot(q, q_t, g)
    return jnp.stack([q_t, q_tt, jnp.zeros_like(g)]).T.ravel()


@partial(jax.jit, static_argnums=(1, 2), backend='cpu')
def gen_data(seed, batch, num):
    rng = jax.random.PRNGKey(seed)
    q0 = jax.random.uniform(rng, (batch,), minval=-10, maxval=10)
    qt0 = jax.random.uniform(rng+1, (batch,), minval=-0.99, maxval=0.99)
    g = jax.random.normal(rng+2, (batch,))*10

    y0 = jnp.stack([q0, qt0, g]).T.ravel()
    y1 = jnp.stack([q0, qt0, g]).T
    print(y1.shape)

    yt = odeint(ofunc, y0, jnp.linspace(0, 1, num=num),  atol=1e-6, rtol=1e-6)

    
    qall = yt[:, ::3]
    qtall = yt[:, 1::3]
    gall = yt[:, 2::3]
    
    return jnp.stack([qall, qtall]).reshape(2, -1).T, gall.reshape(1, -1).T, qdotdot(qall, qtall, gall)[1].reshape(1, -1).T

'''@partial(jax.jit, static_argnums=(1,))
def gen_data_batch(seed, batch):
    rng = jax.random.PRNGKey(seed)
    q0 = jax.random.uniform(rng, (batch,), minval=-10, maxval=10)
    qt0 = (jnp.tanh(jax.random.normal(jax.random.PRNGKey(1), (batch,))*2)*0.99999)#jax.random.uniform(rng+1, (batch,), minval=-1, maxval=1)
    g = jax.random.normal(rng+2, (batch,))*10
    
    return jnp.stack([q0, qt0]).reshape(2, -1).T, g.reshape(1, -1).T, qdotdot(q0, qt0, g)[1].reshape(1, -1).T
'''

dd = gen_data(0, 1, 1000)


print(dd[0][0,2])
print('qt', dd[0][:5, 1], 'g', dd[1][:5, 0], 'qtt', dd[2][:5, 0])
print('qt', dd[0][:5, 0], 'g', dd[1][:5, 0], 'qtt', dd[2][:5, 0])


(1, 3)
0.41314822
qt [0.41314822 0.4152955  0.4174257  0.41953892 0.42163536] g [4.612761 4.612761 4.612761 4.612761 4.612761] qtt [2.1537657 2.1365683 2.1195455 2.1026988 2.0860243]
qt [8.953341 8.953755 8.954172 8.954591 8.955012] g [4.612761 4.612761 4.612761 4.612761 4.612761] qtt [2.1537657 2.1365683 2.1195455 2.1026988 2.0860243]


In [None]:
dd = gen_data(0, 128, 1000)

In [None]:
print('qt', dd[0][:5, 1], 'g', dd[1][:5, 0], 'qtt', dd[2][:5, 0])
print('qt', dd[0][:5, 0], 'g', dd[1][:5, 0], 'qtt', dd[2][:5, 0])
print(dd[1][0, 0],dd[1][128, 0])
print(dd[1][1, 0],dd[1][129, 0])
print(dd[0][1, 0],dd[0][129, 0])


In [2]:
from jax.experimental.ode import odeint
import jax
import jax.numpy as jnp
import numpy as np
import argparse
from functools import partial

def qdotdot(q, q_t, conditionals):
    g = conditionals
    
    q_tt = (
        g * (1 - q_t**2)**(5./2) / 
        (1 + 2 * q_t**2)
    )
    
    return q_t, q_tt

def ofunc(y, t=None):
    q = y[::3]
    q_t = y[1::3]
    g = y[2::3]
    
    q_t, q_tt = qdotdot(q, q_t, g)
    return jnp.stack([q_t, q_tt, jnp.zeros_like(g)]).T.ravel()

def gen_data(seed, batch, num):
    rng = jax.random.PRNGKey(seed)
    q0 = jax.random.uniform(rng, (batch,), minval=-10, maxval=10)
    qt0 = jax.random.uniform(rng+1, (batch,), minval=-0.99, maxval=0.99)
    g = jax.random.normal(rng+2, (batch,))*10

    y0 = jnp.stack([q0, qt0, g]).T.ravel()
    y1 = jnp.stack([q0, qt0, g]).T
    print(y1.shape)

    yt = odeint(ofunc, y0, jnp.linspace(0, 1, num=num), atol=1e-6, rtol=1e-6)
    
    qall = yt[:, ::3]
    qtall = yt[:, 1::3]
    gall = yt[:, 2::3]
    
    return jnp.stack([qall, qtall]).reshape(2, -1).T, gall.reshape(1, -1).T, qdotdot(qall, qtall, gall)[1].reshape(1, -1).T

dd = gen_data(0, 128, 1000)
f = dd[1]

print(f.shape)
print(dd[0][0,2])
print('qt', dd[0][:5, 1], 'g', dd[1][:5, 0], 'qtt', dd[2][:5, 0])
print('qt', dd[0][:5, 0], 'g', dd[1][:5, 0], 'qtt', dd[2][:5, 0])


(128, 3)
(128000, 1)
0.41314822
qt [ 0.41314822  0.6002258   0.11193506  0.05001947 -0.9819116 ] g [  4.612761    2.1626668 -12.121061   -6.3339176  17.351053 ] qtt [ 2.1537657e+00  4.1144755e-01 -1.1457828e+01 -6.2630343e+00
  1.4418560e-03]
qt [ 8.953341   9.571598  -3.3541703 -0.6266308  1.3977742] g [  4.612761    2.1626668 -12.121061   -6.3339176  17.351053 ] qtt [ 2.1537657e+00  4.1144755e-01 -1.1457828e+01 -6.2630343e+00
  1.4418560e-03]


In [4]:
import numpy as np
from scipy.integrate import odeint
import argparse
from functools import partial

def qdotdot(q, q_t, conditionals):
    g = conditionals
    
    q_tt = (
        g * (1 - q_t**2)**(5./2) / 
        (1 + 2 * q_t**2)
    )
    
    return q_t, q_tt

def ofunc(y, t):
    q = y[::3]
    q_t = y[1::3]
    g = y[2::3]
    
    q_t, q_tt = qdotdot(q, q_t, g)
    return np.stack([q_t, q_tt, np.zeros_like(g)]).T.ravel()

def gen_data(seed, batch, num):
    np.random.seed(seed)
    q0 = np.random.uniform(low=-10, high=10, size=(batch,))
    qt0 = np.random.uniform(low=-0.99, high=-0.8, size=(batch,))
    g = np.random.normal(loc=0, scale=10, size=(batch,))

    y0 = np.stack([q0, qt0, g]).T.ravel()
    y1 = np.stack([q0, qt0, g]).T
    print(y1.shape)

    t = np.linspace(0, 1, num=num)
    yt = odeint(ofunc, y0, t, atol=1e-6, rtol=1e-6)
    
    qall = yt[:, ::3]
    qtall = yt[:, 1::3]
    gall = yt[:, 2::3]
    
    return np.stack([qall, qtall]).reshape(2, -1).T, gall.reshape(1, -1).T, qdotdot(qall, qtall, gall)[1].reshape(1, -1).T

dd = gen_data(0, 1, 1000)



print("数据已保存为ode_results.csv")
print(dd[0][0,0])
print('qt', dd[0][:5, 1], 'g', dd[1][:5, 0], 'qtt', dd[2][:5, 0])
print('qt', dd[0][:5, 0], 'g', dd[1][:5, 0], 'qtt', dd[2][:5, 0])


(1, 3)
数据已保存为ode_results.csv
0.9762700785464951
qt [-0.85411402 -0.85396221 -0.85381    -0.85365739 -0.85350437] g [9.78737984 9.78737984 9.78737984 9.78737984 9.78737984] qtt [0.15145342 0.15184868 0.1522456  0.1526442  0.15304448]
qt [0.97627008 0.97541519 0.97456045 0.97370586 0.97285143] g [9.78737984 9.78737984 9.78737984 9.78737984 9.78737984] qtt [0.15145342 0.15184868 0.1522456  0.1526442  0.15304448]


In [5]:
import pandas as pd

# 解构dd中的数据
xy, g, qtt = dd

# 创建DataFrame
df = pd.DataFrame({
    'q': xy[:, 0],   # 位置q
    'qt': xy[:, 1],  # 速度qt
    'g': g[:, 0],    # 条件g
    'qtt': qtt[:, 0] # 加速度qtt
})

# 保存为CSV文件
df.to_csv('relative_particle.csv', index=False)


# 检查代码是否正确

In [11]:
import jax
import jax.numpy as jnp
import pandas as pd
from jax.experimental.ode import odeint
from jax import config
config.update("jax_enable_x64", True)



def qdotdot(q, q_t, conditionals):
    g = conditionals
    
    q_tt = (
        g * (1 - q_t**2)**(5./2) / 
        (1 + 2 * q_t**2)
    )
    
    return q_t, q_tt


# 读取 CSV 文件并转换为 NumPy 数组
data0 = pd.read_csv('./relative_particle.csv')#
tae = data0.to_numpy()

# 使用 JAX 的数组操作替换 PyTorch 的操作
state = jnp.array(tae[:, :2], dtype=jnp.float64)  # 转换为 JAX 数组
g = jnp.array(tae[:, 2], dtype=jnp.float64)  # 转换为 JAX 数组
true_q_ddot = jnp.array(tae[:, 3:], dtype=jnp.float64)  # 转换为 JAX 数组
# 将数据存储在字典中
data = {
    'inputs': state,
    'outputs': true_q_ddot,  # 真实的加速度
}

inputs, outputs = data['inputs'], data['outputs']
n_dim = inputs.shape[1] // 2
q = inputs[:, :n_dim]
q_t = inputs[:, n_dim:]
true_accelerations = outputs
print(q[:5],q_t[:5],true_accelerations[:5])
print(g)


dd = qdotdot(q, q_t, g)[1]
loss = jnp.mean((dd - true_accelerations)**2)  # 计算损失
print('Loss:', loss)



[[0.97627008]
 [0.97541519]
 [0.97456045]
 [0.97370586]
 [0.97285143]] [[-0.85411402]
 [-0.85396221]
 [-0.85381   ]
 [-0.85365739]
 [-0.85350437]] [[0.15145342]
 [0.15184868]
 [0.1522456 ]
 [0.1526442 ]
 [0.15304448]]
[9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984 9.78737984
 9.78737984 9.78737984 9.78737984 9.78737984 