Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster covariance and cospectra calculation with einsum #80

Closed
dojeda opened this issue Mar 19, 2020 · 4 comments
Closed

Faster covariance and cospectra calculation with einsum #80

dojeda opened this issue Mar 19, 2020 · 4 comments

Comments

@dojeda
Copy link

dojeda commented Mar 19, 2020

I recently saw that there have been some improvements to the coherence calculation (9a6dbee) and I thought that I would like to propose some further improvements to the covariance and cospectra calculation.

In brief, I have observed that using einsum for these two operations give a speed-up of one order of magnitude. Here is my current code, but let us discuss if and how we could add this to pyriemann:

def covariances(x: np.ndarray) -> np.ndarray:
    """Calculate covariances on epoched data

    Input dimensions must be epoch, samples, channels
    """
    n = x.shape[1]
    # TODO: watch for einsum, it does not promote!
    c = np.einsum('aji,ajk->aik', x, x) / (n - 1)
    return c
def cross_spectrum(x: np.ndarray,
                   nperseg=None, noverlap=None, *,
                   fs: float = 1,
                   detrend: Optional[str] = 'constant',
                   window: str = 'boxcar',
                   #return_onesided: bool = True,
                   ) -> Tuple[np.ndarray, np.ndarray]:

    # x should be of shape (epoch, channel, sample)
    if x.ndim == 1:
        # when x is 1D, assume that is just samples; add a single channel
        x = x[np.newaxis, np.newaxis, :]
    elif x.ndim == 2:
        # when x is 2D, no manipulation is needed
        pass
    else:
        raise ValueError('Expected 1D or 2D array')

    n_channels, n_samples = x.shape
    nperseg = nperseg or n_samples
    noverlap = noverlap or 0

    # create sliding epochs to get epoch, channel, sample
    x_epoched = epoch(x, nperseg, nperseg - noverlap, axis=1)
    n_epochs = x_epoched.shape[0]

    # Handle detrending and window functions
    w = get_window(window, nperseg)
    x_epoched = x_epoched * w[np.newaxis, np.newaxis, :]
    if detrend is not None:
        x_epoched = signaltools.detrend(x_epoched, type=detrend, axis=2)

    # Apply FFT on x last dimension, X will be (epoch, channel, freq)
    freqs = np.fft.fftfreq(nperseg, 1 / fs)
    X = np.fft.fft(x_epoched)  # FFT over the last axis (samples)

    # Do a Einstein sum that will be equivalent to the following commented code:
    #
    # ## Verbose implementation ##
    # Apply x multiplied by its complex conjugate for each frequency
    # This gives dimensions epoch, channel, channel, frequency
    # cxx = np.apply_along_axis(_xxh, 1, X)
    #
    # Reorder the axis to epoch, frequency, channel, channel
    # cxx = np.rollaxis(cxx, 3, start=1)
    #
    # Average over epochs, eliminating the epoch dimension to get frequency, channel, channel
    # cxx = cxx.mean(axis=0)
    # ## end of verbose implementation ##
    #
    #
    # Using np.einsum, we get 1 order of magnitude faster (10x faster!),
    # but it is more difficult to understand. First, let us understand what
    # np.einsum('i,j->ij', a, b) does:
    # It multiplies the first axis of the first input over the each
    # element of the second input along its first axis.
    # In other words: it multiplies each element in a with each element in b
    # In other words: it does a vector outer product
    #
    # For example:
    # >>> x = np.arange(0, 3); y = np.arange(10, 13)
    # >>> x, y
    # (array([0, 1, 2]), array([10, 11, 12]))
    # >>> np.einsum('i,j->ij', x, y)
    # array([[ 0,  0,  0],
    #        [10, 11, 12],
    #        [20, 22, 24]])
    #
    # Another example:
    # >>> x = np.arange(3) + 1j
    # >>> x
    # array([0.+1.j, 1.+1.j, 2.+1.j])
    # >>> np.einsum('i,j->ij', x, x.conj())
    # array([[1.+0.j, 1.+1.j, 1.+2.j],
    #        [1.-1.j, 2.+0.j, 3.+1.j],
    #        [1.-2.j, 3.-1.j, 5.+0.j]])
    #
    # Back to our case: This is what we want to do on
    # an (I x J) array with I channels and J frequencies:
    # for each frequency, calculate x @ x.T (vector outer product):
    # np.einsum('ik,jk->kij', X, X.conj())
    #
    # for the 3D case (that is, with epochs)
    # np.einsum('ijl,ikl->iljk', X, X.conj())
    # or, for a more verbose approach, say the indices represent the following:
    # e: epoch
    # c: channel
    # f: frequency
    # h: channel on the conjugate tranpose (this should be the same size as c)
    # Then, the operation can be rewritten as:
    # np.einsum('ecf,ehf->efch', x, x.conj())
    #
    # Finally, since the final step is to do a mean over the epochs, we can
    # sum the "e" axis (by dropping the "e" axis on the output) and divide by
    # the number of epochs:
    cxx = np.einsum('ecf,ehf->fch', X, X.conj()) / n_epochs

    return freqs, cxx
