# Setup for this notebook

1) Poliastro:

```bash
git clone git://github.com/s-m-e/poliastro.git
cd poliastro
git checkout 879f7ab62d05361aff88575bd060d8ff9f880a14
pip install -e .
```

2) Extra packages

```bash
pip install joblib orbitalpy psutil
```

3) Data: 

```bash
wget https://minorplanetcenter.net/Extended_Files/nea_extended.json.gz
gzip -d nea_extended.json.gz
```

# Loading test data: Near Earth Asteroid (NEA) orbits from the MPC

In [1]:
FN = 'nea_extended.json'

import json

from astropy import units as u
from joblib import Parallel, delayed
import numpy as np
from orbital.utilities import true_anomaly_from_mean
import psutil

from poliastro.bodies import Sun

In [2]:
K = Sun.k.to_value(u.km**3 / u.s**2)

def _orbit_from_mpc(body):
    nu = true_anomaly_from_mean(
        e = body['e'],
        M = float((body['M'] * u.deg).to(u.rad).value)
    ) * u.rad
    if not -np.pi * u.rad <= nu < np.pi * u.rad:
        nu = ((nu + np.pi * u.rad) % (2 * np.pi * u.rad) - np.pi * u.rad).to(nu.unit)
    return (
        ((body['a'] * u.AU).to_value(u.km) * (1 - body['e']**2)),  # a
        body['e'],  # ecc
        (body['i'] * u.deg).to_value(u.rad),  # inc
        (body['Node'] * u.deg).to_value(u.rad),  # raan
        (body['Peri'] * u.deg).to_value(u.rad),  # argp
        nu.to_value(u.rad),  # nu
    )

def _read_mpc(fn):
    with open(fn, 'r', encoding = 'utf-8') as f:
        raw = json.load(f)
    return Parallel(n_jobs = psutil.cpu_count(logical = True))(delayed(_orbit_from_mpc)(body) for body in raw)

mpc_orbits = np.array(_read_mpc(FN), dtype = 'f8')

# Import of poliastro's new `jit` infrastructure

In [3]:
from math import cos, sin, sqrt
import os

import numpy as np

# POLIASTRO_TARGET can be set to `cpu`, `parallel` or `cuda`. Default: `cpu`
os.environ['POLIASTRO_TARGET'] = 'cuda'  

# POLIASTRO_INLINE can be set to `always` or `never`. Default: `never`
os.environ['POLIASTRO_INLINE'] = 'never'

from poliastro.core.jit import gjit, hjit, vjit, TARGET, INLINE

print(TARGET, INLINE)  # verfication

cuda never


# Isolated code-path for conversion of classical orbital elements to state vectors

In [4]:
@hjit('M(M)')
def transpose_M_(a):
    return (
        (a[0][0], a[1][0], a[2][0]),
        (a[0][1], a[1][1], a[2][1]),
        (a[0][2], a[1][2], a[2][2]),
    )

@hjit('M(M,M)')
def matmul_MM_(a, b):
    return (
        (
            a[0][0] * b[0][0] + a[0][1] * b[1][0] + a[0][2] * b[2][0],
            a[0][0] * b[0][1] + a[0][1] * b[1][1] + a[0][2] * b[2][1],
            a[0][0] * b[0][2] + a[0][1] * b[1][2] + a[0][2] * b[2][2],
        ),
        (
            a[1][0] * b[0][0] + a[1][1] * b[1][0] + a[1][2] * b[2][0],
            a[1][0] * b[0][1] + a[1][1] * b[1][1] + a[1][2] * b[2][1],
            a[1][0] * b[0][2] + a[1][1] * b[1][2] + a[1][2] * b[2][2],
        ),
        (
            a[2][0] * b[0][0] + a[2][1] * b[1][0] + a[2][2] * b[2][0],
            a[2][0] * b[0][1] + a[2][1] * b[1][1] + a[2][2] * b[2][1],
            a[2][0] * b[0][2] + a[2][1] * b[1][2] + a[2][2] * b[2][2],
        ),
    )

@hjit('V(V,M)')
def matmul_VM_(a, b):
    return (
        a[0] * b[0][0] + a[1] * b[1][0] + a[2] * b[2][0],
        a[0] * b[0][1] + a[1] * b[1][1] + a[2] * b[2][1],
        a[0] * b[0][2] + a[1] * b[1][2] + a[2] * b[2][2],
    )

