Skip to content

Commit

Permalink
Merge pull request #796 from Ipuch/starting_point
Browse files Browse the repository at this point in the history
duplicate_collocation_starting_point
  • Loading branch information
pariterre committed Nov 8, 2023
2 parents 88695f1 + c19cc82 commit 777c20c
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 23 deletions.
6 changes: 3 additions & 3 deletions bioptim/dynamics/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def __init__(self, ode: dict, ode_opt: dict):

self.method = ode_opt["method"]
self.degree = ode_opt["irk_polynomial_interpolation_degree"]
self.include_starting_collocation_point = ode_opt["include_starting_collocation_point"]
self.duplicate_collocation_starting_point = ode_opt["duplicate_collocation_starting_point"]
self.allow_free_variables = ode_opt["allow_free_variables"]

# Coefficients of the collocation equation
Expand Down Expand Up @@ -819,7 +819,7 @@ def _finish_init(self):
self.function = Function(
"integrator",
[
horzcat(*self.x_sym) if self.include_starting_collocation_point else horzcat(*self.x_sym[1:]),
horzcat(*self.x_sym) if self.duplicate_collocation_starting_point else horzcat(*self.x_sym[1:]),
self.u_sym,
self.param_sym,
self.s_sym,
Expand Down Expand Up @@ -911,7 +911,7 @@ def dxdt(

# Root-finding function, implicitly defines x_collocation_points as a function of x0 and p
time_sym = []
collocation_states = vertcat(*states[1:]) if self.include_starting_collocation_point else vertcat(*states[2:])
collocation_states = vertcat(*states[1:]) if self.duplicate_collocation_starting_point else vertcat(*states[2:])
vfcn = Function(
"vfcn",
[collocation_states, time_sym, states[0], controls, params, stochastic_variables],
Expand Down
12 changes: 6 additions & 6 deletions bioptim/dynamics/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class COLLOCATION(OdeSolverBase):
The method of interpolation ("legendre" or "radau")
defects_type: DefectType
The type of defect to use (DefectType.EXPLICIT or DefectType.IMPLICIT)
include_starting_collocation_point: bool
duplicate_collocation_starting_point: bool
Whether an additional collocation point should be added at the shooting node (this is typically used in SOCPs)
Methods
Expand All @@ -344,7 +344,7 @@ def __init__(
polynomial_degree: int = 4,
method: str = "legendre",
defects_type: DefectType = DefectType.EXPLICIT,
include_starting_collocation_point: bool = False,
duplicate_collocation_starting_point: bool = False,
):
"""
Parameters
Expand All @@ -355,8 +355,8 @@ def __init__(

super(OdeSolver.COLLOCATION, self).__init__()
self.polynomial_degree = polynomial_degree
self.include_starting_collocation_point = include_starting_collocation_point
self.n_cx = polynomial_degree + 3 if include_starting_collocation_point else polynomial_degree + 2
self.duplicate_collocation_starting_point = duplicate_collocation_starting_point
self.n_cx = polynomial_degree + 3 if duplicate_collocation_starting_point else polynomial_degree + 2
self.rk_integrator = COLLOCATION
self.method = method
self.defects_type = defects_type
Expand Down Expand Up @@ -384,7 +384,7 @@ def integrator(
"developers and ping @EveCharbie"
)

if self.include_starting_collocation_point:
if self.duplicate_collocation_starting_point:
x_unscaled = ([nlp.states.cx_start] + nlp.states.cx_intermediates_list,)
x_scaled = [nlp.states.scaled.cx_start] + nlp.states.scaled.cx_intermediates_list
else:
Expand Down Expand Up @@ -423,7 +423,7 @@ def integrator(
"irk_polynomial_interpolation_degree": self.polynomial_degree,
"method": self.method,
"defects_type": self.defects_type,
"include_starting_collocation_point": self.include_starting_collocation_point,
"duplicate_collocation_starting_point": self.duplicate_collocation_starting_point,
"allow_free_variables": allow_free_variables,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def prepare_socp(
ode_solver = OdeSolver.COLLOCATION(
polynomial_degree=socp_type.polynomial_degree,
method=socp_type.method,
include_starting_collocation_point=True,
duplicate_collocation_starting_point=True,
)

return OptimalControlProgram(
Expand Down
2 changes: 1 addition & 1 deletion bioptim/limits/penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ def state_continuity(penalty: PenaltyOption, controller: PenaltyController | lis
def first_collocation_point_equals_state(penalty: PenaltyOption, controller: PenaltyController | list):
"""
Insures that the first collocation helper is equal to the states at the shooting node.
This is a necessary constraint for COLLOCATION with include_starting_collocation_point.
This is a necessary constraint for COLLOCATION with duplicate_collocation_starting_point.
"""
collocation_helper = controller.states.cx_intermediates_list[0]
states = controller.states.cx_start
Expand Down
4 changes: 2 additions & 2 deletions bioptim/optimization/optimal_control_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def _declare_continuity(self) -> None:
ConstraintFcn.STATE_CONTINUITY, node=Node.ALL_SHOOTING, penalty_type=PenaltyType.INTERNAL
)
penalty.add_or_replace_to_penalty_pool(self, nlp)
if nlp.ode_solver.is_direct_collocation and nlp.ode_solver.include_starting_collocation_point:
if nlp.ode_solver.is_direct_collocation and nlp.ode_solver.duplicate_collocation_starting_point:
penalty = Constraint(
ConstraintFcn.FIRST_COLLOCATION_HELPER_EQUALS_STATE,
node=Node.ALL_SHOOTING,
Expand All @@ -968,7 +968,7 @@ def _declare_continuity(self) -> None:
ConstraintFcn.STATE_CONTINUITY, node=shooting_node, penalty_type=PenaltyType.INTERNAL
)
penalty.add_or_replace_to_penalty_pool(self, nlp)
if nlp.ode_solver.is_direct_collocation and nlp.ode_solver.include_starting_collocation_point:
if nlp.ode_solver.is_direct_collocation and nlp.ode_solver.duplicate_collocation_starting_point:
penalty = Constraint(
ConstraintFcn.FIRST_COLLOCATION_HELPER_EQUALS_STATE,
node=shooting_node,
Expand Down
16 changes: 8 additions & 8 deletions bioptim/optimization/solution/simplified_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ class SimplifiedNLP:
Generate time vector steps for a phase considering all the phase final time
_define_step_times(self, dynamics_step_time: list, ode_solver_steps: int,
keep_intermediate_points: bool = None, continuous: bool = True,
is_direct_collocation: bool = None, include_starting_collocation_point: bool = False) -> np.ndarray
is_direct_collocation: bool = None, duplicate_collocation_starting_point: bool = False) -> np.ndarray
Define the time steps for the integration of the whole phase
_define_step_times(self, dynamics_step_time: list, ode_solver_steps: int,
keep_intermediate_points: bool = None, continuous: bool = True,
is_direct_collocation: bool = None, include_starting_collocation_point: bool = False) -> np.ndarray
is_direct_collocation: bool = None, duplicate_collocation_starting_point: bool = False) -> np.ndarray
Define the time steps for the integration of the whole phase
_complete_controls(self, controls: dict[str, np.ndarray]) -> dict[str, np.ndarray]
Controls don't necessarily have dimensions that matches the states. This method aligns them
Expand Down Expand Up @@ -313,15 +313,15 @@ def _generate_time(
np.ndarray
"""
is_direct_collocation = self.ode_solver.is_direct_collocation
include_starting_collocation_point = False
duplicate_collocation_starting_point = False
if is_direct_collocation:
include_starting_collocation_point = self.ode_solver.include_starting_collocation_point
duplicate_collocation_starting_point = self.ode_solver.duplicate_collocation_starting_point

step_times = self._define_step_times(
dynamics_step_time=self.dynamics[0].step_time,
ode_solver_steps=self.ode_solver.steps,
is_direct_collocation=is_direct_collocation,
include_starting_collocation_point=include_starting_collocation_point,
duplicate_collocation_starting_point=duplicate_collocation_starting_point,
keep_intermediate_points=keep_intermediate_points,
continuous=shooting_type == Shooting.SINGLE,
)
Expand Down Expand Up @@ -356,7 +356,7 @@ def _define_step_times(
keep_intermediate_points: bool = None,
continuous: bool = True,
is_direct_collocation: bool = None,
include_starting_collocation_point: bool = False,
duplicate_collocation_starting_point: bool = False,
) -> np.ndarray:
"""
Define the time steps for the integration of the whole phase
Expand All @@ -375,7 +375,7 @@ def _define_step_times(
arrival node and the beginning of the next one are expected to be almost equal when the problem converged
is_direct_collocation: bool
If the ode solver is direct collocation
include_starting_collocation_point
duplicate_collocation_starting_point
If the ode solver is direct collocation and an additional collocation point at the shooting node was used
Returns
Expand All @@ -392,7 +392,7 @@ def _define_step_times(
if keep_intermediate_points:
step_times = np.array(dynamics_step_time + [1])

if include_starting_collocation_point:
if duplicate_collocation_starting_point:
step_times = np.array([0] + step_times)
else:
step_times = np.array(dynamics_step_time + [1])[[0, -1]]
Expand Down
4 changes: 2 additions & 2 deletions bioptim/optimization/stochastic_optimal_control_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _set_default_ode_solver(self):
return OdeSolver.COLLOCATION(
method=self.problem_type.method,
polynomial_degree=self.problem_type.polynomial_degree,
include_starting_collocation_point=True,
duplicate_collocation_starting_point=True,
)
else:
raise RuntimeError("Wrong choice of problem_type, you must choose one of the SocpType.")
Expand Down Expand Up @@ -372,7 +372,7 @@ def _check_has_no_ode_solver_defined(**kwargs):
"OdeSolver.COLLOCATION("
"method=problem_type.method, "
"polynomial_degree=problem_type.polynomial_degree, "
"include_starting_collocation_point=True"
"duplicate_collocation_starting_point=True"
")"
)

Expand Down

0 comments on commit 777c20c

Please sign in to comment.