Skip to content

Commit

Permalink
Add sample_weight parameter in TLCenter, TLStretch and TLRotate (#273)
Browse files Browse the repository at this point in the history
* add sample_weight

* check_weight for each domain

* complete docstring

* minor modifs

* test sample_weight

* sample_weight in recenter

* final modifs

* use same metric for rct and rot

* add rotation in simulated data

* fix random state

---------

Co-authored-by: qbarthelemy <q.barthelemy@gmail.com>
Co-authored-by: Pedro L. C. Rodrigues <pedro.rodrigues@melix.org>
  • Loading branch information
3 people committed Jan 16, 2024
1 parent 02fd6e0 commit b9148ec
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 49 deletions.
2 changes: 2 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -22,6 +22,8 @@ v0.6.dev

- Correct :func:`pyriemann.utils.distance.distance_wasserstein` and :func:`pyriemann.utils.distance.distance_kullback`, keeping only real part. :pr:`267` by :user:`qbarthelemy`

- Add ``sample_weight`` parameter in TLCenter, TLStretch and TLRotate. :pr:`273` by :user:`apmellot`

- Deprecate input ``covmats`` for mean functions, renamed into ``X``. :pr:`252` by :user:`qbarthelemy`

- Add support for complex covariance estimation for 'lwf', 'mcd', 'oas' and 'sch' estimators. :pr:`274` by :user:`gabelstein`
Expand Down
91 changes: 62 additions & 29 deletions pyriemann/transfer/_estimators.py
Expand Up @@ -13,6 +13,7 @@
from ..utils.distance import distance
from ..utils.base import invsqrtm, powm, sqrtm
from ..utils.geodesic import geodesic
from ..utils.utils import check_weights
from ._rotate import _get_rotation_matrix
from ..classification import MDM, _check_metric
from ..preprocessing import Whitening
Expand Down Expand Up @@ -113,7 +114,7 @@ class TLCenter(BaseEstimator, TransformerMixin):
Attributes
----------
recenter_ : dict
Dictionary with key=domain_name and value=domain_mean
Dictionary with key=domain_name and value=domain_mean.
References
----------
Expand All @@ -133,7 +134,7 @@ def __init__(self, target_domain, metric='riemann'):
self.target_domain = target_domain
self.metric = metric

def fit(self, X, y_enc):
def fit(self, X, y_enc, sample_weight=None):
"""Fit TLCenter.
Calculate the mean of all matrices in each domain.
Expand All @@ -144,17 +145,24 @@ def fit(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
Returns
-------
self : TLCenter instance
The TLCenter instance.
"""
_, _, domains = decode_domains(X, y_enc)
n_matrices, _, _ = X.shape
sample_weight = check_weights(sample_weight, n_matrices)

self.recenter_ = {}
for d in np.unique(domains):
idx = domains == d
self.recenter_[d] = Whitening(metric=self.metric).fit(X[idx])
self.recenter_[d] = Whitening(metric=self.metric).fit(
X[idx], sample_weight=sample_weight[idx]
)
return self

def transform(self, X, y_enc=None):
Expand All @@ -175,7 +183,7 @@ def transform(self, X, y_enc=None):
# Used during inference, apply recenter from specified target domain.
return self.recenter_[self.target_domain].transform(X)

def fit_transform(self, X, y_enc):
def fit_transform(self, X, y_enc, sample_weight=None):
"""Fit TLCenter and then transform data points.
Calculate the mean of all matrices in each domain and then recenter
Expand All @@ -192,15 +200,18 @@ def fit_transform(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
Returns
-------
X : ndarray, shape (n_matrices, n_classes)
Set of SPD matrices with mean in the Identity.
"""
# Used during fit, in pipeline
self.fit(X, y_enc)
self.fit(X, y_enc, sample_weight)
_, _, domains = decode_domains(X, y_enc)

X_rct = np.zeros_like(X)
for d in np.unique(domains):
idx = domains == d
Expand Down Expand Up @@ -259,7 +270,7 @@ def __init__(self, target_domain, final_dispersion=1.0,
self.centered_data = centered_data
self.metric = metric

def fit(self, X, y_enc):
def fit(self, X, y_enc, sample_weight=None):
"""Fit TLStretch.
Calculate the dispersion around the mean for each domain.
Expand All @@ -270,29 +281,35 @@ def fit(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
Returns
-------
self : TLStretch instance
The TLStretch instance.
"""

_, _, domains = decode_domains(X, y_enc)
n_dim = X[0].shape[1]
self._means = {}
self.dispersions_ = {}
n_matrices, n_channels, _ = X.shape
sample_weight = check_weights(sample_weight, n_matrices)

self._means, self.dispersions_ = {}, {}
for d in np.unique(domains):
idx = domains == d
sample_weight_d = check_weights(sample_weight[idx], np.sum(idx))
if self.centered_data:
self._means[d] = np.eye(n_dim)
self._means[d] = np.eye(n_channels)
else:
self._means[d] = mean_riemann(X[domains == d])
disp_domain = distance(
X[domains == d],
self._means[d] = mean_riemann(
X[idx], sample_weight=sample_weight_d
)
dist = distance(
X[idx],
self._means[d],
metric=self.metric,
squared=True,
).mean()
self.dispersions_[d] = disp_domain
)
self.dispersions_[d] = np.sum(sample_weight_d * np.squeeze(dist))

return self

Expand Down Expand Up @@ -342,7 +359,7 @@ def transform(self, X, y_enc=None):

return X_str

def fit_transform(self, X, y_enc):
def fit_transform(self, X, y_enc, sample_weight=None):
"""Fit TLStretch and then transform data points.
Calculate the dispersion around the mean for each domain and then
Expand All @@ -359,6 +376,8 @@ def fit_transform(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
Returns
-------
Expand All @@ -367,8 +386,9 @@ def fit_transform(self, X, y_enc):
"""

# used during fit, in pipeline
self.fit(X, y_enc)
self.fit(X, y_enc, sample_weight)
_, _, domains = decode_domains(X, y_enc)

X_str = np.zeros_like(X)
for d in np.unique(domains):
idx = domains == d
Expand Down Expand Up @@ -450,7 +470,7 @@ def __init__(self, target_domain, weights=None, metric='euclid', n_jobs=1):
self.metric = metric
self.n_jobs = n_jobs

def fit(self, X, y_enc):
def fit(self, X, y_enc, sample_weight=None):
"""Fit TLRotate.
Calculate the rotations matrices to transform each source domain into
Expand All @@ -462,6 +482,8 @@ def fit(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
Returns
-------
Expand All @@ -470,30 +492,38 @@ def fit(self, X, y_enc):
"""

_, _, domains = decode_domains(X, y_enc)
n_matrices, _, _ = X.shape
sample_weight = check_weights(sample_weight, n_matrices)

idx = domains == self.target_domain
X_target, y_target = X[idx], y_enc[idx]
M_target = np.stack([
mean_riemann(X_target[y_target == label])
mean_riemann(X_target[y_target == label],
sample_weight=sample_weight[idx][y_target == label])
for label in np.unique(y_target)
])

source_names = np.unique(domains)
source_names = source_names[source_names != self.target_domain]
source_domains = np.unique(domains)
source_domains = source_domains[source_domains != self.target_domain]
rotations = Parallel(n_jobs=self.n_jobs)(
delayed(_get_rotation_matrix)(
np.stack([
mean_riemann(X[domains == d][y_enc[domains == d] == label])
for label in np.unique(y_enc[domains == d])
mean_riemann(
X[domains == d][y_enc[domains == d] == label],
sample_weight=sample_weight[domains == d][
y_enc[domains == d] == label
]
) for label in np.unique(y_enc[domains == d])
]),
M_target,
self.weights,
metric=self.metric,
) for d in source_names
) for d in source_domains
)

self.rotations_ = {}
for di, roti in zip(source_names, rotations):
self.rotations_[di] = roti
for d, rot in zip(source_domains, rotations):
self.rotations_[d] = rot

return self

Expand All @@ -519,7 +549,7 @@ def transform(self, X, y_enc=None):
# used during inference on target domain
return X

def fit_transform(self, X, y_enc):
def fit_transform(self, X, y_enc, sample_weight=None):
"""Fit TLRotate and then transform data points.
Calculate the rotation matrix for matching each source domain to the
Expand All @@ -536,6 +566,8 @@ def fit_transform(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
Returns
-------
Expand All @@ -544,8 +576,9 @@ def fit_transform(self, X, y_enc):
"""

# used during fit in pipeline, rotate each source domain
self.fit(X, y_enc)
self.fit(X, y_enc, sample_weight)
_, _, domains = decode_domains(X, y_enc)

X_rot = np.zeros_like(X)
for d in np.unique(domains):
idx = domains == d
Expand Down

0 comments on commit b9148ec

Please sign in to comment.