Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample_weight parameter in TLCenter, TLStretch and TLRotate #273

Merged
merged 12 commits into from Jan 16, 2024
1 change: 1 addition & 0 deletions doc/whatsnew.rst
Expand Up @@ -22,6 +22,7 @@ v0.6.dev

- Correct :func:`pyriemann.utils.distance.distance_wasserstein` and :func:`pyriemann.utils.distance.distance_kullback`, keeping only real part. :pr:`267` by :user:`qbarthelemy`

- Add ``sample_weight`` parameter in TLCenter, TLStretch and TLRotate. :pr:`273` by :user:`apmellot`

v0.5 (Jun 2023)
---------------
Expand Down
80 changes: 57 additions & 23 deletions pyriemann/transfer/_estimators.py
Expand Up @@ -13,6 +13,7 @@
from ..utils.distance import distance
from ..utils.base import invsqrtm, powm, sqrtm
from ..utils.geodesic import geodesic
from ..utils.utils import check_weights
from ._rotate import _get_rotation_matrix
from ..classification import MDM, _check_metric
from ..preprocessing import Whitening
Expand Down Expand Up @@ -133,7 +134,7 @@ def __init__(self, target_domain, metric='riemann'):
self.target_domain = target_domain
self.metric = metric

def fit(self, X, y_enc):
def fit(self, X, y_enc, sample_weight=None):
"""Fit TLCenter.

Calculate the mean of all matrices in each domain.
Expand All @@ -144,17 +145,24 @@ def fit(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
self : TLCenter instance
The TLCenter instance.
"""
_, _, domains = decode_domains(X, y_enc)
n_matrices, _, _ = X.shape
sample_weight = check_weights(sample_weight, n_matrices)

self.recenter_ = {}
for d in np.unique(domains):
idx = domains == d
self.recenter_[d] = Whitening(metric=self.metric).fit(X[idx])
self.recenter_[d] = Whitening(metric=self.metric).fit(
X[idx], sample_weight=sample_weight[idx]
)
return self

def transform(self, X, y_enc=None):
Expand All @@ -175,7 +183,7 @@ def transform(self, X, y_enc=None):
# Used during inference, apply recenter from specified target domain.
return self.recenter_[self.target_domain].transform(X)

def fit_transform(self, X, y_enc):
def fit_transform(self, X, y_enc, sample_weight=None):
"""Fit TLCenter and then transform data points.

Calculate the mean of all matrices in each domain and then recenter
Expand All @@ -192,15 +200,18 @@ def fit_transform(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
X : ndarray, shape (n_matrices, n_classes)
Set of SPD matrices with mean in the Identity.
"""
# Used during fit, in pipeline
self.fit(X, y_enc)
self.fit(X, y_enc, sample_weight)
_, _, domains = decode_domains(X, y_enc)

X_rct = np.zeros_like(X)
for d in np.unique(domains):
idx = domains == d
Expand Down Expand Up @@ -259,7 +270,7 @@ def __init__(self, target_domain, final_dispersion=1.0,
self.centered_data = centered_data
self.metric = metric

def fit(self, X, y_enc):
def fit(self, X, y_enc, sample_weight=None):
"""Fit TLStretch.

Calculate the dispersion around the mean for each domain.
Expand All @@ -270,29 +281,37 @@ def fit(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
self : TLStretch instance
The TLStretch instance.
"""

_, _, domains = decode_domains(X, y_enc)
n_dim = X[0].shape[1]
self._means = {}
self.dispersions_ = {}
n_matrices, n_channels, _ = X.shape
sample_weight = check_weights(sample_weight, n_matrices)

self._means, self.dispersions_ = {}, {}
for d in np.unique(domains):
idx = domains == d
sample_weight_d = check_weights(sample_weight[idx], np.sum(idx))
if self.centered_data:
self._means[d] = np.eye(n_dim)
self._means[d] = np.eye(n_channels)
else:
self._means[d] = mean_riemann(X[domains == d])
disp_domain = distance(
X[domains == d],
self._means[d] = mean_riemann(
X[idx], sample_weight=sample_weight_d
)
dist_domain = distance(
X[idx],
self._means[d],
metric=self.metric,
squared=True,
).mean()
self.dispersions_[d] = disp_domain
)
self.dispersions_[d] = (
sample_weight_d * np.squeeze(dist_domain)
).sum()

return self

Expand Down Expand Up @@ -342,7 +361,7 @@ def transform(self, X, y_enc=None):

return X_str

def fit_transform(self, X, y_enc):
def fit_transform(self, X, y_enc, sample_weight=None):
"""Fit TLStretch and then transform data points.

Calculate the dispersion around the mean for each domain and then
Expand All @@ -359,6 +378,8 @@ def fit_transform(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
Expand All @@ -367,8 +388,9 @@ def fit_transform(self, X, y_enc):
"""

# used during fit, in pipeline
self.fit(X, y_enc)
self.fit(X, y_enc, sample_weight)
_, _, domains = decode_domains(X, y_enc)

X_str = np.zeros_like(X)
for d in np.unique(domains):
idx = domains == d
Expand Down Expand Up @@ -450,7 +472,7 @@ def __init__(self, target_domain, weights=None, metric='euclid', n_jobs=1):
self.metric = metric
self.n_jobs = n_jobs

def fit(self, X, y_enc):
def fit(self, X, y_enc, sample_weight=None):
"""Fit TLRotate.

Calculate the rotations matrices to transform each source domain into
Expand All @@ -462,6 +484,8 @@ def fit(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
Expand All @@ -470,11 +494,14 @@ def fit(self, X, y_enc):
"""

_, _, domains = decode_domains(X, y_enc)
n_matrices, _, _ = X.shape
sample_weight = check_weights(sample_weight, n_matrices)

idx = domains == self.target_domain
X_target, y_target = X[idx], y_enc[idx]
M_target = np.stack([
mean_riemann(X_target[y_target == label])
mean_riemann(X_target[y_target == label],
sample_weight=sample_weight[idx][y_target == label])
for label in np.unique(y_target)
])

Expand All @@ -483,8 +510,12 @@ def fit(self, X, y_enc):
rotations = Parallel(n_jobs=self.n_jobs)(
delayed(_get_rotation_matrix)(
np.stack([
mean_riemann(X[domains == d][y_enc[domains == d] == label])
for label in np.unique(y_enc[domains == d])
mean_riemann(
X[domains == d][y_enc[domains == d] == label],
sample_weight=sample_weight[domains == d][
y_enc[domains == d] == label
]
) for label in np.unique(y_enc[domains == d])
]),
M_target,
self.weights,
Expand Down Expand Up @@ -519,7 +550,7 @@ def transform(self, X, y_enc=None):
# used during inference on target domain
return X

def fit_transform(self, X, y_enc):
def fit_transform(self, X, y_enc, sample_weight=None):
"""Fit TLRotate and then transform data points.

Calculate the rotation matrix for matching each source domain to the
Expand All @@ -536,6 +567,8 @@ def fit_transform(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
Expand All @@ -544,8 +577,9 @@ def fit_transform(self, X, y_enc):
"""

# used during fit in pipeline, rotate each source domain
self.fit(X, y_enc)
self.fit(X, y_enc, sample_weight)
_, _, domains = decode_domains(X, y_enc)

X_rot = np.zeros_like(X)
for d in np.unique(domains):
idx = domains == d
Expand Down