Skip to content

Commit

Permalink
Slight changes to satisfy linter
Browse files Browse the repository at this point in the history
  • Loading branch information
yngvem committed Jul 18, 2020
1 parent 5264f0d commit abe165e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
10 changes: 4 additions & 6 deletions src/group_lasso/_fista.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,14 @@ def minimise(self, x0, n_iter=10, tol=1e-6, callback=None):
generalised_gradient = momentum_x.ravel() - new_optimal_x.ravel()
update_vector = new_optimal_x.ravel() - previous_x.ravel()
# Loss based restart criterion
if generalised_gradient.T@update_vector > 0:
if generalised_gradient.T@update_vector > self.smooth_loss(previous_x):
momentum_x = previous_x
momentum = 1
(
new_optimal_x,
new_momentum_x,
new_momentum,
) = self._update_step(
# fmt: off
new_optimal_x, new_momentum_x, new_momentum = self._update_step( # noqa: E501
previous_x, momentum_x, momentum, self.lipschitz
)
# fmt: on

# Backtracking line search
while self._continue_backtracking(
Expand Down
4 changes: 2 additions & 2 deletions src/group_lasso/_group_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _get_reg_vector(self, reg):
return reg

@abstractmethod
def _unregularised_loss(self, X, y, w): # pragma: nocover
def _unregularised_loss(self, X_aug, y, w): # pragma: nocover
"""The unregularised reconstruction loss.
"""
raise NotImplementedError
Expand Down Expand Up @@ -265,7 +265,7 @@ def loss(self, X, y):
return self._loss(X_aug, y, w)

@abstractmethod
def _estimate_lipschitz(self, X, y): # pragma: nocover
def _estimate_lipschitz(self, X_aug, y): # pragma: nocover
"""Compute Lipschitz bound for the gradient of the unregularised loss.
The Lipschitz bound is with respect to the coefficient vector or
Expand Down
2 changes: 1 addition & 1 deletion test/test_fista.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def prox(x, L):
def test_lipschitz_updates_with_small_initial_guess(
smooth_problem_1d, no_regulariser
):
f, df, lipschitz = smooth_problem_1d
f, df, _ = smooth_problem_1d
g, prox = no_regulariser

small_L = 0.01
Expand Down

0 comments on commit abe165e

Please sign in to comment.