# Import libs

In [1]:
import time
from typing import List, Dict

import scipy.integrate
import autograd.numpy as np

import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.datasets
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

#import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

# ODE

This step was to solve the ODE. Here, an implemetation of the ABM method for integration was used to solve the differential equation in the form of $ \frac{d\vec{z}(t)}{dt} = f(\vec{z}(t), t, \theta)$.

## State vector derivative ($f = \frac{dz}{dt}$)

Here we define the derivative of the state vector:

In [2]:
def X_ddot(t, X):
    '''
    Returns the derivative of the state vector X
    ----------
    Arguments:
        t {float} -- time, in seconds
        X {np.array} -- state vector=(x, y, z, vx, vy, vz)
    ----------
    Returns:
        (6,1) np.array -- (xdot, ydot, zdot, vxdot, vydot, vzdot)
    '''
    
    x_dot = X[3:]
    v_dot = -mu*X[:3]/np.linalg.norm(X.detach().numpy())**3 # simply the acceleration
    X_dot_dot = np.concatenate((x_dot.detach().numpy(), v_dot.detach().numpy()), axis=None)

    X_dot_dot = torch.tensor(X_dot_dot, dtype=torch.float32, requires_grad=True)
    
    return X_dot_dot

## f derivative ($\frac{df}{dt}$)

In [None]:
def X_dddot(t, X, mu):
    '''
    Calculates the second derivative of the gravitational velocity and the first derivative of the acceleration
    ----------
    Arguments:
        t {float} -- time, in seconds
        X {torch.tensor} -- state vector = (x, y, z, vx, vy, vz), requires_grad should be True
        mu {float} -- standard gravitational parameter (GM)
    ----------
    Returns:
        torch.tensor -- (vx_dot, vy_dot, vz_dot, ax_dot, ay_dot, az_dot)
    '''

    # Ensure X is a tensor with gradient computation enabled
    if not X.requires_grad:
        X = X.clone().detach().requires_grad_(True)

    # Position (x, y, z) and velocity (vx, vy, vz)
    position = X[:3]
    velocity = X[3:]

    # Compute gravitational acceleration
    r = torch.norm(position)
    acceleration = -mu * position / r**3

    # Concatenate velocity and acceleration to form the state derivative
    state_derivative = torch.cat((velocity, acceleration))

    # Use automatic differentiation to compute the derivative of the state derivative
    state_derivative_grad = torch.autograd.grad(outputs=state_derivative, inputs=X,
                                                grad_outputs=torch.ones_like(state_derivative),
                                                create_graph=True)[0]

    return state_derivative_grad

## ODE Solver (RK4 & ABM)

In [3]:
def RK4(f, x0, tf, dt, t0=0):
    # Time vector
    t = np.arange(t0, tf, dt)
    nt = t.size
    
    # Constructing final vector
    nx = x0.size
    x = np.zeros((nx, nt))
    
    # Initial conditions
    x[:,0] = x0
    
    for k in range(nt-1):
        # At the kth time step, each parameter k_n is computed
        k1 = dt*f(t[k], x[:, k])
        k2 = dt*f(t[k] + dt/2, x[:, k] + k1/2)
        k3 = dt*f(t[k] + dt/2, x[:, k] + k2/2)
        k4 = dt*f(t[k] + dt, x[:, k] + k3)
        
        # k_ns are used to calculate dx
        dx = (k1 + 2*k2 + 2* k3 + k4)/6
        
        #dx is used to calculate x at the next time step
        x[:, k+1] = x[:, k] + dx
        
    return x, t

