Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add squared argument to distances #246

Merged
merged 4 commits into from May 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/whatsnew.rst
Expand Up @@ -27,6 +27,7 @@ v0.5.dev

- Fix :func:`pyriemann.utils.kernel.kernel_euclid` applied on non-symmetric matrices. :pr:`245` by :user:`qbarthelemy`

- Add argument ``squared`` to all distances. :pr:`246` by :user:`qbarthelemy`

v0.4 (Feb 2023)
---------------
Expand Down
9 changes: 6 additions & 3 deletions pyriemann/classification.py
Expand Up @@ -818,7 +818,7 @@ def _get_label(self, x, labs_unique):
for ip, p in enumerate(self.power_list):
for ill, ll in enumerate(labs_unique):
m[ip, ill] = distance(
x, self.covmeans_[p][ll], metric=self.metric) ** 2
x, self.covmeans_[p][ll], metric=self.metric, squared=True)

if self.method_label == 'sum_means':
ipmin = np.argmin(np.sum(m, axis=1))
Expand Down Expand Up @@ -861,8 +861,11 @@ def _predict_distances(self, X):
for ll in self.classes_:
m[p].append(
distance(
x, self.covmeans_[p][ll], metric=self.metric
) ** 2
x,
self.covmeans_[p][ll],
metric=self.metric,
squared=True,
)
)
pmin = min(m.items(), key=lambda x: np.sum(x[1]))[0]
dist.append(np.array(m[pmin]))
Expand Down
2 changes: 1 addition & 1 deletion pyriemann/datasets/simulated.py
Expand Up @@ -354,7 +354,7 @@ def make_outliers(n_matrices, mean, sigma, outlier_coeff=10,
for i in range(n_matrices):
Oi = generate_random_spd_matrix(n_dim=n_dim, random_state=random_state)
epsilon_num = outlier_coeff * sigma * n_dim
epsilon_den = distance_riemann(Oi, np.eye(n_dim))**2
epsilon_den = distance_riemann(Oi, np.eye(n_dim), squared=True)
epsilon = np.sqrt(epsilon_num / epsilon_den)
outliers[i] = mean_sqrt @ powm(Oi, epsilon) @ mean_sqrt

Expand Down
26 changes: 18 additions & 8 deletions pyriemann/stats.py
Expand Up @@ -359,7 +359,7 @@ def __init_transform(self, X):
if self.mode == 'ftest':
self.global_mean = mean_covariance(X, metric=self.mdm.metric_mean)
elif self.mode == 'pairwise':
X = pairwise_distance(X, metric=self.mdm.metric_dist)**2
X = pairwise_distance(X, metric=self.mdm.metric_dist, squared=True)
return X

def _score_ftest(self, X, y):
Expand All @@ -372,16 +372,23 @@ def _score_ftest(self, X, y):
between = 0
for ix, classe in enumerate(mdm.classes_):
di = distance(
covmeans[ix], self.global_mean, metric=mdm.metric_dist)**2
covmeans[ix],
self.global_mean,
metric=mdm.metric_dist,
squared=True,
)
between += np.sum(y == classe) * di
between /= (n_classes - 1)

# estimates within class variability
within = 0
for ix, classe in enumerate(mdm.classes_):
within += (distance(
X[y == classe], covmeans[ix], metric=mdm.metric_dist)
** 2).sum()
within += distance(
X[y == classe],
covmeans[ix],
metric=mdm.metric_dist,
squared=True,
).sum()
within /= (len(y) - n_classes)

score = between / within
Expand All @@ -400,9 +407,12 @@ def _score_ttest(self, X, y):

dist = 0
for ix, classe in enumerate(mdm.classes_):
di = (distance(
X[y == classe], covmeans[ix], metric=mdm.metric_dist)
** 2).mean()
di = distance(
X[y == classe],
covmeans[ix],
metric=mdm.metric_dist,
squared=True,
).mean()
dist += (di / np.sum(y == classe))
score = mean_dist / np.sqrt(dist)
return score
Expand Down
9 changes: 6 additions & 3 deletions pyriemann/transfer/_estimators.py
Expand Up @@ -286,9 +286,12 @@ def fit(self, X, y_enc):
self._means[d] = np.eye(n_dim)
else:
self._means[d] = mean_riemann(X[domains == d])
disp_domain = np.sum(distance(
X[domains == d], self._means[d], metric=self.metric
)**2)
disp_domain = distance(
X[domains == d],
self._means[d],
metric=self.metric,
squared=True,
).sum()
self.dispersions_[d] = disp_domain

return self
Expand Down