In [None]:
import os
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

import numpy as np

from amygdala_spiking.data import load_bids_group, epoch_sig
from amygdala_spiking.preproc import preprocess, preprocess_group

from neurodsp.spectral import compute_spectrum
from neurodsp.utils.norm import normalize_variance
from neurodsp.plts import plot_time_series, plot_power_spectra

from fooof import FOOOFGroup
from fooof.objs.utils import combine_fooofs

from ndspflow.motif import MotifGroup
from ndspflow.optimize import refit, refit_group

## Motifs + EMD + Decoding
This notebook computes EMD assisted motifs from amydala recordings (i.e. (n_participants, n_epochs, n_timepoints)), and uses the motif waveforms to decode trial type, either neutral or 

In [None]:
# Load bids data
bids_dir = os.path.join(os.getcwd(), 'data_bids')

raw = load_bids_group(bids_dir)

In [None]:
preproc_data = preprocess_group(raw, 'lowpass', f_range=500)

In [None]:
fs = preproc_data['01']['fs']
epoch_types = preproc_data['01']['epoch_types']

In [None]:
sigs = []
for subj_key in preproc_data:
    
    _sigs = preproc_data[subj_key]['sigs']
    
    if 'mmAL1' in _sigs.keys():
        sig_key = 'mmAL1'
    elif 'mAL1' in _sigs.keys():
        sig_key = 'mAL1'
    else:
        continue
        
    sigs.append(_sigs[sig_key])
    
sigs = np.array(sigs)

In [None]:
fgs = []
fgs_refit = []
imfs = []
pe_masks = []

drop_idxs = []
motif_group = []

for ind, sigs_subj in enumerate(sigs):
    
    sigs_subj = normalize_variance(sigs_subj, variance=1)
    
    # Specparam
    f_range = (1, 100)

    freqs, powers = compute_spectrum(sigs_subj, fs, f_range=f_range)

    fg = FOOOFGroup(verbose=False)

    fg.fit(freqs, powers, freq_range=f_range, n_jobs=-1)

    fgs.append(fg)
    
    # Refit
    fg_refit, imfs, pe_mask = refit_group(fg, sigs_subj, fs, f_range, power_thresh=0.1)

    # Remove imfs that aren't above 1/f
    imfs_filt = [imf[idx] for imf, idx in zip(imfs, pe_mask)]

    # Drop specparam models and imfs that don't have motifs
    fg_refit_filt = fg_refit.copy()

    drop_idx = [idx for idx, mask in enumerate(pe_mask) if not mask.any()]
    imfs_filt = [imf for imf, mask in zip(imfs_filt, pe_mask) if mask.any()]
    
    drop_idxs.append(drop_idx)
    pe_mask.append(pe_mask)
    
    fms = []
    for ind in range(len(fg_refit)):
        if ind not in drop_idx:
            fms.append(fg_refit.get_fooof(ind))

    fg_refit_filt = combine_fooofs(fms)

    # Compute motifs
    motif_epoch = MotifGroup(var_thresh=.01, max_clusters=4, random_state=0)        
    motif_epoch.fit(fg_refit_filt, imfs_filt, fs, progress='tqdm.notebook')
    
    motif_group.append(motif_epoch)

In [None]:
def stack_motifs(motifs, max_len=2000):
    
    motif_clusts = []
    len_motifs = []
    maxima = []

    # Iterate over group motif object, spectrum by spectrum
    for motif_ind in motifs:

        # Iterate peak by peak
        for motif_peak in motif_ind:

            # Skip nans
            if not isinstance(motif_peak.motif, list):
                continue

            for motif_clust in motif_peak.motif:
                if np.isnan(motif_clust).all():
                    continue
                
                len_motifs.append(len(motif_clust))
                maxima.append(np.argmax(motif_clust))
                motif_clusts.append(motif_clust)
        
    # 2d motif array, centered at peaks, padded with nans
    motif_array = np.zeros((len(len_motifs), max_len))
    motif_array[:, :] = np.nan

    midpoint = int(max_len/2)

    for idx, peak in enumerate(maxima):
        start_idx = midpoint - peak
        motif_array[idx][start_idx:start_idx + len(motif_clusts[idx])] = motif_clusts[idx]

    
    #drop_idxs = np.zeros(len(motif_array[0]), dtype=bool)
    #for idx in range(len(drop_idxs)):
    #    drop_idxs[idx] = np.isnan(motif_array[:, idx]).all()
        
    #motif_array = motif_array[:, ~drop_idxs]
    
    return motif_array

In [None]:
neutral_idxs = np.where(epoch_types == 'Neutral')[0].astype(int)
aversive_idxs = np.where(epoch_types == 'Aversive')[0].astype(int)

motifs_neutral = []
motifs_aversive = []

for motif_epoch, drop_idx in zip(deepcopy(motif_group), drop_idxs):
    
    labels = np.array([l for idx, l in enumerate(epoch_types) if idx not in drop_idx])
    
    neutral_idxs = np.where(labels == 'Neutral')[0].astype(int)
    aversive_idxs = np.where(labels == 'Aversive')[0].astype(int)
    
    _motifs_neutral = []
    _motifs_aversive = []
    
    for idx, motif_trial in enumerate(motif_epoch):
        
        if idx in neutral_idxs:
            _motifs_neutral.append(motif_trial)
        else:
            _motifs_aversive.append(motif_trial)
        

    motifs_neutral.append(stack_motifs(_motifs_neutral))
    motifs_aversive.append(stack_motifs(_motifs_aversive))

In [None]:
for _motif_group in motifs_neutral:
    for motif in _motif_group:
        plt.plot(motif)
plt.xlim(600, 1400)
plt.ylim(-2, 2)

In [None]:
for _motif_group in motifs_aversive:
    for motif in _motif_group:
        plt.plot(motif)
        
plt.xlim(600, 1400)
plt.ylim(-2, 2)

In [None]:
for _motif_group in motifs_neutral:
    for motif in _motif_group:
        plt.plot(motif, color='k', alpha=.1)
        
        
for _motif_group in motifs_aversive:
    for motif in _motif_group:
        plt.plot(motif, color='r', alpha=.1)

### SVM

In [None]:
from tslearn.svm import TimeSeriesSVC

In [None]:
def reshape_arr(motifs):
    
    motifs_reshape = motifs.copy()

    for idx, _motif in enumerate(motifs_reshape):

        start = np.where(~np.isnan(_motif))[0]

        _motif = np.roll(_motif, -start[0], axis=0)

        motifs_reshape[idx, :len(_motif)] = _motif
        
    return motifs_reshape

In [None]:
_motifs_neutral = np.vstack(motifs_neutral)
_motifs_aversive = np.vstack(motifs_aversive)

_motifs_neutral = reshape_arr(_motifs_neutral)
_motifs_aversive = reshape_arr(_motifs_aversive)

In [None]:
# Stack neutral and aversive motifs
_motifs = np.vstack((_motifs_neutral, _motifs_aversive))

In [None]:
# Trim excess nans
idxs = np.where(np.isnan(_motifs).sum(axis=0) == len(_motifs))[0]

_motifs = _motifs[:, :idxs[0]]

_motifs = np.reshape(_motifs, (len(_motifs), len(_motifs[0]), 1))

In [None]:
labels = np.ones(len(_motifs))
labels[:len(_motifs_neutral)] = 0
labels

In [None]:
clf = TimeSeriesSVC(kernel="gak", gamma=.1)
clf.fit(_motifs, labels)