# Import

In [1]:
from datetime import datetime, timedelta
from math import atan2, sqrt
from random import random
from typing import Union

import astropy.units as u
from astropy.time import Time, TimeDelta
import numpy as np
from typeguard import typechecked

# Data

In [2]:
START = Time('2022-01-01 00:00', scale = 'ut1')
STOP = Time('2022-01-10 00:00', scale = 'ut1')
STEP_LENGTH = 1 << u.d

TIMES = Time(np.arange(START, STOP, STEP_LENGTH))
TIMES

<Time object: scale='ut1' format='iso' value=['2022-01-01 00:00:00.000' '2022-01-02 00:00:00.000'
 '2022-01-03 00:00:00.000' '2022-01-04 00:00:00.000'
 '2022-01-05 00:00:00.000' '2022-01-06 00:00:00.000'
 '2022-01-07 00:00:00.000' '2022-01-08 00:00:00.000'
 '2022-01-09 00:00:00.000']>

In [3]:
FACTOR = 10
POINTS_LENGTH = 5

POINTS = [
    {'x': random() * FACTOR << u.km, 'y': random() * FACTOR << u.km}
    for _ in range(POINTS_LENGTH)
]
POINTS

[{'x': <Quantity 6.25350266 km>, 'y': <Quantity 3.78524177 km>},
 {'x': <Quantity 3.35121602 km>, 'y': <Quantity 2.43147868 km>},
 {'x': <Quantity 9.44068878 km>, 'y': <Quantity 8.95284076 km>},
 {'x': <Quantity 0.39916973 km>, 'y': <Quantity 8.15307567 km>},
 {'x': <Quantity 0.92443308 km>, 'y': <Quantity 3.70176948 km>}]

# PoC

In [4]:
@typechecked
class State:
    
    @u.quantity_input(x = u.km, y = u.km)
    def __init__(self, epoch: Time, x: u.Quantity, y: u.Quantity):
        assert epoch.ndim == x.ndim == y.ndim == 0
        self._epoch = epoch
        self._x = x
        self._y = y
    
    def __repr__(self) -> str:
        return f'<State epoch={self._epoch} x={self._x} y={self._y}>'
    
    def __eq__(self, other) -> bool:
        return self._epoch == other.epoch and self._x == other.x and self._y == other.y
    
    @property
    def epoch(self) -> Time:
        return self._epoch
    
    @epoch.setter
    def epoch(self, value: Time):
        assert value.ndim == 0
        self._epoch = value
    
    @property
    def x(self) -> u.Quantity:
        return self._x
    
    @x.setter
    @u.quantity_input(value = u.km)
    def x(self, value: u.Quantity):
        assert value.ndim == 0
        self._x = value
    
    @property
    def y(self) -> u.Quantity:
        return self._y
    
    @y.setter
    @u.quantity_input(value = u.km)
    def y(self, value: u.Quantity):
        assert value.ndim == 0
        self._y = value
    
    def to_polar(self) -> tuple[u.Quantity, u.Quantity]:
        return (
            sqrt(self._x.to_value(u.km) ** 2 + self._y.to_value(u.km) ** 2) << u.km,
            atan2(self._y.to_value(u.km), self._x.to_value(u.km)) << u.rad,
        )

In [5]:
a = State(Time('2022-01-01 00:00', scale = 'ut1'), 2.0 << u.km, 3.0 << u.km)
print(a)
print(a.to_polar())

<State epoch=2022-01-01 00:00:00.000 x=2.0 km y=3.0 km>
(<Quantity 3.60555128 km>, <Quantity 0.98279372 rad>)


In [23]:
@typechecked
class Orbit:

    def __init__(self, state: State):
        self._state = state

    def __repr__(self) -> str:
        return f'<Orbit epoch={self._state.epoch} x={self._state.x} y={self._state.y}>'

    def propagate(self, timedelta: TimeDelta, inplace: bool = False) -> 'Union[Orbit, OrbitArray]':
        
        if timedelta.ndim == 0:
        
            days = timedelta.to_value(u.d)
            dx = (random() - 0.5) * days << u.km
            dy = (random() - 0.5) * days << u.km
            
            if inplace:
                self._state.epoch += timedelta
                self._state.x += dx
                self._state.y += dy
                return self
            else:
                return type(self)(State(
                    epoch = self._state.epoch + timedelta,
                    x = self._state.x + dx,
                    y = self._state.y + dy,
                ))
        
        else:
            
            assert not inplace
            
            epoch = None

    @property
    def state(self) -> State:
        return self._state

