Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544458648
  • Loading branch information
JAXopt authors committed Jun 29, 2023
1 parent 5ff06fe commit e6fbf5f
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions jaxopt/_src/backtracking_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def update(
value: Optional[float] = None,
grad: Optional[Any] = None,
descent_direction: Optional[Any] = None,
fun_args: list[Any] = [],
fun_kwargs: dict[str, Any] = {},
fun_args: list = [],
fun_kwargs: dict = {},
) -> base.LineSearchStep:
"""Performs one iteration of backtracking line search.
Expand Down
4 changes: 2 additions & 2 deletions jaxopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,8 @@ def run(self,
value: Optional[float] = None,
grad: Optional[Any] = None,
descent_direction: Optional[Any] = None,
fun_args: list[Any] = [],
fun_kwargs: dict[str, Any] = {}) -> LineSearchStep:
fun_args: list = [],
fun_kwargs: dict = {}) -> LineSearchStep:

return super()._run(init_stepsize, params, value, grad, descent_direction,
fun_args, fun_kwargs)
2 changes: 1 addition & 1 deletion jaxopt/_src/broyden.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def ls_fun_with_aux(params, *args, **kwargs):
new_stepsize, ls_state = ls.run(init_stepsize,
params, value, None,
descent_direction,
*args, **kwargs)
fun_args=args, fun_kwargs=kwargs)
new_value, new_aux = ls_state.aux
new_params = ls_state.params
else:
Expand Down
8 changes: 4 additions & 4 deletions jaxopt/_src/hager_zhang_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ def init_state(
value: Optional[float] = None,
grad: Optional[Any] = None,
descent_direction: Optional[Any] = None,
fun_args: list[Any] = [],
fun_kwargs: dict[str, Any] = {},
fun_args: list = [],
fun_kwargs: dict = {},
) -> HagerZhangLineSearchState:
"""Initialize the line search state.
Expand Down Expand Up @@ -419,8 +419,8 @@ def update(
value: Optional[float] = None,
grad: Optional[Any] = None,
descent_direction: Optional[Any] = None,
fun_args: list[Any] = [],
fun_kwargs: dict[str, Any] = {},
fun_args: list = [],
fun_kwargs: dict = {},
) -> base.LineSearchStep:
"""Performs one iteration of Hager-Zhang line search.
Expand Down
8 changes: 4 additions & 4 deletions jaxopt/_src/zoom_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,8 @@ def init_state(
value: Optional[float] = None,
grad: Optional[Any] = None,
descent_direction: Optional[Any] = None,
fun_args: list[Any] = [],
fun_kwargs: dict[str, Any] = {},
fun_args: list = [],
fun_kwargs: dict = {},
) -> base.LineSearchStep:
"""Initialize the line search state by computing all relevant quantities and store it in the initial state.
Expand Down Expand Up @@ -682,8 +682,8 @@ def update(
value: Optional[float] = None,
grad: Optional[Any] = None,
descent_direction: Optional[Any] = None,
fun_args: list[Any] = [],
fun_kwargs: dict[str, Any] = {},
fun_args: list = [],
fun_kwargs: dict = {},
) -> base.LineSearchStep:
"""Combines Algorithms 3.5 and 3.6 of [1].
Expand Down

0 comments on commit e6fbf5f

Please sign in to comment.