Skip to content

Commit

Permalink
Merge pull request #767 from pariterre/master
Browse files Browse the repository at this point in the history
Fixed custom_plot
  • Loading branch information
pariterre committed Sep 27, 2023
2 parents 941a16f + d7ae7bb commit 4d7909a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
8 changes: 4 additions & 4 deletions bioptim/examples/getting_started/custom_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,18 @@ def prepare_ocp(
)

# Add my lovely new plot
ocp.add_plot("My New Extra Plot", lambda t, x, u, p: custom_plot_callback(x, [0, 1, 3]), plot_type=PlotType.PLOT)
ocp.add_plot("My New Extra Plot", lambda t, x, u, p, s: custom_plot_callback(x, [0, 1, 3]), plot_type=PlotType.PLOT)
ocp.add_plot( # This one combines to the previous one as they have the same name
"My New Extra Plot",
lambda t, x, u, p: custom_plot_callback(x, [1, 3]),
lambda t, x, u, p, s: custom_plot_callback(x, [1, 3]),
plot_type=PlotType.STEP,
axes_idx=[1, 2],
)
ocp.add_plot(
"My Second New Extra Plot",
lambda t, x, u, p: custom_plot_callback(x, [1, 3]),
lambda t, x, u, p, s: custom_plot_callback(x, [2, 1]),
plot_type=PlotType.INTEGRATED,
axes_idx=[1, 2],
axes_idx=[0, 2],
)

return ocp
Expand Down
17 changes: 12 additions & 5 deletions bioptim/gui/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def legend_without_duplicate_labels(ax):
)
nlp.plot[key].phase_mappings = BiMapping(to_first=range(size), to_second=range(size))
else:
size = len(nlp.plot[key].phase_mappings.to_second.map_idx)
size = max(nlp.plot[key].phase_mappings.to_second.map_idx) + 1
if key not in variable_sizes[i]:
variable_sizes[i][key] = size
else:
Expand Down Expand Up @@ -432,7 +432,7 @@ def legend_without_duplicate_labels(ax):
continue

mapping_to_first_index = nlp.plot[variable].phase_mappings.to_first.map_idx
mapping_range_index = list(range(len(nlp.plot[variable].phase_mappings.to_second.map_idx)))
mapping_range_index = list(range(max(nlp.plot[variable].phase_mappings.to_second.map_idx) + 1))
for ctr in mapping_range_index:
ax = axes[ctr]
if ctr in mapping_to_first_index:
Expand Down Expand Up @@ -804,10 +804,13 @@ def update_data(self, v: dict):
y_tp[:, :] = val
all_y.append(y_tp)

for idx in range(len(self.plot_func[key][i].phase_mappings.to_second.map_idx)):
for idx in range(max(self.plot_func[key][i].phase_mappings.to_second.map_idx) + 1):
y_tp = []
for y in all_y:
y_tp.append(y[idx, :])
if idx in self.plot_func[key][i].phase_mappings.to_second.map_idx:
for y in all_y:
y_tp.append(y[idx, :])
else:
y_tp = None
self.__append_to_ydata([y_tp])

elif self.plot_func[key][i].type == PlotType.POINT:
Expand Down Expand Up @@ -1057,9 +1060,13 @@ def __update_axes(self):
"""
Update the plotted data from ydata
"""

assert len(self.plots) == len(self.ydata)
for i, plot in enumerate(self.plots):
y = self.ydata[i]
if y is None:
# Jump the plots which are empty
y = (np.nan,) * len(plot[2])

if plot[0] == PlotType.INTEGRATED:
for cmp, p in enumerate(plot[2]):
Expand Down

0 comments on commit 4d7909a

Please sign in to comment.