In [24]:
a = Orbit(State(Time('2022-01-01 00:00', scale = 'ut1'), 2.0 << u.km, 3.0 << u.km))
print(a)
b = a.propagate(TimeDelta(1 << u.d), inplace = False)
print(b)
print(a is b)
print(a == b)

<Orbit epoch=2022-01-01 00:00:00.000 x=2.0 km y=3.0 km>
<Orbit epoch=2022-01-02 00:00:00.000 x=1.794312925825925 km y=2.719979738636332 km>
False
False


In [25]:
c = a.propagate(TimeDelta(1 << u.d), inplace = True)
print(c)
print(a is c)
print(a == c)

<Orbit epoch=2022-01-02 00:00:00.000 x=1.8764727077406333 km y=2.710528540562871 km>
True
True


In [8]:
@typechecked
class StateArray:

    @u.quantity_input(x = u.km, y = u.km)
    def __init__(self, epoch: Time, x: u.Quantity, y: u.Quantity):
        assert epoch.shape == x.shape == y.shape
        self._epoch = epoch
        self._x = x
        self._y = y

    def __repr__(self) -> str:
        return  (
            f'<StateArray shape={self.shape} value=[\n'
            + '\n'.join([
                f' (epoch={state.epoch} x={state.x} y={state.y}),'
                for state in self.reshape(self.size)
            ])
            + '\n]>'
        )
    
    def __getitem__(self, idx) -> 'Union[StateArray, State]':
        target = type(self)(
            epoch = self._epoch[idx],
            x = self._x[idx],
            y = self._y[idx],
        )
        if np.squeeze(target.epoch).ndim == 0 and (
            isinstance(idx, int) or (
                isinstance(idx, tuple) and all(isinstance(item, int) for item in idx)
            )
        ):
            return State(
                epoch = target.epoch,
                x = target.x,
                y = target.y,
            )
        return target

    def reshape(self, *args) -> 'StateArray':
        return type(self)(
            epoch = self._epoch.reshape(*args),
            x = self._x.reshape(*args),
            y = self._y.reshape(*args),
        )

    @property
    def epoch(self) -> Time:
        return self._epoch

    @property
    def x(self) -> u.Quantity:
        return self._x

    @property
    def y(self) -> u.Quantity:
        return self._y

    @property
    def ndim(self):
        return self._epoch.ndim

    @property
    def size(self):
        return self._epoch.size

    @property
    def shape(self):
        return self._epoch.shape

    def to_polar(self) -> tuple[u.Quantity, u.Quantity]:
        return np.sqrt(self._x ** 2 + self._y ** 2), np.arctan2(self._y, self._x)

    @classmethod
    def from_states(cls, states):
        return cls(
            epoch = Time([state.epoch for state in states]),
            x = u.Quantity([state.x for state in states], u.km),
            y = u.Quantity([state.y for state in states], u.km),
        )

In [9]:
states = [
    State(time, **position)
    for time, position in zip(TIMES, POINTS)
]
states

[<State epoch=2022-01-01 00:00:00.000 x=6.253502657298325 km y=3.7852417656020876 km>,
 <State epoch=2022-01-02 00:00:00.000 x=3.3512160233674617 km y=2.4314786801575385 km>,
 <State epoch=2022-01-03 00:00:00.000 x=9.440688784123175 km y=8.952840758946511 km>,
 <State epoch=2022-01-04 00:00:00.000 x=0.39916973060293404 km y=8.153075673625192 km>,
 <State epoch=2022-01-05 00:00:00.000 x=0.9244330754053787 km y=3.701769484864422 km>]

In [10]:
statearray = StateArray.from_states(states)
a = statearray[:4].reshape(2, 2)
a

<StateArray shape=(2, 2) value=[
 (epoch=2022-01-01 00:00:00.000 x=6.253502657298325 km y=3.7852417656020876 km),
 (epoch=2022-01-02 00:00:00.000 x=3.3512160233674617 km y=2.4314786801575385 km),
 (epoch=2022-01-03 00:00:00.000 x=9.440688784123175 km y=8.952840758946511 km),
 (epoch=2022-01-04 00:00:00.000 x=0.39916973060293404 km y=8.153075673625192 km),
]>

In [11]:
a[1, 1]

<State epoch=2022-01-04 00:00:00.000 x=0.39916973060293404 km y=8.153075673625192 km>

In [12]:
statearray.to_polar()

(<Quantity [ 7.30988035,  4.1403789 , 13.01076333,  8.16284138,  3.81545198] km>,
 <Quantity [0.54430705, 0.62766937, 0.75888157, 1.52187598, 1.32607429] rad>)

In [13]:
@typechecked
class OrbitArray:
    
    def __init__(self, statearray: StateArray):
        self._statearray = statearray
    
    def __repr__(self) -> str:
        return  (
            f'<OrbitArray shape={self.shape} value=[\n'
            + '\n'.join([
                f' (epoch={state.epoch} x={state.x} y={state.y}),'
                for state in self._statearray.reshape(np.multiply.reduce(self._statearray.shape))
            ])
            + '\n]>'
        )
    
    def __getitem__(self, idx) -> 'Union[OrbitArray, Orbit]':
        target = self._statearray[idx]
        if isinstance(target, State):
            return Orbit(state = target)
        return OrbitArray(statearray = target)

    def propagate(self, timedelta: TimeDelta, inplace: bool = False) -> 'OrbitArray':
        
        if timedelta.ndim == 0:
            timedelta = np.repeat(timedelta.to_value(u.d), self.size).reshape(self.shape) << u.d
        else: # TODO allow better broadcasting logic
            assert timedelta.shape == self.shape
        
        days = timedelta.to_value(u.d)
        dx = ((np.random.random(self.size) - 0.5) << u.km).reshape(self.shape) * days
        dy = ((np.random.random(self.size) - 0.5) << u.km).reshape(self.shape) * days
        
        if inplace:
            self._statearray.epoch[:] += timedelta
            self._statearray.x[:] += dx
            self._statearray.y[:] += dy
            return self
        else:
            return type(self)(StateArray(
                epoch = self._statearray.epoch + timedelta,
                x = self._statearray.x + dx,
                y = self._statearray.y + dy,
            ))

    def reshape(self, *args) -> 'OrbitArray':
        return type(self)(statearray = self._statearray.reshape(*args))

    @property
    def statearray(self) -> StateArray:
        return self._statearray

    @property
    def ndim(self):
        return self._statearray.ndim

    @property
    def shape(self):
        return self._statearray.shape

    @property
    def size(self):
        return self._statearray.size

    @classmethod
    def from_orbits(cls, orbits):
        return cls(
            statearray = StateArray.from_states([orbit.state for orbit in orbits]),
        )

In [14]:
a = OrbitArray(statearray)
a

<OrbitArray shape=(5,) value=[
 (epoch=2022-01-01 00:00:00.000 x=6.253502657298325 km y=3.7852417656020876 km),
 (epoch=2022-01-02 00:00:00.000 x=3.3512160233674617 km y=2.4314786801575385 km),
 (epoch=2022-01-03 00:00:00.000 x=9.440688784123175 km y=8.952840758946511 km),
 (epoch=2022-01-04 00:00:00.000 x=0.39916973060293404 km y=8.153075673625192 km),
 (epoch=2022-01-05 00:00:00.000 x=0.9244330754053787 km y=3.701769484864422 km),
]>

In [15]:
b = a[:4].reshape(2, 2)
b

<OrbitArray shape=(2, 2) value=[
 (epoch=2022-01-01 00:00:00.000 x=6.253502657298325 km y=3.7852417656020876 km),
 (epoch=2022-01-02 00:00:00.000 x=3.3512160233674617 km y=2.4314786801575385 km),
 (epoch=2022-01-03 00:00:00.000 x=9.440688784123175 km y=8.952840758946511 km),
 (epoch=2022-01-04 00:00:00.000 x=0.39916973060293404 km y=8.153075673625192 km),
]>

In [16]:
c = b.propagate(TimeDelta(7 << u.d))
b, c

(<OrbitArray shape=(2, 2) value=[
  (epoch=2022-01-01 00:00:00.000 x=6.253502657298325 km y=3.7852417656020876 km),
  (epoch=2022-01-02 00:00:00.000 x=3.3512160233674617 km y=2.4314786801575385 km),
  (epoch=2022-01-03 00:00:00.000 x=9.440688784123175 km y=8.952840758946511 km),
  (epoch=2022-01-04 00:00:00.000 x=0.39916973060293404 km y=8.153075673625192 km),
 ]>,
 <OrbitArray shape=(2, 2) value=[
  (epoch=2022-01-08 00:00:00.000 x=3.5816864142090945 km y=5.885919960010017 km),
  (epoch=2022-01-09 00:00:00.000 x=4.6142463041484145 km y=1.6241837450638048 km),
  (epoch=2022-01-10 00:00:00.000 x=10.431136438674688 km y=12.21448501292679 km),
  (epoch=2022-01-11 00:00:00.000 x=-2.472579439256678 km y=9.555661842465295 km),
 ]>)