# Multivariate Temporal Dictionary Learning for EEG
In this notebook, we reimplement experiment 1 of the article : 
Q. Barthélemy, C. Gouy-Pailler, Y. Isaac, A. Souloumiac, A. Larue, J.I. Mars,
Multivariate temporal dictionary learning for EEG,
Journal of Neuroscience Methods,
Volume 215, Issue 1,
2013,
Pages 19-28,
ISSN 0165-0270,
https://doi.org/10.1016/j.jneumeth.2013.02.001.
(https://www.sciencedirect.com/science/article/pii/S0165027013000666)

In [1]:
import numpy as np
import matplotlib.pyplot as plt
# files handling
import pickle
from os import listdir
#from os.path import exists
#from scipy.io import loadmat
import mne

# Dataset

The dataset used in this experiment is dataset 2.a of the 2008 BCI competition. A description of this dataset can be found here : http://bbci.de/competition/iv/desc_2a.pdf . Although this dataset was made for a classification task, we will use it for dictionnary learning on EEG signals.

### Load dataset
For the preprocessing, we follow the article, and apply a bandwidth filter [8hz, 30Hz]. We also discard 3 channels : 'EOG-left', 'EOG-right', and 'EOG-central', as these channels are not useful to our task.

In [39]:
def load_dataset(preprocessing=True):
    train = []
    test = []
    path = './dataset/'
    items = listdir(path)
    for item in items :
        #Load session
        raw=mne.io.read_raw_gdf(path+item, verbose=False, preload=True)
        # print(raw.info.ch_names)
        
        # Apply bandwith filter
        if preprocessing:
            raw = raw.filter(l_freq=8, h_freq=30, picks=None)
        # Convert to ndarray and discard EOG channels
        raw_eeg = raw.get_data(picks=
                                       ['EEG-Fz', 'EEG-0', 'EEG-1', 'EEG-2', 'EEG-3', 'EEG-4', 'EEG-5', 
                                        'EEG-C3', 'EEG-6', 'EEG-Cz', 'EEG-7', 'EEG-C4', 'EEG-8', 'EEG-9',
                                        'EEG-10', 'EEG-11', 'EEG-12', 'EEG-13', 'EEG-14', 'EEG-Pz', 'EEG-15', 'EEG-16'])
        # transpose to fit convention dim(signal) = (n_features, n_dims)
        raw_eeg = raw_eeg.T
        #raw_eeg.plot(n_channels=22)
        # Store in corresponding list
        if item[-5:] == 'T.gdf':
            train.append(raw_eeg)
        if item[-5:] == 'E.gdf':
            test.append(raw_eeg)
    return(train, test)

In [40]:
train, test = load_dataset(preprocessing=True)
print(train[0].shape)
print(train[0].max())
print(train[0].min())

  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw=mne.io.read_raw_gdf(path+item, verbose=False, preload=True)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw=mne.io.read_raw_gdf(path+item, verbose=False, preload=True)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw=mne.io.read_raw_gdf(path+item, verbose=False, preload=True)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw=mne.io.read_raw_gdf(path+item, verbose=False, preload=True)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw=mne.io.read_raw_gdf(path+item, verbose=False, preload=True)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw=mne.io.read_raw_gdf(path+item, verbose=False, preload=True)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw=mne.io.read_raw_gdf(path+item, verbose=False, preload=True)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw=mne.io.read_raw_gdf(path+item, verbose=False, preload=True)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  ra

(672528, 22)
0.0006457132104828345
-0.0006461001398773491


# Multivariate dictionary learning

In this section, we will use the multivariate dictionary learning (M-DLA) presented in the article. This algorithm is the alternation of two phases : 
- Multivariate sparse approximation, using multivariate orthogonal matching pursuit (M-OMP)
- Multivariate dictionary update, using a maximum likelihood criterion.

Sparse approximation

In [72]:
def place_in_signal(toplace, final_length, shift):
    extended_signal = np.zeros((final_length, toplace.shape[1]))
    extended_signal[shift : shift + toplace.shape[0], :] = toplace
    return extended_signal

def M_OMP(signal, compact_dict, n_nonzero_coefs, verbose=False):
    """
    Parameters:
    signal: array of shape (n_samples, n_dims)
        it's the data for one session, for one patient.
    compact_dict: list of arrays
        compact dictionary among which the members of the active
        dictionary will be chosen. Each element is a convolution kernel,
        i.e. an array of shape (k, n_dims), where k <= n_features and is
        kernel specific. The algorithm normalizes the kernels.
    n_nonzero_coefs: int
        number of non-zero coordinates in the final sparse coding
        vector.
    
    Returns:
    
    """
    # Initialization
    n_kernels = len(compact_dict)
    k_max_len = np.max([k.shape[0] for k in compact_dict])
    k_min_len = np.min([k.shape[0] for k in compact_dict])
    n_samples, n_dims = signal.shape
    flat_signal = signal.flatten().T
    residual = signal.copy()
    # Initialize active compct dict : [amplitude, offset, kernel]*n_nonzero_coefs
    active_comp_dict = np.zeros((n_nonzero_coefs, 3)) -1
    atoms = np.zeros((n_nonzero_coefs, n_samples, n_dims))-1
    
    
    # Main loop
    for atoms_in_estimate in range(n_nonzero_coefs):
        # print("atoms_in_estimate", atoms_in_estimate)
        # Compute correlations : array of size (n_kernels, max shift)
        # Pour la coord du shift j'ai recopié le git mais j'ai pas trop compris pourquoi ils prennent ça
        correlation = np.zeros((n_kernels,  max(n_samples, k_max_len) - min(n_samples, k_min_len) + 1))
        for l in range(n_kernels):
            # Sum the correlations over each dimension
            corr = 0
            for dim in range(n_dims):
                corr += np.correlate(residual[:, dim], compact_dict[l][:, dim], mode='valid')
            correlation[l, :len(corr)] = corr
        
        # Select kernel to add to active dictionary
        new_kernel, new_shift = np.unravel_index(np.argmax(np.abs(correlation)), correlation.shape)
        # PEUT ETRE AJOUTER UN BOUT POUR VERIFIER QUE CE KERNEL NA PAS ETE UTILISE
        
        # Update active dictionary : comp dictionary , atoms, and flattened dictionary
        new_amplitude = correlation[new_kernel, new_shift]  # Est ce quon en a besoin?
        active_comp_dict[atoms_in_estimate] = np.array([new_amplitude, new_shift, new_kernel])
        new_atom = place_in_signal(compact_dict[new_kernel], n_samples, new_shift)
        atoms[atoms_in_estimate] = new_atom
        if atoms_in_estimate == 0:
            D = new_atom.flatten().T.reshape(n_samples*n_dims, 1)
        elif atoms_in_estimate == 1:
            #D = D.flatten().T
            D = np.stack((D.flatten(), new_atom.flatten()), axis=-1)
        else:
            D = np.column_stack((D, new_atom.flatten()))

        # Compute new sparse encoding vector
        x = np.linalg.lstsq(D, flat_signal)[0]
        
        # Compute estimated signal and residual
        estimate = np.zeros((n_samples, n_dims))
        for k in range(atoms_in_estimate+1):
            estimate = estimate + x[k]*atoms[k]
        residual = signal - estimate
    if verbose:
        print("x", x)
        print("x.shape", x.shape)
        print("active_comp_dict", active_comp_dict)
        print("residual", residual)
    return x, active_comp_dict, residual

In [71]:
signal = np.random.rand(10, 3)
compact_dict = [np.random.rand(5, 3) for _ in range(7)]
n_nonzero_coefs = 6

M_OMP(signal, compact_dict, n_nonzero_coefs)



(array([-0.0890238 ,  0.32864756,  0.4446297 , -0.46276909,  0.39215957,
         0.75528642]),
 array([[ 4.09270633,  3.        ,  0.        ],
        [ 2.92969458,  0.        ,  2.        ],
        [ 1.58596475,  5.        ,  2.        ],
        [-0.54876689,  4.        ,  1.        ],
        [ 0.61686695,  1.        ,  5.        ],
        [ 0.58290207,  3.        ,  6.        ]]),
 array([[-0.02836013, -0.09217805,  0.17079153],
        [ 0.20517847, -0.14106902,  0.14714718],
        [ 0.42925784,  0.05654619,  0.02090071],
        [ 0.06816559,  0.24323707, -0.09050803],
        [-0.35083821, -0.01870044,  0.04608032],
        [-0.1547321 , -0.10709754,  0.00530149],
        [-0.11893497, -0.00765195,  0.09108222],
        [-0.11326426,  0.15152987,  0.12228309],
        [ 0.10708774,  0.10000868,  0.1830327 ],
        [-0.38578795,  0.10959943,  0.02524831]]))

Dictionary update

[0. 0. 0. 0. 0. 0.]


Dictionary learning