Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample_weight parameter in TLCenter, TLStretch and TLRotate #273

Merged
merged 12 commits into from Jan 16, 2024
1 change: 1 addition & 0 deletions doc/whatsnew.rst
Expand Up @@ -22,6 +22,7 @@ v0.6.dev

- Correct :func:`pyriemann.utils.distance.distance_wasserstein` and :func:`pyriemann.utils.distance.distance_kullback`, keeping only real part. :pr:`267` by :user:`qbarthelemy`

- Add ``sample_weight`` parameter in TLCenter, TLStretch and TLRotate. :pr:`273` by :user:`apmellot`

v0.5 (Jun 2023)
---------------
Expand Down
91 changes: 62 additions & 29 deletions pyriemann/transfer/_estimators.py
Expand Up @@ -13,6 +13,7 @@
from ..utils.distance import distance
from ..utils.base import invsqrtm, powm, sqrtm
from ..utils.geodesic import geodesic
from ..utils.utils import check_weights
from ._rotate import _get_rotation_matrix
from ..classification import MDM, _check_metric
from ..preprocessing import Whitening
Expand Down Expand Up @@ -113,7 +114,7 @@ class TLCenter(BaseEstimator, TransformerMixin):
Attributes
----------
recenter_ : dict
Dictionary with key=domain_name and value=domain_mean
Dictionary with key=domain_name and value=domain_mean.

References
----------
Expand All @@ -133,7 +134,7 @@ def __init__(self, target_domain, metric='riemann'):
self.target_domain = target_domain
self.metric = metric

def fit(self, X, y_enc):
def fit(self, X, y_enc, sample_weight=None):
"""Fit TLCenter.

Calculate the mean of all matrices in each domain.
Expand All @@ -144,17 +145,24 @@ def fit(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
self : TLCenter instance
The TLCenter instance.
"""
_, _, domains = decode_domains(X, y_enc)
n_matrices, _, _ = X.shape
sample_weight = check_weights(sample_weight, n_matrices)

self.recenter_ = {}
for d in np.unique(domains):
idx = domains == d
self.recenter_[d] = Whitening(metric=self.metric).fit(X[idx])
self.recenter_[d] = Whitening(metric=self.metric).fit(
X[idx], sample_weight=sample_weight[idx]
)
return self

def transform(self, X, y_enc=None):
Expand All @@ -175,7 +183,7 @@ def transform(self, X, y_enc=None):
# Used during inference, apply recenter from specified target domain.
return self.recenter_[self.target_domain].transform(X)

def fit_transform(self, X, y_enc):
def fit_transform(self, X, y_enc, sample_weight=None):
"""Fit TLCenter and then transform data points.

Calculate the mean of all matrices in each domain and then recenter
Expand All @@ -192,15 +200,18 @@ def fit_transform(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
X : ndarray, shape (n_matrices, n_classes)
Set of SPD matrices with mean in the Identity.
"""
# Used during fit, in pipeline
self.fit(X, y_enc)
self.fit(X, y_enc, sample_weight)
_, _, domains = decode_domains(X, y_enc)

