In [None]:
# Author: Proloy das <pd640@nmr.mgh.harvard.edu>
# License: BSD (3-clause)
%config InlineBackend.figure_formats = ['svg']
%matplotlib inline

In [None]:
import os
import numpy as np
import eelbrain
import mne
from mne.datasets import sample
from codetiming import Timer
from matplotlib import pyplot as plt
from eelbrain import save
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
mne.viz.set_browser_backend('matplotlib')

In [None]:
data_path = sample.data_path()
fwd_fname = os.path.join(data_path, 'MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif')
ave_fname = os.path.join(data_path, 'MEG/sample/sample_audvis-ave.fif')
cov_fname = os.path.join(data_path, 'MEG/sample/sample_audvis-cov.fif')
subjects_dir = os.path.join(data_path, 'subjects')
condition = 'Left Auditory'
subject = 'sample'

## Simulate raw data

In [None]:
info = mne.io.read_info(ave_fname)
with info._unlock():
    info['sfreq'] = 100.
tstep = 1 / info['sfreq']
forward = mne.read_forward_solution(fwd_fname)
src = forward['src']
noise_cov = mne.read_cov(cov_fname)

In [None]:
eeg_indices = mne.pick_types(info, meg=False, eeg=True, stim=True)
info = mne.pick_info(info, eeg_indices)
info

### Regions to activate
For demonstartion purpose, we use choose four region of interests from both hemispheres.

| region             | hemi | activity |
|--------------------|------|----------|
|transversetemporal  |  lh  |   slow   |
|precentral          |  rh  |   slow   |
|inferiorparietal    |  rh  |   alpha  |
|caudalmiddlefrontal|  lh  |   alpha  |

Each ROI extent is 10 mm, starting from the center of the above-mentioned DKT atlas. 

In [None]:
rois = ['transversetemporal', 'precentral', 'inferiorparietal', 'caudalmiddlefrontal']
hemis = ['lh', 'rh', 'rh', 'lh']
selected_labels = [mne.read_labels_from_annot(
                                        subject,
                                        regexp=f'{roi}-{hemi}',
                                        subjects_dir=subjects_dir
                                             )[0] 
                   for roi, hemi in zip(rois, hemis)]
location = 'center'  # Use the center of the region as a seed.
extent = 10.  # Extent in mm of the region.
labels = [mne.label.select_sources(
    subject, selected_label, location=location, extent=extent,
    subjects_dir=subjects_dir) for selected_label in selected_labels]

### Define the time course of the activity for each source of the region to activate.
Here we use two AR models: one slow (central frequecny 1.6Hz), and one fast oscillations (central frequency 12Hz). 
1. For slow oscillations, one of them is amiply lagged version of the another.
2. For the fast oscillations, they are two separate realizations of the same AR process.

In [None]:
from purdonlabmeeg._temporal_dynamics_utils.tests._generate_data import ARData

ntimes = int(np.round(info['sfreq'] * 20. * 10)) + 200
# slow_data = ARData(ntimes + 36, noise_var=0.01,
#                    coeffs=[2*np.cos(2*np.pi*1.6/info['sfreq']),
#                            -0.9999999],
#                    num_prev=2)
# fast_data = ARData(ntimes, noise_var=0.01,
#                    coeffs=[2*np.cos(2*np.pi*12/info['sfreq']),
#                            -0.9996],
#                    num_prev=2)

# another_fast_data = ARData(ntimes, noise_var=0.01,
#                    coeffs=[2*np.cos(2*np.pi*10./info['sfreq']),
#                            -0.9996],
#                    num_prev=2)

slow_data = ARData(ntimes + 1, noise_var=0.01,
                   coeffs=[2*np.cos(2*np.pi*1.6/info['sfreq']),
                           -0.993],
                   num_prev=2)
fast_data = ARData(ntimes, noise_var=0.01,
                   coeffs=[2*np.cos(2*np.pi*12/info['sfreq']),
                           -0.965],
                   num_prev=2)

another_fast_data = ARData(ntimes, noise_var=0.01,
                   coeffs=[2*np.cos(2*np.pi*10./info['sfreq']),
                           -0.96],
                   num_prev=2)

source_time_series1 = 15e-9 * slow_data.y[200:][1:] / slow_data.y[200:].std()
source_time_series2 = 5e-9 * slow_data.y[200:][:-1] / slow_data.y[200:].std()
source_time_series3 = 10e-9 * fast_data.y[200:] / fast_data.y[200:].std()
source_time_series4 = 5e-9  * another_fast_data.y[200:] / another_fast_data.y[200:].std()
source_time_serieses = (source_time_series1, source_time_series2,
                       source_time_series3, source_time_series4)
