In [None]:

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

mass1, mass2 = 4, 4
pos1 = [5.0, 0.0]
pos2 = [-5.0, 0.0]
vel1 = [5.0, 0.0]
vel2 = [-5.0, 0.0]
grav_const = 1
duration = 10

def simulate_gravity(t, y):
    x1, y1, x2, y2, vx1, vy1, vx2, vy2 = y
    dx, dy = x2 - x1, y2 - y1
    dist = np.sqrt(dx**2 + dy**2)
    a1x = grav_const * mass2 * dx / dist**3
    a1y = grav_const * mass2 * dy / dist**3
    a2x = grav_const * mass1 * -dx / dist**3
    a2y = grav_const * mass1 * -dy / dist**3
    return [vx1, vy1, vx2, vy2, a1x, a1y, a2x, a2y]

initial_state = pos1 + pos2 + vel1 + vel2
time_span = [0, 2 * duration]
time_eval = np.linspace(*time_span, 100)
solution = solve_ivp(simulate_gravity, time_span, initial_state, t_eval=time_eval)

train_split = int(0.8 * len(time_eval))
train_time, test_time = np.split(time_eval, [train_split])
train_x1, test_x1 = np.split(solution.y[0], [train_split])
train_y1, test_y1 = np.split(solution.y[1], [train_split])
train_x2, test_x2 = np.split(solution.y[2], [train_split])
train_y2, test_y2 = np.split(solution.y[3], [train_split])

class GravityNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 4)
        )

    def forward(self, t):
        return self.net(t)

def compute_loss(model, t_vals, x1_gt, y1_gt, x2_gt, y2_gt, wd=10, wp=1.0, wi=1.0):
    t_tensor = torch.tensor(t_vals, dtype=torch.float32).view(-1, 1)
    t_tensor.requires_grad_(True)
    preds = model(t_tensor)
    x1, y1, x2, y2 = preds[:, 0:1], preds[:, 1:2], preds[:, 2:3], preds[:, 3:4]
    vx1 = torch.autograd.grad(x1.sum(), t_tensor, create_graph=True)[0]
    vy1 = torch.autograd.grad(y1.sum(), t_tensor, create_graph=True)[0]
    vx2 = torch.autograd.grad(x2.sum(), t_tensor, create_graph=True)[0]
    vy2 = torch.autograd.grad(y2.sum(), t_tensor, create_graph=True)[0]
    ax1 = torch.autograd.grad(vx1.sum(), t_tensor, create_graph=True)[0]
    ay1 = torch.autograd.grad(vy1.sum(), t_tensor, create_graph=True)[0]
    ax2 = torch.autograd.grad(vx2.sum(), t_tensor, create_graph=True)[0]
    ay2 = torch.autograd.grad(vy2.sum(), t_tensor, create_graph=True)[0]
    x1_gt = torch.tensor(x1_gt, dtype=torch.float32).view(-1, 1)
    y1_gt = torch.tensor(y1_gt, dtype=torch.float32).view(-1, 1)
    x2_gt = torch.tensor(x2_gt, dtype=torch.float32).view(-1, 1)
    y2_gt = torch.tensor(y2_gt, dtype=torch.float32).view(-1, 1)
    loss_data = ((x1 - x1_gt)**2 + (y1 - y1_gt)**2 + (x2 - x2_gt)**2 + (y2 - y2_gt)**2).mean()
    dx, dy = x2 - x1, y2 - y1
    r = torch.sqrt(dx**2 + dy**2 + 1e-10)
    ax1_gt = grav_const * mass2 * dx / r**3
    ay1_gt = grav_const * mass2 * dy / r**3
    ax2_gt = grav_const * mass1 * -dx / r**3
    ay2_gt = grav_const * mass1 * -dy / r**3
    loss_phys = ((ax1 - ax1_gt)**2 + (ay1 - ay1_gt)**2 + (ax2 - ax2_gt)**2 + (ay2 - ay2_gt)**2).mean()
    init_loss = ((x1[0] - pos1[0])**2 + (y1[0] - pos1[1])**2 +
                 (x2[0] - pos2[0])**2 + (y2[0] - pos2[1])**2 +
                 (vx1[0] - vel1[0])**2 + (vy1[0] - vel1[1])**2 +
                 (vx2[0] - vel2[0])**2 + (vy2[0] - vel2[1])**2)
    total = wd * loss_data + wp * loss_phys + wi * init_loss
    return total, {"data": loss_data.item(), "physics": loss_phys.item(), "initial": init_loss.item()}