In [None]:
def ABM_aug(f, x0, sf, s0,theta,aug=1):
    # Constants
    a = 600000  # m
    mu = 3.9860064E+14  # m^3/s^2
    aug = 0
    # Orbit period calculated through Kepler's Third Law
    T = np.sqrt(a**3 * (4 * np.pi**2 / mu))
    ds = T/4000
    #print("Step size (ds):", ds)

    # Calculate number of steps based on the given time interval and step size
    s = np.arange(s0, sf, ds)
    ns = len(s)  # Get the exact number of elements in s

    # Ensure x0 is a torch tensor with the correct shape
    if not isinstance(x0, torch.Tensor):
        x0 = torch.tensor(x0, dtype=torch.float32)
    if x0.dim() == 1:
        x0 = x0.unsqueeze(0)  # Ensures x0 is [1, 6] if it's provided as [6]

    # Initialize the tensor to store the simulation results
    x = torch.zeros((ns, x0.size(1)), dtype=x0.dtype, device=x0.device)
    # print("x0",x0)
    # Set initial state
    x[0, :] = x0.squeeze()  # Make sure x0 is squeezed to [6]
    # First initialize with an RK4 step for stability in starting the integration
    for k in range(3):
        if k + 1 < ns:
            k1 = ds * f(s[k], x[k, :],theta,aug)
            k2 = ds * f(s[k] + ds/2, x[k, :] + k1/2,theta,aug)
            k3 = ds * f(s[k] + ds/2, x[k, :] + k2/2,theta,aug)
            k4 = ds * f(s[k] + ds, x[k, :] + k3,theta,aug)
            dx = (k1 + 2*k2 + 2*k3 + k4) / 6
            x[k + 1, :] = x[k, :] + dx

    # ABM integration
    for k in range(3, ns - 1):
        if k - 3 >= 0:  # Make sure indices don't go out of bounds
            f_m3 = f(s[k-3], x[k-3, :],theta,aug)
            f_m2 = f(s[k-2], x[k-2, :],theta,aug)
            f_m1 = f(s[k-1], x[k-1, :],theta,aug)
            f_0 = f(s[k], x[k, :],theta,aug)

            # Predictor
            dx = (ds/24) * (55 * f_0 - 59 * f_m1 + 37 * f_m2 - 9 * f_m3)
            x[k + 1, :] = x[k, :] + dx

            # Evaluate at the predicted next step (ensure not at the last step)
            if k + 1 < ns - 1:
                f_p1 = f(s[k + 1], x[k + 1, :],theta,aug)
                # Corrector
                dx = (ds/24) * (9 * f_p1 + 19 * f_0 - 5 * f_m1 + f_m2)
                x[k + 1, :] = x[k, :] + dx

    # Return the results

    return x

In [4]:
def ABM(f, x0, sf, ds, s0=0):
    # Time vector
    s = np.arange(s0, sf, ds)
    ns = s.size
    
    # Constructing final vector
    nx = x0.size
    x = np.zeros((nx, ns))
    
    # Initial conditions
    x[:,0] = x0
    
    # First we initialize with an RK4:
    for k in range(0, 3):
        # At the kth time step, each parameter k_n is computed
        k1 = ds*f(s[k], x[:, k])
        k2 = ds*f(s[k] + ds/2, x[:, k] + k1/2)
        k3 = ds*f(s[k] + ds/2, x[:, k] + k2/2)
        k4 = ds*f(s[k] + ds, x[:, k] + k3)
        
        # k_ns are used to calculate dx
        dx = (k1 + 2*k2 + 2* k3 + k4)/6
        
        #dx is used to calculate x at the next time step
        x[:, k+1] = x[:, k] + dx
    
    # Proceeding to the ABM integration:
    for k in range(3, ns-1):
        f_m3 = f(s[k-3], x[:, k-3]) #f_{n-3}
        f_m2 = f(s[k-2], x[:, k-2]) #f_{n-2}
        f_m1 = f(s[k-1], x[:, k-1]) #f_{n-1}
        f_0 = f(s[k], x[:, k]) #f_{n}
        
        ### Predictor ###
        dx = (ds/24) * (55*f_0 - 59*f_m1 + 37*f_m2 - 9*f_m3)
        x[:, k+1] = x[:, k] + dx
        
        f_p1 = f(s[k+1], x[:, k+1]) #f_{n+1}
        
        ### Corrector ###
        dx = (ds/24) * (9*f_p1 + 19*f_0 - 5*f_m1 + f_m2)
        #dx is used to calculate x at the next time step
        x[:, k+1] = x[:, k] + dx
        
    return x

# Augmented State Dynamics

Next, we defined the dynamics of the augmented state.

## Augmented dynamics


The next function, `aug_dynamics`, is used to calculate how the augmented state $\vec{s}(t) = [\vec{z}(t), \vec{a}(t), \frac{\partial L}{\partial\theta}, \frac{\partial L}{\partial t}]$ evolves over time. In other words, it returns $\frac{d\vec{s}(t)}{dt}$.

This new version uses as inputs the gradient of the function f (parameterized by the NN described previously) with respect to z, t and $\theta$.