source_time_serieses = [x[20:] for x in source_time_serieses] 


In [None]:
tx = np.arange(ntimes-200) / info['sfreq']
fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(tx, source_time_series1, label='leading slow data')
ax.plot(tx, source_time_series2, label='lagging slow data')
ax.plot(tx, source_time_series3, label='fast data')
ax.plot(tx, source_time_series4, label='another fast data')
ax.set_ylim([-0.5e-7, 0.5e-7])
legend = ax.legend()
fig.savefig('source_time_courses.svg')

### Define when the activity occurs using events.<br>
The first column is the sample of the event, the second is not used, and the third is the event id. Here the events occur every 200 samples.

In [None]:
n_events = 10
events = np.zeros((n_events, 3), dtype=np.int_)
events[:, 0] = 100 + (ntimes // 10) * np.arange(n_events)  # Events sample.
events[:, 2] = 1  # All events have the sample id.

###  Simulated activity creation (kinda easy).<br>
Here we use a `SourceSimulator` whose add_data method is the key. It allows us to specify where (label), what (source_time_series), and when (events) an event type will occur.

In [None]:
source_simulator = mne.simulation.SourceSimulator(src, tstep=tstep)
for label, source_time_series  in zip(labels, source_time_serieses):
    source_simulator.add_data(label, source_time_series, np.array([[0, 100, 1],]))

stc = source_simulator.get_stc(100, ntimes+100)

def summarize(x, axis): 
    return np.sum(x ** 2, axis=axis)
stc = stc.bin(10, func=summarize)


In [None]:
# initial_time = 5.4
# brain = stc.plot(subjects_dir=subjects_dir, hemi='both', initial_time=initial_time,
#                  clim=dict(kind='value', lims=[1e-9, 6e-9, 2e-8]), alpha=1.0,
#                  smoothing_steps=7)

Also, the noise is kinda important, we visualize it before moving forward.

In [None]:
# noise_cov.data[:] *= np.eye(noise_cov.data.shape[0])
noise_cov.data[:] += 0.2 * np.diag(np.diag(noise_cov.data))
noise_cov.data[:] /= 1.2

# fig = noise_cov.plot(info)

Now we are ready to project the source time series to sensor space and add some noise.<br>

In [None]:
raw = mne.simulation.simulate_raw(info, source_simulator, forward=forward)
# cov = mne.make_ad_hoc_cov(raw.info)
raw_orig = raw.copy()
mne.simulation.add_noise(raw, noise_cov, random_state=0)

### extract the epochs and form evoked object

In [None]:
# events = mne.find_events(raw)
# raw = raw.filter(1., None)
raw.set_eeg_reference('average', projection=True)
epochs = mne.Epochs(raw, events, 1, tmin=-0.0, tmax=15.0, baseline=None)
epochs.load_data()


In [None]:
fig = epochs[4:].plot_psd()
fig.savefig('psd plot.svg')
fig = epochs[4:].plot_psd_topomap()
fig.savefig('psd plot topomap.svg')

### VVI: Crop timepoints of interest, and pick only the EEG channels!

In [None]:
epochs = epochs.pick_types(eeg=True, meg=False)
# epochs = epochs.drop_channels('EEG 052')
fig = epochs.plot()

Once again, lets look at the noise cov to remind ourselves that it is _not diagonal_.

In [None]:
fig = noise_cov.plot(epochs.info)
# fig.savefig('noise-cov.svg')

In [None]:
from purdonlabmeeg.oca import OCA, OCACV

oca = OCACV(n_oscillations=[2, 3, 4, 5, 6, 8, 10], n_pca_components=.9999, noise_cov=noise_cov,
            fit_params={'ar_order':13, 'pca_whiten':False, 'scalar_alpha':True}, max_iter=100)
oca.fit(epochs[4:6])
# for ii in [0]: # range(0, len(epochs)):
#     ocacv = OCA.fit(epochs[4:6], 10, picks=None, start=None, stop=None,
#                        max_iter=50, initial_guess=None,
#                        scalar_alpha=True, update_sigma2=True,
#                        tol=1e-6, verbose=None, ar_order=7)
#     save.pickle(ocacv, f'results/oca-10s-epoch{ii}-cv')

In [None]:
fig = oca.plot_cv()

In [None]:
oca._oscillators_.freq


In [None]:
mne.cov.compute_covariance(epochs).plot(info)

In [None]:
oca.get_fitted_noise_cov().plot(info)

In [None]:
sources = oca.get_sources(epochs[4:6])

In [None]:
fig =  sources.plot_psd(picks='all')

In [None]:
fig = sources.plot(picks='all', scalings={'misc': 1})

In [None]:
recon_epochs = oca.apply(epochs)
res = epochs.copy()
res._data -= recon_epochs._data
fig = res.plot_psd()

In [None]:
%debug

Okay, now that we have fitted OCA, how many osc components do you think OCA will recover? 2, 3, or 4 or more? Let's find out, shall we?

Wait, what? Why 3? should not there be 4. Think twice. How many independent time courses were there? 

Lets look at the loading matrices, i.e the topomaps now.

In [None]:
fig = ocacv.plot_topomaps(plot_phase=False, colorbar=True)
fig.savefig('oca-topomaps.svg')

What about the recovered time courses?

In [None]:
fig = ocacv.plot_sources(epochs, scalings={'misc': 5e-1})
fig.savefig('oca-tc.svg')

How is the free energy doing?

In [None]:
fig, ax = plt.subplots()
ax.plot(ocacv._free_energy)
for oca in ocacv._rest_ocas[:2]:
    ax.plot(oca._free_energy)
fig.savefig('oca-convg.svg')

And last, but not the least, how was the noise covariance learing? ðŸ˜²

In [None]:
fig = ocacv.noise_cov.plot(ocacv.info)
fig[0].savefig('oca-noise-cov-est.svg')

In [None]:
# # WIP

# import matplotlib.pyplot as plt
# from matplotlib.gridspec import GridSpec


# def format_axes(fig):
#     for i, ax in enumerate(fig.axes):
# #         ax.text(0.5, 0.5, "ax%d" % (i+1), va="center", ha="center")
#         ax.tick_params(labelbottom=False, labelleft=False)

# fig = plt.figure(constrained_layout=True)

# gs = GridSpec(4, 3, figure=fig)
# ax00 = fig.add_subplot(gs[0,0], projection='3d')
# ax01 = fig.add_subplot(gs[0,1], projection='3d')
# ax02 = fig.add_subplot(gs[0,2], projection='3d')
# ax1 = fig.add_subplot(gs[1, :])
# # identical to ax1 = plt.subplot(gs.new_subplotspec((0, 0), colspan=3))
# ax2 = fig.add_subplot(gs[2, :-1])
# ax3 = fig.add_subplot(gs[2:, -1])
# ax4 = fig.add_subplot(gs[-1, 0])
# ax5 = fig.add_subplot(gs[-1, -2])

# # Time courses
# # brain = stc.plot(subjects_dir=subjects_dir, hemi='lh', initial_time=initial_time,
# #                  clim=dict(kind='value', lims=[1e-9, 6e-9, 2e-8]), alpha=1.0,
# #                  smoothing_steps=7, backend='matplotlib')

# tx = np.arange(ntimes) / info['sfreq']
# ax1.plot(tx, source_time_series1, label='leading slow data')
# ax1.plot(tx, source_time_series2, label='lagging slow data')
# ax1.plot(tx, source_time_series3, alpha=0.7, label='fast data')
# ax1.plot(tx, source_time_series4, alpha=0.5, label='another fast data')
# ax1.set_ylim([-0.5e-7, 0.5e-7])
# legend = ax1.legend()
# ax1.set_xlim([20, 40])

# fig.suptitle("GridSpec")
# format_axes(fig)

# plt.show()

In [None]:
## WIP
# from scipy import sparse
# import numpy as np
# pca, n_pca = oca._pca_dict['pca'], oca._pca_dict['n_pca']
# noise_cov = sparse.block_diag((oca._noise_var,
#                                 np.diag(pca.explained_variance_[n_pca:]))).toarray()
# noise_cov = pca.inverse_transform(pca.inverse_transform(noise_cov.T).T)
# cov = mne.Covariance(noise_cov * (ocacv._data_scale ** 2), oca.info.ch_names, 
#                      bads=None, projs=[], nfree=1, 
#                      eig=None, eigvec=None, method='custom',
#                      loglik=None, verbose=None)
# # epoch.info
# cov.plot(oca.info)
# fig.savefig('oca-noise-cov-est.svg')

Now, who wants to use OCA?? 