Skip to content

Commit

Permalink
Merge 49ff4e0 into 416312a
Browse files Browse the repository at this point in the history
  • Loading branch information
Stéphane Caron committed Jul 25, 2023
2 parents 416312a + 49ff4e0 commit 1f119a1
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 176 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ All notable changes to this project will be documented in this file.
- MPCProblem: target state trajectory for stage state cost
- Plan: ``first_input`` getter
- Plan: ``is_empty`` property
- Started ``ltv_mpc.live_plots`` submodule
- Started ``ltv_mpc.systems`` submodule

### Changed
Expand Down
41 changes: 20 additions & 21 deletions examples/cart_pole.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@
from loop_rate_limiters import RateLimiter

from ltv_mpc import solve_mpc
from ltv_mpc.live_plots import CartPolePlot
from ltv_mpc.systems import CartPole

EXAMPLE_DURATION: float = 10.0 # seconds
NB_SUBSTEPS: int = 15 # number of integration substeps

params = CartPole.Parameters() # default cart-pole model


def parse_command_line_arguments() -> argparse.Namespace:
"""Parse command-line arguments."""
Expand Down Expand Up @@ -66,7 +65,9 @@ def parse_command_line_arguments() -> argparse.Namespace:
return parser.parse_args()


def get_target_states(state: np.ndarray, target_vel: float):
def get_target_states(
cart_pole: CartPole, state: np.ndarray, target_vel: float
):
"""Define the reference state trajectory over the receding horizon.
Args:
Expand All @@ -76,44 +77,42 @@ def get_target_states(state: np.ndarray, target_vel: float):
Returns:
Goal state at the end of the horizon.
"""
nx = CartPole.STATE_DIM
T = params.sampling_period
target_states = np.zeros((params.nb_timesteps + 1) * nx)
for k in range(params.nb_timesteps + 1):
nx = cart_pole.STATE_DIM
T = cart_pole.sampling_period
target_states = np.zeros((cart_pole.nb_timesteps + 1) * nx)
for k in range(cart_pole.nb_timesteps + 1):
target_states[k * nx] = state[0] + (k * T) * target_vel
target_states[k * nx + 2] = target_vel
return target_states


if __name__ == "__main__":
args = parse_command_line_arguments()
cart_pole = CartPole()
live_plot = CartPolePlot(cart_pole, order=args.plot)
mpc_problem = CartPole.build_mpc_problem(
params,
cart_pole,
terminal_cost_weight=10.0,
stage_state_cost_weight=1.0,
stage_input_cost_weight=1e-3,
)

cart_pole = CartPole(
params,
initial_state=np.zeros(CartPole.STATE_DIM),
)
cart_pole.init_live_plot(order=args.plot)