In [49]:
def aug_dynamics_new(z, t, theta, a, module, df_dz, df_dt, df_dtheta):
  '''
  Defines dynamics of augmented state.
  ---
  Inputs:
    z: np.array
      Hidden state
    t: float
      Time
    theta: ##type##
      Dynamic parameters
    a: ##type##
      Adjoint, a
    delL_deltheta: ##type##
      Derivative of loss wrt theta
    delL_delt: ##type##
      Derivative of loss wrt t
    module: function
      Function f

  ---

  Returns:
    delz_delt: np.array
      Time derivative of state, z

    dela_delt: ##[type]##
      Time derivative of adjoint, a

    deldelL_deltheta_delt: ##[type]##
      Time derivative of loss gradient wrt dynamic parameters, delL_deltheta

    delL_delt: ##[type]##
      Time derivative of loss, L

  '''
  # Build augmented state: [z, a, delL_deltheta, delL_delt]
  s = [z, a, _, _] # delL_deltheta and delL_delt are missing since we don't have them yet! But we won`t need them just yet

  # Calculate f itself (delz_delt)
  delz_delt = module(z, t)

  # Time derivative of adjoint
  dela_delt = -a.T@df_dz # Vector-jacobian product

  # Time derivative of loss gradient
  deldelL_deltheta_delt = -a.T@df_dtheta # Vector-jacobian product

  # Time derivative of loss
  delL_delt = -a.T@df_dt # Vector-jacobian product

  return (delz_delt, dela_delt, deldelL_deltheta_delt, delL_delt)

In [6]:
def aug_dynamics(z, t, theta, a, module):
  '''
  Defines dynamics of augmented state.
  ---
  Inputs:
    z: np.array
      Hidden state
    t: float
      Time
    theta: ##type##
      Dynamic parameters
    a: ##type##
      Adjoint, a
    delL_deltheta: ##type##
      Derivative of loss wrt theta
    delL_delt: ##type##
      Derivative of loss wrt t
    module: function
      Function f

  ---

  Returns:
    delz_delt: np.array
      Time derivative of state, z

    dela_delt: ##[type]##
      Time derivative of adjoint, a

    deldelL_deltheta_delt: ##[type]##
      Time derivative of loss gradient wrt dynamic parameters, delL_deltheta

    delL_delt: ##[type]##
      Time derivative of loss, L

  '''
  # Build augmented state: [z, a, delL_deltheta, delL_delt]
  s = [z, a, _, _] # delL_deltheta and delL_delt are missing since we don't have them yet! But we won`t need them just yet

  # Calculate derivatives of f, as well as f itself (delz_delt)
  delf_delz, delf_delt, delf_deltheta, delz_delt = grad_f(z, t, theta, a, module)

  # Time derivative of adjoint
  dela_delt = -a.T@delf_delz # Vector-jacobian product

  # Time derivative of loss gradient
  deldelL_deltheta_delt = -a.T@delf_deltheta # Vector-jacobian product

  # Time derivative of loss
  delL_delt = -a.T@delf_delt # Vector-jacobian product

  return (delz_delt, dela_delt, deldelL_deltheta_delt, delL_delt)

# Testing

## Initial conditions

The initial conditions are as follows:

In [7]:
x = 1888980.04103698 #m
y = 6652209.67475597 #m
z = 902482.883545056 #m
v_x = -9585.79511076297 #m/s
v_y = 2413.57051166562 #m/s
v_z = 2273.50409709003 #m/s

x_vec = np.array([x,y,z])
v_vec = np.array([v_x,v_y,v_z])

a = 34869261 #m
mu = 3.9860064E+14 #𝑚3/𝑠2
R = 6378139 #m

# Orbit period is calculated through Kepler's Third Law:
T = np.sqrt(a**3*(4*np.pi**2/mu)) #s
delta_t = T/400

## NN model

In [43]:
# Define a simple model
model = nn.Linear(in_features=6, out_features=6)  # Change architecture as necessary

# Example input and target
state = torch.tensor(x0, dtype=torch.float32, requires_grad=True)
#state = torch.randn(1, 6, requires_grad=True)  # State must require_grad to compute gradients w.r.t. it
target = torch.randn(1, 6)

# Loss function
criterion = nn.MSELoss()

# Forward pass
prediction = model(state)

# Compute loss
loss = criterion(prediction, target)

# Backward pass
loss.backward()

# Gradient of the loss with respect to the state
delL_delz1 = state.grad

In [44]:
# Define (initial) state
z = state

# Define time tensor
t = np.arange(t0, tf, dt)
t = torch.tensor(t, dtype=torch.float32, requires_grad=True)

# Define (initial) adjunct
a = delL_delz1 # Gradient of the loss with respect to the state

