Skip to content

Commit

Permalink
Fixed bug with chosen groups
Browse files Browse the repository at this point in the history
  • Loading branch information
yngvem committed Apr 20, 2020
1 parent 47ad752 commit 0338a65
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/group_lasso/_group_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,23 +438,25 @@ def sparsity_mask(self):
)
return self.sparsity_mask_


def _get_chosen_coef_mask(self, coef_):
mean_abs_coef = abs(coef_.mean())
return np.abs(coef_) > 1e-10*mean_abs_coef


@property
def sparsity_mask_(self):
"""A boolean mask indicating whether features are used in prediction.
"""
if len(self.groups.shape) == 1 or self.groups.shape[1] == 1:
coef_ = self.coef_.mean(1)
else:
coef_ = self.coef_
mean_abs_coef = abs(coef_.mean())

return np.abs(coef_) > 1e-10 * mean_abs_coef
coef_ = self.coef_.mean(1)
return self._get_chosen_coef_mask(coef_)

@property
def chosen_groups_(self):
"""A set of the coosen group ids.
"""
return set(np.unique(self.groups[self.sparsity_mask_]))
sparsity_mask = self._get_chosen_coef_mask(self.coef_)
return set(np.unique(self.groups.ravel()[sparsity_mask.ravel()]))

def transform(self, X):
"""Remove columns corresponding to zero-valued coefficients.
Expand Down

0 comments on commit 0338a65

Please sign in to comment.