@hjit('M(f,u1)')
def rotation_matrix_(angle, axis):
    c = cos(angle)
    s = sin(angle)
    if axis == 0:
        return (
            (1.0, 0.0, 0.0),
            (0.0,   c,  -s),
            (0.0,   s,   c),
        )
    if axis == 1:
        return (
            (  c, 0.0,   s),
            (0.0, 1.0, 0.0),
            (  s, 0.0,   c),
        )
    if axis == 2:
        return (
            (  c,  -s, 0.0),
            (  s,   c, 0.0),
            (0.0, 0.0, 1.0),
        )
    raise ValueError("Invalid axis: must be one of 0, 1 or 2")

@hjit('Tuple([V,V])(f,f,f,f)')
def rv_pqw_(k, p, ecc, nu):
    sinnu = sin(nu)
    cosnu = cos(nu)
    a = p / (1 + ecc * cosnu)
    b = sqrt(k / p)
    return (
        (cosnu * a, sinnu * a, 0),
        (-sinnu * b, (ecc + cosnu) * b, 0),
    )

@hjit('M(f,f,f)')
def coe_rotation_matrix_(inc, raan, argp):
    r = rotation_matrix_(raan, 2)
    r = matmul_MM_(r, rotation_matrix_(inc, 0))
    return matmul_MM_(r, rotation_matrix_(argp, 2))

@hjit('Tuple([V,V])(f,f,f,f,f,f,f)')
def coe2rv_(k, p, ecc, inc, raan, argp, nu):
    "Converts from classical orbital to state vectors"
    r, v = rv_pqw_(k, p, ecc, nu)
    rm = transpose_M_(coe_rotation_matrix_(inc, raan, argp))
    return matmul_VM_(r, rm), matmul_VM_(v, rm)

# Testing for single orbit/state

In [5]:
if TARGET != 'cuda':
    print(coe2rv_(K, *mpc_orbits[0, :]))

# Testing for **array** of orbits/states

In [6]:
_rv = np.zeros((mpc_orbits.shape[0], 6), dtype = 'f8')

r = _rv[:, :3]  # view
v = _rv[:, 3:]  # view

for item in (mpc_orbits, r, v, K):
    print(item.dtype, item.flags.c_contiguous, item.ndim, item.shape)

float64 True 2 (20250, 6)
float64 False 2 (20250, 3)
float64 False 2 (20250, 3)
float64 True 0 ()


In [7]:
@gjit(
    'void(f,f[:],f[:])',
    '(),(n)->(n)',
)
def coe2rv(k, cl, sv):
    """
    Converts from classical orbital elements to state vectors ON ARRAYS
    cl[0...5] : p, ecc, inc, raan, argp
    sv[0...5] : rx, ry, rz, vx, vy, vz
    """
    (sv[0], sv[1], sv[2]), (sv[3], sv[4], sv[5]) = coe2rv_(k, cl[0], cl[1], cl[2], cl[3], cl[4], cl[5])

coe2rv(K, mpc_orbits, _rv)
r, v



(array([[-8.22172321e+07, -2.20850539e+08, -3.67986141e+07],
        [-1.68505479e+08,  3.97092826e+08, -8.34073051e+07],
        [ 2.28586647e+08, -5.12990316e+08, -5.81423124e+06],
        ...,
        [ 3.89940003e+07, -1.29757595e+08,  5.41316389e+07],
        [-1.25327990e+08, -1.72807704e+08, -4.90166215e+07],
        [ 1.00536243e+08,  3.29143864e+08,  1.21615287e+08]]),
 array([[ 18.78758291, -12.1967396 ,   1.65384504],
        [-16.1191056 ,   2.73889356,  -0.78254845],
        [  8.01970316,   6.87396875,  -1.63941873],
        ...,
        [ 30.09556541,   5.42495671,   5.21728375],
        [ 17.78412559,  -3.92195834,  -1.4502226 ],
        [-16.22194191,  -5.3500318 ,   4.5105067 ]]))

In [8]:
print(r[0,:], v[0,:])

[-8.22172321e+07 -2.20850539e+08 -3.67986141e+07] [ 18.78758291 -12.1967396    1.65384504]
