Skip to content

Commit

Permalink
various fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandrebarachant committed Jun 18, 2017
1 parent 4e8fad8 commit 92531c5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
25 changes: 14 additions & 11 deletions pyriemann/stats.py
Expand Up @@ -39,7 +39,7 @@ def unique_permutations(elements):
class BasePermutation():
"""Base object for permutations test"""

def test(self, X, y, verbose=True):
def test(self, X, y, groups=None, verbose=True):
"""Performs the permutation test
Parameters
Expand All @@ -48,7 +48,7 @@ def test(self, X, y, verbose=True):
The data to fit. Can be, for example a list, or an array at
least 2d.
y : array-like, optional, default: None
y : array-like
The target variable to try to predict in the case of
supervised learning.
Expand All @@ -64,15 +64,15 @@ def test(self, X, y, verbose=True):
X = self._initial_transform(X)

# get the non permuted score
self.scores_[0] = self.score(X, y)
self.scores_[0] = self.score(X, y, groups=groups)

if Npe <= self.n_perms:
print("Warning, number of unique permutations : %d" % Npe)
perms = unique_permutations(y)
ii = 0
for perm in perms:
if not numpy.array_equal(perm, y):
self.scores_[ii + 1] = self.score(X, perm)
self.scores_[ii + 1] = self.score(X, perm, groups=groups)
ii += 1
if verbose:
self._print_progress(ii)
Expand All @@ -81,7 +81,7 @@ def test(self, X, y, verbose=True):
rs = numpy.random.RandomState(self.random_state)
for ii in range(self.n_perms - 1):
perm = rs.permutation(y)
self.scores_[ii + 1] = self.score(X, perm)
self.scores_[ii + 1] = self.score(X, perm, groups=groups)
if verbose:
self._print_progress(ii)
if verbose:
Expand Down Expand Up @@ -127,7 +127,9 @@ def plot(self, nbins=10, range=None, axes=None):
y_max = axes.get_ylim()[1]
axes.plot([x_val, x_val], [0, y_max], '--r', lw=2)
x_max = axes.get_xlim()[1]
axes.text(x_max * 0.5, y_max * 0.8, 'p-value: %.3f' % self.p_value_)
x_min = axes.get_xlim()[0]
x_pos = x_min + ((x_max - x_min) * 0.25)
axes.text(x_pos, y_max * 0.8, 'p-value: %.3f' % self.p_value_)
axes.set_xlabel('Score')
axes.set_ylabel('Count')
return axes
Expand Down Expand Up @@ -197,7 +199,7 @@ def __init__(self, n_perms=100, model=MDM(), cv=3, scoring=None,
self.n_jobs = n_jobs
self.random_state = random_state

def score(self, X, y):
def score(self, X, y, groups=None):
"""Score one permutation.
Parameters
Expand All @@ -206,12 +208,13 @@ def score(self, X, y):
The data to fit. Can be, for example a list, or an array at
least 2d.
y : array-like, optional, default: None
y : array-like
The target variable to try to predict in the case of
supervised learning.
"""
score = cross_val_score(self.model, X, y, cv=self.cv,
n_jobs=self.n_jobs, scoring=self.scoring)
n_jobs=self.n_jobs, scoring=self.scoring,
groups=groups)
return score.mean()


Expand Down Expand Up @@ -300,7 +303,7 @@ def __init__(self, n_perms=100, metric='riemann', mode='pairwise',
self.random_state = random_state
self.estimator = estimator

def score(self, X, y):
def score(self, X, y, groups=None):
"""Score of a permutation.
Parameters
Expand All @@ -309,7 +312,7 @@ def score(self, X, y):
The data to fit. Can be, for example a list, or an array at
least 2d.
y : array-like, optional, default: None
y : array-like
The target variable to try to predict in the case of
supervised learning.
"""
Expand Down
6 changes: 4 additions & 2 deletions pyriemann/utils/viz.py
Expand Up @@ -23,7 +23,8 @@ def plot_confusion_matrix(targets, predictions, target_names,
return g


def plot_embedding(X, y=None, metric='riemann'):
def plot_embedding(X, y=None, metric='riemann',
title='Spectral embedding of covariances'):
"""Plot 2D embedding of covariance matrices using Diffusion maps."""
lapl = Embedding(n_components=2, metric=metric)
embd = lapl.fit_transform(X)
Expand All @@ -38,9 +39,10 @@ def plot_embedding(X, y=None, metric='riemann'):

ax.set_xlabel(r'$\varphi_1$', fontsize=16)
ax.set_ylabel(r'$\varphi_2$', fontsize=16)
ax.set_title('Spectral embedding of ERP recordings', fontsize=16)
ax.set_title(title, fontsize=16)
ax.grid(False)
ax.set_xticks([-1.0, -0.5, 0.0, +0.5, 1.0])
ax.set_yticks([-1.0, -0.5, 0.0, +0.5, 1.0])
ax.legend(list(np.unique(y)))

return fig
2 changes: 1 addition & 1 deletion requirements.txt
@@ -1,6 +1,6 @@
numpy
scipy
scikit-learn
scikit-learn>=0.18
pandas
joblib
seaborn

0 comments on commit 92531c5

Please sign in to comment.