model = GravityNet()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for step in range(30000):
    opt.zero_grad()
    loss, parts = compute_loss(model, train_time, train_x1, train_y1, train_x2, train_y2, wd=7.0)
    loss.backward()
    opt.step()
    if step % 500 == 0:
        print(f"Step {step}: Loss={loss.item():.5f}, Parts={parts}")

def evaluate(model, t_vals):
    model.eval()
    t_tensor = torch.tensor(t_vals, dtype=torch.float32).view(-1, 1)
    t_tensor.requires_grad_(True)
    pred = model(t_tensor)
    x1, y1, x2, y2 = [pred[:, i:i+1] for i in range(4)]
    vx1 = torch.autograd.grad(x1.sum(), t_tensor, create_graph=True)[0]
    vy1 = torch.autograd.grad(y1.sum(), t_tensor, create_graph=True)[0]
    vx2 = torch.autograd.grad(x2.sum(), t_tensor, create_graph=True)[0]
    vy2 = torch.autograd.grad(y2.sum(), t_tensor, create_graph=True)[0]
    ax1 = torch.autograd.grad(vx1.sum(), t_tensor, create_graph=True)[0]
    ay1 = torch.autograd.grad(vy1.sum(), t_tensor, create_graph=True)[0]
    ax2 = torch.autograd.grad(vx2.sum(), t_tensor, create_graph=True)[0]
    ay2 = torch.autograd.grad(vy2.sum(), t_tensor, create_graph=True)[0]
    dx, dy = x2 - x1, y2 - y1
    r = torch.sqrt(dx**2 + dy**2 + 1e-10)
    ax1_gt = grav_const * mass2 * dx / r**3
    ay1_gt = grav_const * mass2 * dy / r**3
    ax2_gt = grav_const * mass1 * -dx / r**3
    ay2_gt = grav_const * mass1 * -dy / r**3
    residuals = torch.abs(ax1 - ax1_gt) + torch.abs(ay1 - ay1_gt) + torch.abs(ax2 - ax2_gt) + torch.abs(ay2 - ay2_gt)
    mean_residual = residuals.mean().item()
    x1_gt = torch.tensor(test_x1, dtype=torch.float32).view(-1, 1)
    y1_gt = torch.tensor(test_y1, dtype=torch.float32).view(-1, 1)
    x2_gt = torch.tensor(test_x2, dtype=torch.float32).view(-1, 1)
    y2_gt = torch.tensor(test_y2, dtype=torch.float32).view(-1, 1)
    mae = torch.mean(torch.abs(x1 - x1_gt) + torch.abs(y1 - y1_gt) + torch.abs(x2 - x2_gt) + torch.abs(y2 - y2_gt))
    print(f"MAE: {mae:.5f}")
    print(f"Mean Residual: {mean_residual:.5f}")
    plt.figure(figsize=(8, 8))
    plt.plot(test_x1, test_y1, 'b-', label='Body 1 True')
    plt.plot(test_x2, test_y2, 'r-', label='Body 2 True')
    plt.plot(x1.detach().numpy(), y1.detach().numpy(), 'b--', label='Body 1 Pred')
    plt.plot(x2.detach().numpy(), y2.detach().numpy(), 'r--', label='Body 2 Pred')
    plt.legend()
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.title("Trajectories: True vs Predicted")
    plt.grid(True)
    plt.axis("equal")
    plt.show()
    return mae.item(), mean_residual

mae_val, res_val = evaluate(model, test_time)