@alexandrebarachant
Copy link
Collaborator

thanks,

its true that the current implementation is not super optimized.
most of the compute cost in pyriemann usually come from eignevalue decomposition and similar operation, so i never focused on optimizing the estimations function.

I will look into it, but feel free to open a PR if you feel to.

@dojeda
Copy link
Author

dojeda commented Mar 19, 2020

I agree that the compute cost of the eigenvalue decomposition is currently one of the most important bottlenecks. I have never tried it, but I have been thinking for a while that a functools.lru_cache on the function that does the eigen decomposition (or a wrapper to numpy) could speed things up considerably for a memory cost. If I recall last time I looked under the hood, I think each covariance matrix is decomposed after every loop on the mean covariance matrix estimation.

However, numpy arrays are not hashable (which is needed for the cache to work), so lru_cache needs to be contorted to use some other hash (fortunately, joblib has hash implementations for numpy arrays).

I can look this up and come back if I have some benchmarks to share.

@qbarthelemy
Copy link
Member

This speed-up is promissing!

Some remarks:

  • I think that it is very interesting to define a function cross_spectrum, and then to define co_spectrum and quad_spectrum as the real and imaginary parts of the result. This will open the way for HPD matrices processing!

  • we should also make the difference between: (1) an instantaneous cross-spectrum computed on a single epoch (which is a rank-2 matrix), and (2) an averaged cross-spectrum over several epochs (to obtain a full-rank matrix, by Euclidean averaging of several instantaneous cross-spectra).

@qbarthelemy
Copy link
Member

qbarthelemy commented Aug 11, 2022

For covariances() function, I obtain this comparison between both implementations.
Regarding results, I may have missed something.

issue_80

import timeit
import numpy as np
from matplotlib import pyplot as plt
from pyriemann.utils.covariance import covariances


def covariances_new(x):
    n = x.shape[1]
    c = np.einsum('aji,ajk->aik', x, x) / (n - 1)
    return c

n_mats, n_times = 10, 256
dims = [5, 10, 15, 20, 25, 30]
n_dims, n_reps = len(dims), 50

t_old, t_new  = np.zeros((n_dims, n_reps)), np.zeros((n_dims, n_reps))

for i, n_dim in enumerate(dims):

    X = np.random.random((n_mats, n_dim, n_times))

    for j in range(n_reps):

        t0 = timeit.default_timer()
        covariances(X, estimator='scm')
        t_old[i, j] = timeit.default_timer() - t0

        t0 = timeit.default_timer()
        covariances_new(X)
        t_new[i, j] = timeit.default_timer() - t0

fig, ax = plt.subplots()
ax.errorbar(dims, t_old.mean(axis=-1), yerr=t_old.std(axis=-1), label='old')
ax.errorbar(dims, t_new.mean(axis=-1), yerr=t_new.std(axis=-1), label='new')
ax.set(xlabel='Matrices dimension', ylabel='Time')
plt.legend(loc='upper left')
plt.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants