Skip to content
This repository has been archived by the owner on Oct 14, 2023. It is now read-only.

Added a set_frame method to OrbitPlotter2D #488

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 50 additions & 3 deletions src/poliastro/plotting.py
Expand Up @@ -249,6 +249,7 @@ def __init__(self):
self._attractor_data = {} # type: dict
self._attractor_radius = np.inf * u.km
self._color_cycle = cycle(plotly.colors.DEFAULT_PLOTLY_COLORS)
self._orbits = list(tuple()) # type: List[Tuple[Orbit, str, str]]

@property
def figure(self):
Expand Down Expand Up @@ -318,6 +319,9 @@ def plot(self, orbit, *, label=None, color=None):
if color is None:
color = next(self._color_cycle)

if (orbit, label, color) not in self._orbits:
self._orbits.append((orbit, label, color))

self.set_attractor(orbit.attractor)
self._redraw_attractor(orbit.r_p * 0.15) # Arbitrary threshold

Expand All @@ -326,7 +330,7 @@ def plot(self, orbit, *, label=None, color=None):

self._plot_trajectory(trajectory, label, color, True)
# Plot required 2D/3D shape in the position of the body
radius = min(self._attractor_radius * 0.5, (norm(orbit.r) - orbit.attractor.R) * 0.3) # Arbitrary thresholds
radius = min(self._attractor_radius * 0.5, (norm(orbit.r) - orbit.attractor.R) * 0.2) # Arbitrary thresholds
shape = self._plot_sphere(radius, color, label, center=orbit.r)
self._data.append(shape)

Expand Down Expand Up @@ -442,6 +446,7 @@ class OrbitPlotter2D(_BaseOrbitPlotter):

def __init__(self):
super().__init__()
self._frame = None
self._layout = Layout(
autosize=True,
xaxis=dict(
Expand All @@ -457,9 +462,44 @@ def __init__(self):
"shapes": []
})

def _project(self, rr):
rr_proj = rr - rr.dot(self._frame[2])[:, None] * self._frame[2]
x = rr_proj.dot(self._frame[0])
y = rr_proj.dot(self._frame[1])
return x, y

def set_frame(self, p_vec, q_vec, w_vec):
"""Sets perifocal frame.

Raises
------
ValueError
If the vectors are not a set of mutually orthogonal unit vectors.
"""
if not np.allclose([norm(v) for v in (p_vec, q_vec, w_vec)], 1):
raise ValueError("Vectors must be unit.")
elif not np.allclose([p_vec.dot(q_vec),
q_vec.dot(w_vec),
w_vec.dot(p_vec)], 0):
raise ValueError("Vectors must be mutually orthogonal.")
else:
self._frame = p_vec, q_vec, w_vec

if self._data:
self._redraw()

def _redraw(self):
self._data.clear()
self._layout["shapes"] = ()
self._attractor = None
for orbit, label, color in self._orbits:
self.plot(orbit=orbit, label=label, color=color)

def _plot_sphere(self, radius, color, name, center=[0, 0, 0] * u.km):
x_center, y_center = self._project(center[None])
z_center = center[2]
center = [x_center, y_center, z_center]
xx, yy = _generate_circle(radius, center)
x_center, y_center, z_center = center
trace = Scatter(x=xx.to(u.km).value, y=yy.to(u.km).value, mode='markers', line=dict(color=color, width=5,
dash='dash',), name=name)
self._layout["shapes"] += (
Expand All @@ -482,8 +522,10 @@ def _plot_sphere(self, radius, color, name, center=[0, 0, 0] * u.km):
return trace

def _plot_trajectory(self, trajectory, label, color, dashed):
rr = trajectory.represent_as(CartesianRepresentation).xyz.transpose()
x, y = self._project(rr)
trace = Scatter(
x=trajectory.x.to(u.km).value, y=trajectory.y.to(u.km).value,
x=x.to(u.km).value, y=y.to(u.km).value,
name=label,
line=dict(
color=color,
Expand All @@ -494,6 +536,11 @@ def _plot_trajectory(self, trajectory, label, color, dashed):
)
self._data.append(trace)

def plot(self, orbit, *, label=None, color=None):
if self._frame is None:
self.set_frame(*orbit.pqw())
super(OrbitPlotter2D, self).plot(orbit=orbit, label=label, color=color)


def plot_solar_system(outer=True, epoch=None):
"""
Expand Down
34 changes: 33 additions & 1 deletion src/poliastro/tests/test_plotting2d.py
@@ -1,12 +1,14 @@
import pytest

import astropy.units as u

import tempfile

from unittest import mock

from poliastro.examples import iss
from poliastro.plotting import OrbitPlotter2D
from poliastro.bodies import Earth, Mars, Sun
from poliastro.bodies import Earth, Mars, Sun, Jupiter
from poliastro.twobody.orbit import Orbit


Expand Down Expand Up @@ -95,3 +97,33 @@ def test_savefig_calls_prepare_plot(mock_prepare_plot, mock_export):

assert mock_export.call_count == 1
mock_prepare_plot.assert_called_once_with()


def test_set_frame():
op = OrbitPlotter2D()
p = [1, 0, 0] * u.one
q = [0, 1, 0] * u.one
w = [0, 0, 1] * u.one
op.set_frame(p, q, w)

assert op._frame == (p, q, w)


def test_redraw_makes_attractor_none():
op = OrbitPlotter2D()
earth = Orbit.from_body_ephem(Earth)
op.plot(earth)
op._redraw()
assert op._attractor_radius is not None


def test_set_frame_plots_same_colors():
earth = Orbit.from_body_ephem(Earth)
jupiter = Orbit.from_body_ephem(Jupiter)
op = OrbitPlotter2D()
op.plot(earth)
op.plot(jupiter)
colors1 = [orb[2] for orb in op._orbits]
op.set_frame(*jupiter.pqw())
colors2 = [orb[2] for orb in op._orbits]
assert colors1 == colors2