Skip to content

Commit

Permalink
Added mixin classes
Browse files Browse the repository at this point in the history
  • Loading branch information
yngvem committed Jul 16, 2019
1 parent 6b8a6bc commit f7e231e
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions src/group_lasso/_group_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
check_array,
check_consistent_length,
)
from sklearn.base import (
BaseEstimator,
TransformerMixin,
RegressorMixin,
ClassifierMixin,
)

from group_lasso._singular_values import find_largest_singular_value
from group_lasso._subsampling import subsample
Expand Down Expand Up @@ -51,7 +57,7 @@ def _add_intercept_col(X):
return np.concatenate([ones, X], axis=1)


class BaseGroupLasso(ABC):
class BaseGroupLasso(ABC, BaseEstimator, TransformerMixin):
"""
This class implements the Group Lasso [1] regularisation for optimisation
problems with Lipschitz continuous gradients, which is approximately
Expand Down Expand Up @@ -126,20 +132,6 @@ def __init__(
self.fit_intercept = fit_intercept
self.random_state = random_state

def get_params(self, deep=True):
return {
"groups": self.groups,
"reg": self.reg,
"n_iter": self.n_iter,
"tol": self.tol,
"subsampling_scheme": self.subsampling_scheme,
}

def set_params(self, **parameters):
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self

def _regularizer(self, w):
regularizer = 0
b, w = _split_intercept(w)
Expand Down Expand Up @@ -302,7 +294,7 @@ def _l2_grad(A, b, x):
return A.T @ (A @ x - b)


class GroupLasso(BaseGroupLasso):
class GroupLasso(BaseGroupLasso, RegressorMixin):
"""
This class implements the Group Lasso [1] regularisation for linear
regression with the mean squared penalty.
Expand Down Expand Up @@ -433,7 +425,7 @@ def _logistic_cross_entropy(X, y, w):
return -(y * np.log(p) + (1 - y) * np.log(1 - p))


class LogisticGroupLasso(BaseGroupLasso):
class LogisticGroupLasso(BaseGroupLasso, ClassifierMixin):
"""WARNING: Experimental.
"""

Expand Down

0 comments on commit f7e231e

Please sign in to comment.