X_rct = np.zeros_like(X)
for d in np.unique(domains):
idx = domains == d
Expand Down Expand Up @@ -259,7 +270,7 @@ def __init__(self, target_domain, final_dispersion=1.0,
self.centered_data = centered_data
self.metric = metric

def fit(self, X, y_enc):
def fit(self, X, y_enc, sample_weight=None):
"""Fit TLStretch.

Calculate the dispersion around the mean for each domain.
Expand All @@ -270,29 +281,35 @@ def fit(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
self : TLStretch instance
The TLStretch instance.
"""

_, _, domains = decode_domains(X, y_enc)
n_dim = X[0].shape[1]
self._means = {}
self.dispersions_ = {}
n_matrices, n_channels, _ = X.shape
sample_weight = check_weights(sample_weight, n_matrices)

self._means, self.dispersions_ = {}, {}
for d in np.unique(domains):
idx = domains == d
sample_weight_d = check_weights(sample_weight[idx], np.sum(idx))
if self.centered_data:
self._means[d] = np.eye(n_dim)
self._means[d] = np.eye(n_channels)
else:
self._means[d] = mean_riemann(X[domains == d])
disp_domain = distance(
X[domains == d],
self._means[d] = mean_riemann(
X[idx], sample_weight=sample_weight_d
)
dist = distance(
X[idx],
self._means[d],
metric=self.metric,
squared=True,
).mean()
self.dispersions_[d] = disp_domain
)
self.dispersions_[d] = np.sum(sample_weight_d * np.squeeze(dist))

return self

Expand Down Expand Up @@ -342,7 +359,7 @@ def transform(self, X, y_enc=None):

return X_str

def fit_transform(self, X, y_enc):
def fit_transform(self, X, y_enc, sample_weight=None):
"""Fit TLStretch and then transform data points.

Calculate the dispersion around the mean for each domain and then
Expand All @@ -359,6 +376,8 @@ def fit_transform(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
Expand All @@ -367,8 +386,9 @@ def fit_transform(self, X, y_enc):
"""

# used during fit, in pipeline
self.fit(X, y_enc)
self.fit(X, y_enc, sample_weight)
_, _, domains = decode_domains(X, y_enc)

X_str = np.zeros_like(X)
for d in np.unique(domains):
idx = domains == d
Expand Down Expand Up @@ -450,7 +470,7 @@ def __init__(self, target_domain, weights=None, metric='euclid', n_jobs=1):
self.metric = metric
self.n_jobs = n_jobs

def fit(self, X, y_enc):
def fit(self, X, y_enc, sample_weight=None):
"""Fit TLRotate.

Calculate the rotations matrices to transform each source domain into
Expand All @@ -462,6 +482,8 @@ def fit(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
Expand All @@ -470,30 +492,38 @@ def fit(self, X, y_enc):
"""

_, _, domains = decode_domains(X, y_enc)
n_matrices, _, _ = X.shape
sample_weight = check_weights(sample_weight, n_matrices)

idx = domains == self.target_domain
X_target, y_target = X[idx], y_enc[idx]
M_target = np.stack([
mean_riemann(X_target[y_target == label])
mean_riemann(X_target[y_target == label],
sample_weight=sample_weight[idx][y_target == label])
for label in np.unique(y_target)
])

source_names = np.unique(domains)
source_names = source_names[source_names != self.target_domain]
source_domains = np.unique(domains)
source_domains = source_domains[source_domains != self.target_domain]
rotations = Parallel(n_jobs=self.n_jobs)(
delayed(_get_rotation_matrix)(
np.stack([
mean_riemann(X[domains == d][y_enc[domains == d] == label])
for label in np.unique(y_enc[domains == d])
mean_riemann(
X[domains == d][y_enc[domains == d] == label],
sample_weight=sample_weight[domains == d][
y_enc[domains == d] == label
]
) for label in np.unique(y_enc[domains == d])
]),
M_target,
self.weights,
metric=self.metric,
) for d in source_names
) for d in source_domains
)

self.rotations_ = {}
for di, roti in zip(source_names, rotations):
self.rotations_[di] = roti
for d, rot in zip(source_domains, rotations):
self.rotations_[d] = rot

return self

Expand All @@ -519,7 +549,7 @@ def transform(self, X, y_enc=None):
# used during inference on target domain
return X

def fit_transform(self, X, y_enc):
def fit_transform(self, X, y_enc, sample_weight=None):
"""Fit TLRotate and then transform data points.

Calculate the rotation matrix for matching each source domain to the
Expand All @@ -536,6 +566,8 @@ def fit_transform(self, X, y_enc):
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.

Returns
-------
Expand All @@ -544,8 +576,9 @@ def fit_transform(self, X, y_enc):
"""

# used during fit in pipeline, rotate each source domain
self.fit(X, y_enc)
self.fit(X, y_enc, sample_weight)
_, _, domains = decode_domains(X, y_enc)

X_rot = np.zeros_like(X)
for d in np.unique(domains):
idx = domains == d
Expand Down
76 changes: 58 additions & 18 deletions tests/test_transfer.py
Expand Up @@ -32,6 +32,7 @@
)
from pyriemann.utils.distance import distance, distance_riemann
from pyriemann.utils.mean import mean_covariance, mean_riemann
from pyriemann.utils.utils import check_weights

rndstate = 1234

Expand All @@ -55,23 +56,35 @@ def test_encode_decode_domains(rndstate):


@pytest.mark.parametrize("metric", ["riemann"])
def test_tlcenter(rndstate, metric):
@pytest.mark.parametrize("sample_weight", [True, False])
def test_tlcenter(rndstate, metric, sample_weight):
"""Test pipeline for recentering data to Identity"""
# check if the global mean of the domains is indeed Identity
rct = TLCenter(target_domain='target_domain', metric=metric)
X, y_enc = make_classification_transfer(
n_matrices=25, random_state=rndstate)
X_rct = rct.fit_transform(X, y_enc)
if sample_weight:
sample_weight_ = np.random.rand(len(y_enc))
else:
sample_weight_ = None
X_rct = rct.fit_transform(X, y_enc, sample_weight_)
_, _, domain = decode_domains(X_rct, y_enc)
for d in np.unique(domain):
Xd = X_rct[domain == d]
Md = mean_covariance(Xd, metric=metric)
idx = domain == d
Xd = X_rct[idx]
if sample_weight:
sample_weight_d = check_weights(sample_weight_[idx], np.sum(idx))
Md = mean_covariance(Xd, metric='riemann',
sample_weight=sample_weight_d)
else:
Md = mean_covariance(Xd, metric='riemann')
assert Md == pytest.approx(np.eye(2))


