In [None]:
# in order to get reverse integration, sf should be smalelr than s0
def ABM_aug(f, x0, sf, s0, theta, aug=1):
    # Constants
    a = 34869261  # m
    mu = 3.9860064E+14  # m^3/s^2
    # Orbit period calculated through Kepler's Third Law
    T = np.sqrt(a**3 * (4 * np.pi**2 / mu))
    ########## FIXXXX
    ds = (sf - s0) / 10000 # timestep is 10000
    print("Step size (ds):", ds)

    # Calculate number of steps based on the given time interval and step size
    s = np.arange(s0, sf, ds)
    ns = len(s)  # Get the exact number of elements in s

    # Ensure x0 is a torch tensor with the correct shape
    if not isinstance(x0, torch.Tensor):
        x0 = torch.tensor(x0, dtype=torch.float32)
    if x0.dim() == 1:
        x0 = x0.unsqueeze(0)  # Ensures x0 is [1, 6] if it's provided as [6]

    # Initialize the tensor to store the simulation results
    x = torch.zeros((ns, x0.size(1)), dtype=x0.dtype, device=x0.device)
    # print("x0",x0)
    # Set initial state
    x[0, :] = x0.squeeze()  # Make sure x0 is squeezed to [6]
    # First initialize with an RK4 step for stability in starting the integration
    for k in range(3):
        if k + 1 < ns:
            k1 = ds * f(s[k], x[k, :],theta,aug)
            k2 = ds * f(s[k] + ds/2, x[k, :] + k1/2,theta,aug)
            k3 = ds * f(s[k] + ds/2, x[k, :] + k2/2,theta,aug)
            k4 = ds * f(s[k] + ds, x[k, :] + k3,theta,aug)
            dx = (k1 + 2*k2 + 2*k3 + k4) / 6
            x[k + 1, :] = x[k, :] + dx

    # ABM integration
    for k in range(3, ns - 1):
        if k - 3 >= 0:  # Make sure indices don't go out of bounds
            f_m3 = f(s[k-3], x[k-3, :], theta,aug)
            f_m2 = f(s[k-2], x[k-2, :], theta,aug)
            f_m1 = f(s[k-1], x[k-1, :], theta,aug)
            f_0 = f(s[k], x[k, :], theta, aug)

            # Predictor
            dx = (ds/24) * (55 * f_0 - 59 * f_m1 + 37 * f_m2 - 9 * f_m3)
            x[k + 1, :] = x[k, :] + dx

            # Evaluate at the predicted next step (ensure not at the last step)
            if k + 1 < ns - 1:
                f_p1 = f(s[k + 1], x[k + 1, :],theta,aug)
                # Corrector
                dx = (ds/24) * (9 * f_p1 + 19 * f_0 - 5 * f_m1 + f_m2)
                x[k + 1, :] = x[k, :] + dx

    # Return the results

    return x