dt = params.sampling_period / NB_SUBSTEPS
dt = cart_pole.sampling_period / NB_SUBSTEPS
rate = RateLimiter(frequency=1.0 / (args.slowdown * dt), warn=False)
for t in np.arange(0.0, EXAMPLE_DURATION, params.sampling_period):
state = np.zeros(cart_pole.STATE_DIM)
for t in np.arange(0.0, EXAMPLE_DURATION, cart_pole.sampling_period):
target_vel = 0.5 + (np.cos(t / 2.0) if args.tv_vel else 0.0)
target_states = get_target_states(cart_pole.state, target_vel)
mpc_problem.update_initial_state(cart_pole.state)
target_states = get_target_states(cart_pole, state, target_vel)
mpc_problem.update_initial_state(state)
mpc_problem.update_goal_state(target_states[-CartPole.STATE_DIM :])
mpc_problem.update_target_states(target_states[: -CartPole.STATE_DIM])
plan = solve_mpc(mpc_problem, solver=args.solver)
for step in range(NB_SUBSTEPS):
cart_pole.step(plan.first_input, dt)
cart_pole.update_live_plot(
plan,
state = cart_pole.integrate(state, plan.first_input, dt)
live_plot.update(
plan=plan,
plan_time=t,
state=state,
state_time=t + step * dt,
)
rate.sleep()
Expand Down
4 changes: 4 additions & 0 deletions ltv_mpc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,9 @@ class ProblemDefinitionError(LTVMPCException):
"""Problem definition is incorrect."""


class PlanError(LTVMPCException):
"""Plan is not correct."""


class StateError(LTVMPCException):
"""Report an ill-formed state."""
24 changes: 24 additions & 0 deletions ltv_mpc/live_plots/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2023 Inria
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Set of system-specific live plots provided for reference and examples."""

from .cart_pole_plot import CartPolePlot

__all__ = [
"CartPolePlot",
]
147 changes: 147 additions & 0 deletions ltv_mpc/live_plots/cart_pole_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2023 Inria
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Live plot for the cart-pole system."""

import numpy as np

from ..exceptions import PlanError
from ..plan import Plan
from ..systems import CartPole
from .live_plot import LivePlot


class CartPolePlot:
"""Live plot for the cart-pole system."""

live_plot: LivePlot
cart_pole: CartPole
lhs_index: int
rhs_index: int

def __init__(self, cart_pole: CartPole, order: str) -> None:
"""Initialize live plot.
Args:
cart_pole: Cart-pole system.
order: Order of things to plot, "positions" or "velocities".
"""
lhs_index = 0 if order == "positions" else 2
rhs_index = 1 if order == "positions" else 3
ps = "" if order == "positions" else "/s"
T = cart_pole.sampling_period
live_plot = LivePlot(
xlim=(0.0, cart_pole.horizon_duration + T),
ylim=(-0.5, 1.0),
ylim2=(-1.0, 1.0),
)
live_plot.add_line("lhs", "b-")
live_plot.axis.set_ylabel(f"Ground {order} [m{ps}]", color="b")
live_plot.axis.tick_params(axis="y", labelcolor="b")
live_plot.add_rhs_line("rhs", "g-")
if live_plot.rhs_axis is not None: # help mypy
label = f"Angular {order} [rad{ps}]"
live_plot.rhs_axis.set_ylabel(label, color="g")
live_plot.rhs_axis.tick_params(axis="y", labelcolor="g")
live_plot.add_line("lhs_cur", "bo", lw=2)
live_plot.add_line("lhs_goal", "b--", lw=1)
live_plot.add_rhs_line("rhs_goal", "g--", lw=1)
live_plot.add_rhs_line("rhs_cur", "go", lw=2)
self.cart_pole = cart_pole
self.lhs_index = lhs_index
self.live_plot = live_plot
self.rhs_index = rhs_index

def update_plan(self, plan: Plan, plan_time: float) -> None:
"""Update live-plot from plan.
Args:
plan: Solution to the MPC problem.
plan_time: Time corresponding to the initial state.
"""
if plan.states is None:
raise PlanError("No state trajectory in plan")
X = plan.states
t = plan_time
horizon_duration = self.cart_pole.horizon_duration
nb_timesteps = self.cart_pole.nb_timesteps
trange = np.linspace(t, t + horizon_duration, nb_timesteps + 1)
self.live_plot.update_line("lhs", trange, X[:, self.lhs_index])
self.live_plot.update_line("rhs", trange, X[:, self.rhs_index])
if (
plan.problem.target_states is None
or plan.problem.goal_state is None
):
return
self.live_plot.update_line(
"lhs_goal",
trange,
np.hstack(
[
plan.problem.target_states[self.lhs_index :: 4],
plan.problem.goal_state[self.lhs_index],
]
),
)
self.live_plot.update_line(
"rhs_goal",
trange,
np.hstack(
[
plan.problem.target_states[self.rhs_index :: 4],
plan.problem.goal_state[self.rhs_index],
]
),
)

def update_state(self, state: np.ndarray, state_time: float):
"""Update live-plot from current state.
Args:
state: Current state of the system.
state_time: Time corresponding to the state.
"""
horizon_duration = self.cart_pole.horizon_duration
T = self.cart_pole.sampling_period
if state_time >= T:
t2 = state_time - T
self.live_plot.axis.set_xlim(t2, t2 + horizon_duration + T)
self.live_plot.update_line(
"lhs_cur", [state_time], [state[self.lhs_index]]
)
self.live_plot.update_line(
"rhs_cur", [state_time], [state[self.rhs_index]]
)

def update(
self,
plan: Plan,
plan_time: float,
state: np.ndarray,
state_time: float,
) -> None:
"""Plot plan resulting from the MPC problem.
Args:
plan: Solution to the MPC problem.
plan_time: Time of the beginning of the receding horizon.
state: Current state.
state_time: Time of the current state.
"""
self.update_plan(plan, plan_time)
self.update_state(state, state_time)
self.live_plot.update()
2 changes: 1 addition & 1 deletion ltv_mpc/utils.py → ltv_mpc/live_plots/live_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import matplotlib
from matplotlib import pyplot as plt

from .exceptions import LTVMPCException
from ..exceptions import LTVMPCException


class LivePlot:
Expand Down
Loading

0 comments on commit 1f119a1

Please sign in to comment.