In [None]:
import numpy as np
import matplotlib.pyplot as plt
import librosa
from librosa import display
from scipy.signal import hilbert
import glob
import mne
from eelbrain import *
import os
import pandas as pd

# configure(n_workers=False)

#### Get envelope

In [None]:
def load_wav(file, sr):
    stim, sr = librosa.load(file, sr=sr)
    # librosa.display.waveshow(stim, sr=sr)

    return stim

In [None]:
def compute_envelope(stim):
    analytic_signal = hilbert(stim)
    amplitude_envelope = np.abs(analytic_signal)

    return amplitude_envelope

In [None]:
def pad_envelope(envelope, dtype):
    e = np.pad(envelope, pad_width=(100, (701 - (len(envelope)+100))))
    e = e.astype(dtype)

    return e

In [None]:
def get_envelope(audio_file, sr):
    stim = load_wav(audio_file, sr=sr)
    amp_envelope = compute_envelope(stim)
    envelope = pad_envelope(amp_envelope, '<f8')

    return envelope

#### Create dataset

In [None]:
epoch_files = sorted(glob.glob('./analysis/Revcor*-epo.fif'))
log_files = sorted(glob.glob('./log/trials_subj*.csv'))

es_dict = dict(zip(epoch_files, log_files))
del es_dict['./analysis\\Revcor0019-epo.fif']

In [None]:
tstep = 1. / 1000
n_times = 701
time = UTS(0, tstep, n_times)

sensor = Sensor.from_montage('easycap-M1')[:64]

rows = []

for k, v in es_dict.items():
    subj = mne.read_epochs(k)
    subj = subj.drop_channels('STI')

    df = pd.read_csv(v, encoding='latin')
    df = df[['stim_type_marker', 'stim_id_marker', 'sound_file']]
    df = df[df['stim_id_marker'].isin(subj.selection)]   

    for i in range(len(subj)):

        subject = int(k[17:21])
        
        eeg = NDVar(subj[i].get_data()[0].T, (time, sensor), name='EEG', info={'unit': 'µV'})

        sound = df['sound_file'].iloc[i]
        envelope = NDVar(get_envelope(sound, sr=1000), (time,), name='envelope')

        rows.append([subject, eeg, envelope]) 


ds = Dataset.from_caselist(['subject', 'eeg', 'envelope'], rows)
print(ds.summary())

#### Save dataset

In [None]:
ds.save()

#### Load dataset

In [None]:
ds = load.unpickle('./datasets/dataset.pickle')

#### Compute TRF

In [None]:
tstep = 1. / 1000
n_times = 701
time = UTS(0, tstep, n_times)

sensor = Sensor.from_montage('easycap-M1')[:64]

In [None]:
fit = boosting('eeg', 'envelope', 0, 0.600, basis=0.050, ds=ds, delta=0.01, partitions=6)

# Plot TRF
p = plot.TopoButterfly(fit.h_scaled, w=6, h=2)
p.set_time(.200)

In [None]:
# Save boosting result to predict eeg from envelope
env2eeg_fit = save.pickle(fit)

In [None]:
# Predict EEG data from amplitude envelope
x = NDVar(get_envelope('./sounds/subj6/julie_neutral.0736.pitch_gain.wav', sr=1000), (time,))
y = convolve(res.h_scaled, x)

plot.UTS(y, '.sensor')

In [None]:
plot.TopoButterfly(y, t=0.3)
plot.TopoButterfly(ds['eeg'], t=0.3)

#### Test

In [None]:
tstep = 1. / 1000
n_times = 701
time = UTS(0, tstep, n_times)

sensor = Sensor.from_montage('easycap-M1')[:64]

In [None]:
subj = mne.read_epochs('./analysis/Revcor0007-epo.fif')
subj.drop_channels('STI')

df = pd.read_csv('./log/trials_subj0007_211026_16.27.csv', encoding='latin')
df = df[['stim_type_marker', 'stim_id_marker', 'sound_file']]
df = df[df['stim_id_marker'].isin(subj.selection)]  

rows = []

for i in range(len(subj)):
    
    eeg = NDVar(subj[i].get_data()[0].T, (time, sensor), name='EEG', info={'unit': 'µV'})

    sound = df['sound_file'].iloc[i]
    envelope = NDVar(get_envelope(sound, sr=1000), (time,), name='envelope')

    rows.append([eeg, envelope]) 


ds = Dataset.from_caselist(['eeg', 'envelope'], rows)
print(ds.summary())

In [None]:
print(ds.summary())

In [None]:
fit = boosting('envelope', 'eeg', 0, 0.600, basis=0.050, ds=ds, partitions=2)

In [None]:
# Save boosting result to predict envelope from EEG
eeg2env_fit = save.pickle(fit)

In [None]:
actual_envelope = NDVar(get_envelope('./sounds/subj6/julie_neutral.0736.pitch_gain.wav', sr=1000), (time,))
plot.UTS(actual_envelope)

In [None]:
std_epoch = mne.read_epochs('./analysis/Revcor0006-epo.fif', verbose=False)['14']
std_epoch.drop_channels('STI')

In [None]:
# Predict amplitude envelope from EEG data
x = NDVar(std_epoch.get_data()[0].T, (time, sensor))
y = convolve(fit.h_scaled, x, ds=ds)
print(x)
print(y)
plot.UTS(y)