# Define theta collecting all parameters into a single vector
theta_vector = torch.cat([p.view(-1) for p in model.parameters() if p.requires_grad])
theta = theta_vector

# Define function that goes inside the ODE solver
module = X_ddot

In [70]:
# Forward pass through the model
f = model(z)

# Create a tensor of ones, which will be the initial gradient for each component of f
grad_output = torch.ones_like(f)

# Now call backward with this grad_output
f.backward(gradient=grad_output)

###### Extracting gradients
### df_dz
df_dz = z.grad

### df_dt
df_dt = X_dddot(t, z, mu)

### df_dtheta
# Collect all parameter gradients
grads = [p.grad for p in model.parameters()]

# Check if the gradients exist and are in the expected form
if grads[0] is not None and grads[1] is not None:
    # Assuming grads[0] is the weight matrix and grads[1] is the bias vector
    # We reshape the bias to have the same number of columns as the weight matrix and a single row
    reshaped_bias_grad = grads[1].view(1, -1)
    
    # Concatenate along the rows
    df_dtheta = torch.cat((grads[0], reshaped_bias_grad), dim=0).T
else:
    print("Gradients are not available")

print("Gradient of the output with respect to the input z:", df_dz)
print("Gradient df/dt:", df_dt)
print("Gradients df/dtheta:", df_dtheta)

Gradient of the output with respect to the input z: tensor([ 908430.5000, 1511378.5000,  735810.2500, -784080.0000, -111573.0078,
         674891.4375])
Gradient df/dt: tensor([ 1.1797e-07,  3.3789e-06, -5.5738e-07,  1.0000e+00,  1.0000e+00,
         1.0000e+00], grad_fn=<AddBackward0>)
Gradients df/dtheta: tensor([[-2.2396e+12, -1.7759e+12, -2.0931e+12, -2.1508e+11, -6.4707e+11,
         -1.3356e+12, -1.1856e+06],
        [-7.8870e+12, -6.2540e+12, -7.3710e+12, -7.5741e+11, -2.2787e+12,
         -4.7036e+12, -9.4013e+05],
        [-1.0700e+12, -8.4845e+11, -1.0000e+12, -1.0275e+11, -3.0915e+11,
         -6.3812e+11, -1.1081e+06],
        [ 1.1365e+10,  9.0119e+09,  1.0622e+10,  1.0914e+09,  3.2836e+09,
          6.7779e+09, -1.1386e+05],
        [-2.8616e+09, -2.2691e+09, -2.6744e+09, -2.7480e+08, -8.2677e+08,
         -1.7066e+09, -3.4255e+05],
        [-2.6955e+09, -2.1374e+09, -2.5192e+09, -2.5886e+08, -7.7879e+08,
         -1.6075e+09, -7.0707e+05]])


In [71]:
aug_dynamics_new(z, t, theta, a, module, df_dz, df_dt, df_dtheta)

(tensor([ 4.8600e+02,  6.4800e+02,  8.1000e+02,  9.7200e+02,  1.1340e+03,
          1.2960e+03,  1.4580e+03,  1.6200e+03,  1.7820e+03,  1.9440e+03,
          2.1060e+03,  2.2680e+03,  2.4300e+03,  2.5920e+03,  2.7540e+03,
          2.9160e+03,  3.0780e+03,  3.2400e+03,  3.4020e+03,  3.5640e+03,
          3.7260e+03,  3.8880e+03,  4.0500e+03,  4.2120e+03,  4.3740e+03,
          4.5360e+03,  4.6980e+03,  4.8600e+03,  5.0220e+03,  5.1840e+03,
          5.3460e+03,  5.5080e+03,  5.6700e+03,  5.8320e+03,  5.9940e+03,
          6.1560e+03,  6.3180e+03,  6.4800e+03,  6.6420e+03,  6.8040e+03,
          6.9660e+03,  7.1280e+03,  7.2900e+03,  7.4520e+03,  7.6140e+03,
          7.7760e+03,  7.9380e+03,  8.1000e+03,  8.2620e+03,  8.4240e+03,
          8.5860e+03,  8.7480e+03,  8.9100e+03,  9.0720e+03,  9.2340e+03,
          9.3960e+03,  9.5580e+03,  9.7200e+03,  9.8820e+03,  1.0044e+04,
          1.0206e+04,  1.0368e+04,  1.0530e+04,  1.0692e+04,  1.0854e+04,
          1.1016e+04,  1.1178e+04,  1.