From ce37a8abc537f13c532cb409e5d23156edfd8c52 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Fri, 6 Oct 2023 15:33:17 +0200 Subject: [PATCH 1/2] WIP intercept gamma --- skglm/datafits/single_task.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/datafits/single_task.py b/skglm/datafits/single_task.py index 1dfd79e85..cfb31b488 100644 --- a/skglm/datafits/single_task.py +++ b/skglm/datafits/single_task.py @@ -498,8 +498,8 @@ def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): grad += X_data[i] * (np.exp(Xw[idx_i]) - y[idx_i]) return grad / len(y) - def intercept_update_self(self, y, Xw): - pass + def intercept_update_step(self, y, Xw): + return np.mean(np.exp(Xw) - y) class Gamma(BaseDatafit): @@ -555,8 +555,8 @@ def gradient_scalar(self, X, y, w, Xw, j): def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): pass - def intercept_update_self(self, y, Xw): - pass + def intercept_update_step(self, y, Xw): + return 1 - np.mean(y * np.exp(Xw)) class Cox(BaseDatafit): From de5c34daa4efcb5c7ed627155329d6d2e721e457 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Tue, 10 Oct 2023 08:35:43 +0200 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Badr MOUFAD <65614794+Badr-MOUFAD@users.noreply.github.com> --- skglm/datafits/single_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skglm/datafits/single_task.py b/skglm/datafits/single_task.py index cfb31b488..18be14b87 100644 --- a/skglm/datafits/single_task.py +++ b/skglm/datafits/single_task.py @@ -499,7 +499,7 @@ def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): return grad / len(y) def intercept_update_step(self, y, Xw): - return np.mean(np.exp(Xw) - y) + return np.sum(self.raw_grad(y, Xw)) class Gamma(BaseDatafit): @@ -556,7 +556,7 @@ def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): pass def intercept_update_step(self, y, Xw): - return 1 - np.mean(y * np.exp(Xw)) + return np.sum(self.raw_grad(y, Xw)) class Cox(BaseDatafit):