@pytest.mark.parametrize("centered_data", [True, False])
@pytest.mark.parametrize("metric", ["riemann"])
def test_tlstretch(rndstate, centered_data, metric):
@pytest.mark.parametrize("sample_weight", [True, False])
def test_tlstretch(rndstate, centered_data, metric, sample_weight):
"""Test pipeline for stretching data"""
# check if the dispersion of the dataset indeed decreases to 1
tlstr = TLStretch(
Expand All @@ -82,40 +95,67 @@ def test_tlstretch(rndstate, centered_data, metric):
)
X, y_enc = make_classification_transfer(
n_matrices=25, class_disp=2.0, random_state=rndstate)

if sample_weight:
sample_weight_ = np.random.rand(len(y_enc))
else:
sample_weight_ = None
if centered_data: # ensure that data is indeed centered on each domain
tlrct = TLCenter(target_domain='target_domain', metric=metric)
X = tlrct.fit_transform(X, y_enc)
X_str = tlstr.fit_transform(X, y_enc)
X = tlrct.fit_transform(X, y_enc, sample_weight=sample_weight_)

X_str = tlstr.fit_transform(X, y_enc, sample_weight=sample_weight_)

_, _, domain = decode_domains(X_str, y_enc)
for d in np.unique(domain):
Xd = X_str[domain == d]
Md = mean_riemann(Xd)
disp = np.mean(distance(Xd, Md, metric=metric, squared=True))
idx = domain == d
Xd = X_str[idx]
if sample_weight:
sample_weight_d = check_weights(sample_weight_[idx], np.sum(idx))
Md = mean_riemann(Xd, sample_weight=sample_weight_d)
dist = distance(Xd, Md, metric=metric, squared=True)
disp = np.sum(sample_weight_d * np.squeeze(dist))
else:
Md = mean_riemann(Xd)
disp = np.mean(distance(Xd, Md, metric=metric, squared=True))
assert np.isclose(disp, 1.0)


@pytest.mark.parametrize("metric", ["euclid", "riemann"])
def test_tlrotate(rndstate, metric):
"""Test pipeline for rotating the datasets"""
@pytest.mark.parametrize("sample_weight", [True, False])
def test_tlrotate_fit_transform(rndstate, metric, sample_weight):
"""Test fit_transform method for rotating the datasets"""
# check if the distance between the classes of each domain is reduced
X, y_enc = make_classification_transfer(
n_matrices=50, class_sep=3, class_disp=1.0, random_state=rndstate)

if sample_weight:
sample_weight_ = np.random.rand(len(y_enc))
else:
sample_weight_ = None
sample_weight_ = check_weights(sample_weight_, len(y_enc))

rct = TLCenter(target_domain='target_domain')
X_rct = rct.fit_transform(X, y_enc)
X_rct = rct.fit_transform(X, y_enc, sample_weight_)
rot = TLRotate(target_domain='target_domain', metric=metric)
X_rot = rot.fit_transform(X_rct, y_enc)
X_rot = rot.fit_transform(X_rct, y_enc, sample_weight=sample_weight_)

_, y, domain = decode_domains(X_rot, y_enc)
for label in np.unique(y):
d = 'source_domain'
M_rct_label_source = mean_riemann(
X_rct[domain == d][y[domain == d] == label])
X_rct[domain == d][y[domain == d] == label],
sample_weight=sample_weight_[domain == d][y[domain == d] == label])
M_rot_label_source = mean_riemann(
X_rot[domain == d][y[domain == d] == label])
X_rot[domain == d][y[domain == d] == label],
sample_weight=sample_weight_[domain == d][y[domain == d] == label])
d = 'target_domain'
M_rct_label_target = mean_riemann(
X_rct[domain == d][y[domain == d] == label])
X_rct[domain == d][y[domain == d] == label],
sample_weight=sample_weight_[domain == d][y[domain == d] == label])
M_rot_label_target = mean_riemann(
X_rot[domain == d][y[domain == d] == label])
X_rot[domain == d][y[domain == d] == label],
sample_weight=sample_weight_[domain == d][y[domain == d] == label])
d_rct = distance_riemann(M_rct_label_source, M_rct_label_target)
d_rot = distance_riemann(M_rot_label_source, M_rot_label_target)
assert d_rot <= d_rct
Expand Down