Skip to content
This repository has been archived by the owner on Oct 14, 2023. It is now read-only.

Commit

Permalink
Keep trajectories on redraw, fix #518
Browse files Browse the repository at this point in the history
  • Loading branch information
astrojuanlu committed Jan 7, 2019
1 parent 33393d1 commit 6cf19d8
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 36 deletions.
107 changes: 71 additions & 36 deletions src/poliastro/plotting/static.py
@@ -1,3 +1,6 @@
from collections import namedtuple
from typing import List

import matplotlib as mpl
import numpy as np
from astropy import units as u
Expand All @@ -9,6 +12,10 @@
from poliastro.util import norm


class _Trajectory(namedtuple("_Trajectory", ["trajectory", "state", "label", "color"])):
pass


class StaticOrbitPlotter(object):
"""StaticOrbitPlotter class.
Expand Down Expand Up @@ -43,11 +50,11 @@ def __init__(self, ax=None, num_points=150, dark=False):
self._frame = None
self._attractor = None
self._attractor_radius = np.inf * u.km
self._orbits = list(tuple()) # type: List[Tuple[Orbit, str, str]]
self._trajectories = [] # type: List[_Trajectory]

@property
def orbits(self):
return self._orbits
def trajectories(self):
return self._trajectories

def set_frame(self, p_vec, q_vec, w_vec):
"""Sets perifocal frame.
Expand All @@ -64,40 +71,51 @@ def set_frame(self, p_vec, q_vec, w_vec):
else:
self._frame = p_vec, q_vec, w_vec

if self._orbits:
if self._trajectories:
self._redraw()

def _redraw(self):
for artist in self.ax.lines + self.ax.collections:
artist.remove()
self._attractor = None
for orbit, label, color in self._orbits:
self.plot(orbit, label, color)

for trajectory, state, label, color in self._trajectories:
self._plot(trajectory, state, label, color)

self.ax.relim()
self.ax.autoscale()

def _plot_trajectory(self, trajectory, color=None):
rr = trajectory.represent_as(CartesianRepresentation).xyz.transpose()
x, y = self._project(rr)
lines = self.ax.plot(x.to(u.km).value, y.to(u.km).value, "--", color=color)

return lines

def plot_trajectory(self, trajectory, *, label=None, color=None):
"""Plots a precomputed trajectory.
Parameters
----------
trajectory : ~astropy.coordinates.BaseRepresentation, ~astropy.coordinates.BaseCoordinateFrame
Trajectory to plot.
label : str, optional
Label.
color : str, optional
Color string.
"""
lines = []
rr = trajectory.represent_as(CartesianRepresentation).xyz.transpose()
x, y = self._project(rr)
a, = self.ax.plot(
x.to(u.km).value, y.to(u.km).value, "--", color=color, label=label
)
lines.append(a)
lines = self._plot_trajectory(trajectory, color)

if label:
a.set_label(label)
lines[0].set_label(label)
self.ax.legend(
loc="upper left", bbox_to_anchor=(1.05, 1.015), title="Names and epochs"
)

self._trajectories.append(
_Trajectory(trajectory, None, label, lines[0].get_color())
)

return lines

def set_attractor(self, attractor):
Expand Down Expand Up @@ -137,34 +155,30 @@ def _redraw_attractor(self, min_radius=0 * u.km):
mpl.patches.Circle((0, 0), self._attractor_radius.value, lw=0, color=color)
)

def plot(self, orbit, label=None, color=None, method=mean_motion):
"""Plots state and osculating orbit in their plane.
"""
if not self._frame:
self.set_frame(*orbit.pqw())

self.set_attractor(orbit.attractor)
self._redraw_attractor(orbit.r_p * 0.15) # Arbitrary Threshhold
positions = orbit.sample(self.num_points, method)

x0, y0 = self._project(orbit.r[None])
# Plot current position
l, = self.ax.plot(x0.to(u.km).value, y0.to(u.km).value, "o", mew=0, color=color)
def _plot(self, trajectory, state=None, label=None, color=None):
lines = self._plot_trajectory(trajectory, color)

if (orbit, label, l.get_color()) not in self._orbits:
self._orbits.append((orbit, label, l.get_color()))
if state is not None:
x0, y0 = self._project(state[None])

lines = self.plot_trajectory(trajectory=positions, color=l.get_color())
lines.append(l)
# Plot current position
l, = self.ax.plot(
x0.to(u.km).value,
y0.to(u.km).value,
"o",
mew=0,
color=lines[0].get_color(),
)
lines.append(l)

if label:
# This will apply the label to either the point or the osculating
# orbit depending on the last plotted line, as they share variable
if not self.ax.get_legend():
size = self.ax.figure.get_size_inches() + [8, 0]
self.ax.figure.set_size_inches(size)
label = generate_label(orbit, label)
l.set_label(label)

# This will apply the label to either the point or the osculating
# orbit depending on the last plotted line
lines[-1].set_label(label)
self.ax.legend(
loc="upper left", bbox_to_anchor=(1.05, 1.015), title="Names and epochs"
)
Expand All @@ -174,3 +188,24 @@ def plot(self, orbit, label=None, color=None, method=mean_motion):
self.ax.set_aspect(1)

return lines

def plot(self, orbit, label=None, color=None, method=mean_motion):
"""Plots state and osculating orbit in their plane.
"""
if not self._frame:
self.set_frame(*orbit.pqw())

self.set_attractor(orbit.attractor)
self._redraw_attractor(orbit.r_p * 0.15) # Arbitrary Threshhold
positions = orbit.sample(self.num_points, method)

if label:
label = generate_label(orbit, label)

lines = self._plot(positions, orbit.r, label, color)

self._trajectories.append(
_Trajectory(positions, orbit.r, label, lines[0].get_color())
)

return lines
16 changes: 16 additions & 0 deletions src/poliastro/tests/tests_plotting/test_static.py
Expand Up @@ -100,3 +100,19 @@ def test_set_frame_plots_same_colors():
op.set_frame(*jupiter.pqw())
colors2 = [orb[2] for orb in op.trajectories]
assert colors1 == colors2


def test_redraw_keeps_trajectories():
# See https://github.com/poliastro/poliastro/issues/518
op = StaticOrbitPlotter()
earth = Orbit.from_body_ephem(Earth)
mars = Orbit.from_body_ephem(Mars)
trajectory = earth.sample()
op.plot(mars, label="Mars")
op.plot_trajectory(trajectory, label="Earth")

assert len(op.trajectories) == 2

op.set_frame(*mars.pqw())

assert len(op.trajectories) == 2

0 comments on commit 6cf19d8

Please sign in to comment.