In [None]:
# Author: Proloy das <proloyd94@gmail.com>
# License: BSD (3-clause)
%matplotlib notebook

In [None]:
import numpy as np
import os
# import eelbrain
import mne
from mne.datasets import sample
from matplotlib import pyplot as plt
# from eelbrain import save

In [None]:
data_path = sample.data_path()
fwd_fname = os.path.join(data_path, 'MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif')
ave_fname = os.path.join(data_path, 'MEG/sample/sample_audvis-ave.fif')
cov_fname = os.path.join(data_path, 'MEG/sample/sample_audvis-cov.fif')
subjects_dir = os.path.join(data_path, 'subjects')
condition = 'Left Auditory'
subject = 'sample'

## Simulate raw data

In [None]:
info = mne.io.read_info(ave_fname)
info['bads'] = ['EEG 053']
with info._unlock():
    info['sfreq'] = 250.
tstep = 1 / info['sfreq']
forward = mne.read_forward_solution(fwd_fname)
src = forward['src']
noise_cov = mne.read_cov(cov_fname)

### Region to activate<br>
To select a region to activate, we use the caudal middle frontal to grow a region of interest.

In [None]:
roi = 'caudalmiddlefrontal'
rois = ['transversetemporal', 'precentral']
hemis = ['lh', 'rh']
selected_labels = [mne.read_labels_from_annot(
                                        subject,
                                        regexp=f'{roi}-{hemi}',
                                        subjects_dir=subjects_dir
                                             )[0] 
                   for roi, hemi in zip(rois, hemis)]
location = 'center'  # Use the center of the region as a seed.
extent = 20.  # Extent in mm of the region.
labels = [mne.label.select_sources(
    subject, selected_label, location=location, extent=extent,
    subjects_dir=subjects_dir) for selected_label in selected_labels]

### Define the time course of the activity for each source of the region to activate.
Here we use two AR processes: one slow (1.6Hz), and one fast oscillations(12Hz).

In [None]:
from purdonlabmeeg._temporal_dynamics_utils.tests._generate_data import ARData

ntimes = int(np.round(info['sfreq'] * 20 * 10.))
slow_data = ARData(ntimes + 200, noise_var=0.1,
                   coeffs=[2*np.cos(2*np.pi*1.6/info['sfreq']),
                           -0.99983],
                   num_prev=2)
fast_data = ARData(ntimes +  200, noise_var=0.01,
                   coeffs=[2*np.cos(2*np.pi*12/info['sfreq']),
                           -0.985],
                   num_prev=2)

source_time_series1 = 2e-9 * slow_data.y[200:] / slow_data.y[200:].std()
source_time_series2 = 1.5e-9 * fast_data.y[200:] / fast_data.y[200:].std()
source_time_serieses = (source_time_series1, source_time_series2)

In [None]:
tx = np.arange(ntimes) / info['sfreq']
fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(tx, source_time_series1, label='leading slow data')
ax.plot(tx, source_time_series2, alpha=0.5, label='lagging slow data')
ax.set_ylim([-0.5e-8, 0.5e-8])
legend = ax.legend()
fig.savefig('source_time_courses.svg')

### Define when the activity occurs using events.<br>
The first column is the sample of the event, the second is not used, and the third is the event id. Here the events occur every 200 samples.

In [None]:
n_events = 10
events = np.zeros((n_events, 3), dtype=int)
events[:, 0] = 100 + (ntimes // 10) * np.arange(n_events)  # Events sample.
events[:, 2] = 1  # All events have the sample id.

### Create simulated source activity.<br>
Here we use a `SourceSimulator` whose add_data method is key. It specified where (label), what (source_time_series), and when (events) an event type will occur.

In [None]:
source_simulator = mne.simulation.SourceSimulator(src, tstep=tstep)
for label, source_time_series  in zip(labels, source_time_serieses):
    source_simulator.add_data(label, source_time_series, events)

### Project the source time series to sensor space and add some noise.<br>
The source simulator can be given directly to the simulate_raw function.

In [None]:
raw = mne.simulation.simulate_raw(info, source_simulator, forward=forward)
raw = raw.copy().pick_types(eeg=True)
# raw.drop_channels(info['bads'])
raw = mne.simulation.add_noise(raw, noise_cov, iir_filter=[10, -9], random_state=0)

### extract the epochs and form evoked object

In [None]:
# events = mne.find_events(raw)
psd_kwargs = dict(bandwidth=0.5, adaptive=True, low_bias=True)
raw = raw.pick_types(eeg=True).filter(.5, None).resample(100)
epochs = mne.Epochs(raw, events, 1, tmin=-0.0, tmax=10.0, baseline=None)
epochs.load_data()
fig = epochs.compute_psd(method='multitaper', **psd_kwargs).plot()

Crop timepoints of interest, and pick only the EEG channels!

In [None]:
noise_cov.plot(epochs.info)

In [None]:
from purdonlabmeeg import OCACV

mne.set_log_file('debug-3.log')
mne.set_log_level('DEBUG')
this_epochs = epochs.copy()
ocacv = OCACV(n_oscillations=[1, 2, 3, 4], n_pca_components=0.99, noise_cov=noise_cov, max_iter=50, 
        fit_params={'ar_order':13, 'pca_whiten':True, 'scalar_alpha':True,})
ocacv.fit(this_epochs)

## Lets look at the results now.
I give a few useful ways to inspect the results for fit.

In [None]:
fig = ocacv.plot_cv()
fig.show()
ocacv.crossvalidate(False)

In [None]:
fig = ocacv.plot_components(plot_phase=False, colorbar=True)

Compare the OCA topomaps to the traditional topomaps, made from averaging psd within canonical frequency bands.

In [None]:
fig = epochs.compute_psd(method='multitaper', bandwidth=1.).plot_topomap()

In [None]:
sources = ocacv.get_sources(epochs)
mne.viz.set_browser_backend('matplotlib')
fig = sources.plot(picks='all')
fig = sources.compute_psd(picks='all').plot(picks='all')

In [None]:
cov = ocacv.get_fitted_noise_cov()
cov.plot(ocacv.info)