From cc5cc2ce9f2627a86891692161fa9174af7df961 Mon Sep 17 00:00:00 2001 From: Shreyas Bapat Date: Sun, 4 Nov 2018 03:55:39 +0530 Subject: [PATCH 1/2] Added a set_frame method to OrbitPlotter2D. Fixes #483 and #480 --- src/poliastro/plotting.py | 53 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/src/poliastro/plotting.py b/src/poliastro/plotting.py index 11959e8db..d48c0e502 100644 --- a/src/poliastro/plotting.py +++ b/src/poliastro/plotting.py @@ -249,6 +249,7 @@ def __init__(self): self._attractor_data = {} # type: dict self._attractor_radius = np.inf * u.km self._color_cycle = cycle(plotly.colors.DEFAULT_PLOTLY_COLORS) + self._orbits = list(tuple()) # type: List[Tuple[Orbit, str, str]] @property def figure(self): @@ -318,6 +319,9 @@ def plot(self, orbit, *, label=None, color=None): if color is None: color = next(self._color_cycle) + if (orbit, label, color) not in self._orbits: + self._orbits.append((orbit, label, color)) + self.set_attractor(orbit.attractor) self._redraw_attractor(orbit.r_p * 0.15) # Arbitrary threshold @@ -326,7 +330,7 @@ def plot(self, orbit, *, label=None, color=None): self._plot_trajectory(trajectory, label, color, True) # Plot required 2D/3D shape in the position of the body - radius = min(self._attractor_radius * 0.5, (norm(orbit.r) - orbit.attractor.R) * 0.3) # Arbitrary thresholds + radius = min(self._attractor_radius * 0.5, (norm(orbit.r) - orbit.attractor.R) * 0.2) # Arbitrary thresholds shape = self._plot_sphere(radius, color, label, center=orbit.r) self._data.append(shape) @@ -442,6 +446,7 @@ class OrbitPlotter2D(_BaseOrbitPlotter): def __init__(self): super().__init__() + self._frame = None self._layout = Layout( autosize=True, xaxis=dict( @@ -457,9 +462,44 @@ def __init__(self): "shapes": [] }) + def _project(self, rr): + rr_proj = rr - rr.dot(self._frame[2])[:, None] * self._frame[2] + x = rr_proj.dot(self._frame[0]) + y = rr_proj.dot(self._frame[1]) + return x, y + + def set_frame(self, p_vec, q_vec, w_vec): + """Sets perifocal frame. + + Raises + ------ + ValueError + If the vectors are not a set of mutually orthogonal unit vectors. + """ + if not np.allclose([norm(v) for v in (p_vec, q_vec, w_vec)], 1): + raise ValueError("Vectors must be unit.") + elif not np.allclose([p_vec.dot(q_vec), + q_vec.dot(w_vec), + w_vec.dot(p_vec)], 0): + raise ValueError("Vectors must be mutually orthogonal.") + else: + self._frame = p_vec, q_vec, w_vec + + if self._data: + self._redraw() + + def _redraw(self): + self._data.clear() + self._layout["shapes"] = () + self._attractor = None + for orbit, label, color in self._orbits: + self.plot(orbit=orbit, label=label, color=color) + def _plot_sphere(self, radius, color, name, center=[0, 0, 0] * u.km): + x_center, y_center = self._project(center[None]) + z_center = center[2] + center = [x_center, y_center, z_center] xx, yy = _generate_circle(radius, center) - x_center, y_center, z_center = center trace = Scatter(x=xx.to(u.km).value, y=yy.to(u.km).value, mode='markers', line=dict(color=color, width=5, dash='dash',), name=name) self._layout["shapes"] += ( @@ -482,8 +522,10 @@ def _plot_sphere(self, radius, color, name, center=[0, 0, 0] * u.km): return trace def _plot_trajectory(self, trajectory, label, color, dashed): + rr = trajectory.represent_as(CartesianRepresentation).xyz.transpose() + x, y = self._project(rr) trace = Scatter( - x=trajectory.x.to(u.km).value, y=trajectory.y.to(u.km).value, + x=x.to(u.km).value, y=y.to(u.km).value, name=label, line=dict( color=color, @@ -494,6 +536,11 @@ def _plot_trajectory(self, trajectory, label, color, dashed): ) self._data.append(trace) + def plot(self, orbit, *, label=None, color=None): + if self._frame is None: + self.set_frame(*orbit.pqw()) + super(OrbitPlotter2D, self).plot(orbit=orbit, label=label, color=color) + def plot_solar_system(outer=True, epoch=None): """ From 6e1bb4145cb50b7929daa54cd0ba3b060cf96400 Mon Sep 17 00:00:00 2001 From: Shreyas Bapat Date: Sun, 4 Nov 2018 04:04:02 +0530 Subject: [PATCH 2/2] Added tests for set_frame in OrbitPlotter2D --- src/poliastro/tests/test_plotting2d.py | 34 +++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/poliastro/tests/test_plotting2d.py b/src/poliastro/tests/test_plotting2d.py index 0d778df2a..d2a3153cb 100644 --- a/src/poliastro/tests/test_plotting2d.py +++ b/src/poliastro/tests/test_plotting2d.py @@ -1,12 +1,14 @@ import pytest +import astropy.units as u + import tempfile from unittest import mock from poliastro.examples import iss from poliastro.plotting import OrbitPlotter2D -from poliastro.bodies import Earth, Mars, Sun +from poliastro.bodies import Earth, Mars, Sun, Jupiter from poliastro.twobody.orbit import Orbit @@ -95,3 +97,33 @@ def test_savefig_calls_prepare_plot(mock_prepare_plot, mock_export): assert mock_export.call_count == 1 mock_prepare_plot.assert_called_once_with() + + +def test_set_frame(): + op = OrbitPlotter2D() + p = [1, 0, 0] * u.one + q = [0, 1, 0] * u.one + w = [0, 0, 1] * u.one + op.set_frame(p, q, w) + + assert op._frame == (p, q, w) + + +def test_redraw_makes_attractor_none(): + op = OrbitPlotter2D() + earth = Orbit.from_body_ephem(Earth) + op.plot(earth) + op._redraw() + assert op._attractor_radius is not None + + +def test_set_frame_plots_same_colors(): + earth = Orbit.from_body_ephem(Earth) + jupiter = Orbit.from_body_ephem(Jupiter) + op = OrbitPlotter2D() + op.plot(earth) + op.plot(jupiter) + colors1 = [orb[2] for orb in op._orbits] + op.set_frame(*jupiter.pqw()) + colors2 = [orb[2] for orb in op._orbits] + assert colors1 == colors2