In [None]:
import os

import torch
from torch.func import vmap, grad, jacrev

torch.set_default_dtype(torch.float64)

from symplearn.datasets import VectorFieldDataset
from symplearn.numerics import ExplicitEulerSimulation, EulerDVISimulation

from massless_charged_particle.models import MasslessChargedParticle

model = MasslessChargedParticle()
dt = 0.05
nt = 2

dvi = EulerDVISimulation(model, dt)
rk4 = ExplicitEulerSimulation(model, dt)
data = VectorFieldDataset("val", os.path.join("massless_charged_particle", "data"))
z0 = data.z
t, zt = rk4.simulate(z0, nt)

z1, t1 = zt[:, 1], t[:, 1]
z2, t2 = zt[:, 2], t[:, 2]
x1, y1 = torch.tensor_split(z1, 2, dim=-1)
x2, y2 = torch.tensor_split(z2, 2, dim=-1)

In [None]:
z0 = z1 - dt * model.vector_field(z1, t1)
x0, y0 = torch.tensor_split(z0, 2, -1)

err = vmap(dvi.backstep_x)(x1, y1, t1) - x0
print("Backstep error:", err.square().mean(0).sqrt())

In [None]:
q1, h1 = vmap(model.lagrangian_maps)(x1, y1, t1)
(dx_q1, _), (dx_h1, _) = vmap(model.euler_lagrange_maps)(x1, y1, t1)
q2, h2 = vmap(model.lagrangian_maps)(x2, y2, t2)
(dx_q2, dy_q2), (dx_h2, dy_h2) = vmap(model.euler_lagrange_maps)(x2, y2, t2)

dx1 = (x1 - x0).unsqueeze(-1) / dt
theo_dx_lag_cur = torch.linalg.vecdot(dx1, dx_q1, dim=-2) + q1 / dt - dx_h1
implem_dx_lag_cur = vmap(dvi.d_lag_cur)(x0, x1, y1, t1)
print("Dx L cur:", (theo_dx_lag_cur - implem_dx_lag_cur).square().mean(0).sqrt())

dx2 = (x2 - x1)[..., None] / dt
theo_dx_lag_next = -q2 / dt
theo_dy_lag_next = torch.sum(dx2 * dy_q2, -2) - dy_h2

implem_dx_lag_next, implem_dy_lag_next = vmap(dvi.d_lag_next)(x1, x2, y2, t2)

err_dx_lag_next = theo_dx_lag_next - implem_dx_lag_next
err_dy_lag_next = theo_dy_lag_next - implem_dy_lag_next
print("Dx L next:", err_dx_lag_next.square().mean(0).sqrt())
print("Dy L next:", err_dy_lag_next.square().mean(0).sqrt())

res = vmap(dvi.scheme)(z2, x1, y1, t1)
print("Scheme residual:", res.square().mean(0).sqrt())

res_with_lag = vmap(dvi.scheme_with_lag)(z2, x1, t1, theo_dx_lag_cur)
mean_res_with_lag = res_with_lag.square().mean(0).sqrt()
print("Scheme residual with pre-computed Lagrangian:", mean_res_with_lag)

print("Difference of residuals:", (res - res_with_lag).square().mean(0).sqrt())

In [None]:
dx_lag_cur = implem_dx_lag_cur
schx = dx_lag_cur + implem_dx_lag_next
schy = implem_dy_lag_next
sch = torch.cat((schx, schy), -1)

d2q = vmap(jacrev(jacrev(model.oneform, (0, 1)), 1))
(dyx_q2, dyy_q2) = d2q(x2, y2)

d2h = vmap(jacrev(grad(model.hamiltonian, (0, 1)), 1))
(dyx_h2, dyy_h2) = d2h(x2, y2, t2)

dx2 = (x2 - x1)[..., None, None] / dt
dx_schx = -dx_q2 / dt
dy_schx = -dy_q2 / dt
dx_schy = dy_q2.transpose(-1, -2) / dt + torch.sum(dx2 * dyx_q2, -3) - dyx_h2
dy_schy = torch.sum(dx2 * dyy_q2, -3) - dyy_h2

dx_sch = torch.cat((dx_schx, dx_schy), -2)
dy_sch = torch.cat((dy_schx, dy_schy), -2)
theo_jac = torch.cat((dx_sch, dy_sch), -1)

implem_jac = vmap(jacrev(dvi.scheme_with_lag))(z2, x1, t2, dx_lag_cur)

err_jac = theo_jac - implem_jac
print("Error on the Jacobian:", err_jac.square().mean(0).sqrt())

In [None]:
theo_dz = torch.linalg.solve(theo_jac, sch)
z2_theo_1nr = z2 - theo_dz
res_theo_1nr = vmap(dvi.scheme_with_lag)(z2_theo_1nr, x1, t2, dx_lag_cur)
mean_res_theo_1nr = res_theo_1nr.square().mean(0).sqrt()
print("Res. after 1 N-R iteration (theo. jac.):", mean_res_theo_1nr)

implem_dz = torch.linalg.solve(implem_jac, sch)
z2_implem_1nr = z2 - implem_dz
res_implem_1nr = vmap(dvi.scheme_with_lag)(z2_implem_1nr, x1, t2, dx_lag_cur)
mean_res_implem_1nr = res_implem_1nr.square().mean(0).sqrt()
print("Res. after 1 N-R iteration (implem. jac.):", mean_res_implem_1nr)

jac_1nr = vmap(jacrev(dvi.scheme_with_lag))(z2_implem_1nr, x1, t2, dx_lag_cur)
dz_1nr = torch.linalg.solve(jac_1nr, res_implem_1nr)
z2_implem_2nr = z2_implem_1nr - dz_1nr
res_implem_2nr = vmap(dvi.scheme_with_lag)(z2_implem_2nr, x1, t2, dx_lag_cur)
mean_res_implem_2nr = res_implem_2nr.square().mean(0).sqrt()
print("Res. after 2 N-R iteration (implem. jac.):", mean_res_implem_2nr)

jac_2nr = vmap(jacrev(dvi.scheme_with_lag))(z2_implem_2nr, x1, t2, dx_lag_cur)
dz_2nr = torch.linalg.solve(jac_2nr, res_implem_2nr)
z2_implem_3nr = z2_implem_2nr - dz_2nr
res_implem_3nr = vmap(dvi.scheme_with_lag)(z2_implem_3nr, x1, t2, dx_lag_cur)
mean_res_implem_3nr = res_implem_3nr.square().mean(0).sqrt()
print("Res. after 3 N-R iteration (implem. jac.):", mean_res_implem_3nr)

jac_3nr = vmap(jacrev(dvi.scheme_with_lag))(z2_implem_3nr, x1, t2, dx_lag_cur)
dz_3nr = torch.linalg.solve(jac_3nr, res_implem_3nr)
z2_implem_4nr = z2_implem_3nr - dz_3nr
res_implem_4nr = vmap(dvi.scheme_with_lag)(z2_implem_4nr, x1, t2, dx_lag_cur)
mean_res_implem_4nr = res_implem_4nr.square().mean(0).sqrt()
print("Res. after 4 N-R iteration (implem. jac.):", mean_res_implem_4nr)