In [None]:
# -
from bokeh.plotting import figure
from bokeh.io import show, output_notebook
from bokeh.layouts import column, row, gridplot
from bokeh.models import Range1d
from bokeh.io import export_png
output_notebook()

# -
from fooof import FOOOF

import csv
import os
import h5py
import numpy as np

from tqdm import tqdm
from glob import glob
from pprint import pprint

from scipy.io import loadmat
from scipy.signal import medfilt
from scipy.signal import resample
from scipy.signal import welch, gaussian
from scipy.signal import butter, lfilter


# --
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')

    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)

    return y

# -
def plot_grid_psd(x, m, fs, show_plot=False):
    freqs, psd = create_psd(x[m], fs)
    p = figure(plot_width=150, plot_height=150)
    p.line(freqs, np.log10(psd), color="black", alpha=1)
    p.xaxis.axis_label = 'Freq (Hz)'
    p.yaxis.axis_label = 'Log power (AU)'
    p.x_range = Range1d(1, 50)
    p.y_range = Range1d(-3, 5)
    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None
    p.toolbar.logo = None
    p.toolbar_location = None
    
    if show_plot:
        show(p)
    else:
        return p

# -
def random_fooof(X,
                 num_windows,
                 window_length,
                 good_channels,
                 fs=1000,
                 alpha_range=(8, 16),
                 fit_range=(3, 40),
                 i_min=10000,
                 max_iterations=5000,
                 save_results=False):
    """FOOOF random channels and  windows of X."""
    
    # Init
    fm = FOOOF(peak_width_limits=[2.0, 12.0])

    # Est baseline std for all channels
    std_ref = X.std()
    mean_alpha = np.mean(alpha_range)

    # -
    # !
    m = 0  # Window count
    k = 0  # Iter count

    results = []
    while m < num_windows:
        # Update overall iter count
        k += 1
        if k > max_iterations:
            break

        # Sample random channel from the good
        np.random.shuffle(good_channels)
        c = good_channels[0]
    
        x = X[:, c]
        
        i_max = x.shape[0]
        if i_min > i_max:
            raise ValueError("i_min must be less than {}".format(i_max))
            
        # -
        # Find a good window
        stop_idx_search = True
        while stop_idx_search:
            # Update overall iter count
            k += 1
            if k > max_iterations:
                break
            
            # Generate random i:j 
            i = np.random.randint(i_min, i_max - window_length, 1)
            j = i + window_length
            
            # Basic QC pass?
            if x.std() > (5 * std_ref):
                continue

            stop_idx_search = False
            
            i = int(i)
            j = int(j)
        
        # FOOOF x in the window
        freqs, psd = create_psd(x[i:j], fs)
        fm.fit(freqs, psd, fit_range)

        # -
        # Repack peak_params_ into seperate values
        centers = []
        powers = []
        bws = []
        for (center, amp, bw) in fm.peak_params_:
            centers.append(center)
            powers.append(amp)
            bws.append(bw)
        
        centers = np.asarray(centers)
        powers = np.asarray(powers)
        bws = np.asarray(bws)
        
        # Found any peaks?
        if centers.size == 0:
            continue
        
        # -
        # Find closest to mean alpha
        idx = (np.abs(centers - mean_alpha)).argmin()
        closest_peak = centers[idx]

        # It is in range?
        if (closest_peak >= alpha_range[0]) and (
                closest_peak <= alpha_range[1]):

            row = (m, c, i, j, closest_peak, powers[idx], bws[idx])
            results.append(row)

            m += 1

    return results

def save_fooof_results(name, results):
    header = ("m", "c", "i", "j", "center", "power", "bw")

    with open(name, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=',')
        writer.writerow(header)
        
        for row in results:
            writer.writerow(row)

  
def create_psd(lfp, inrate, outrate=1000):
    """Calculate PSD from LFP/EEG data."""
    lfp = np.array(lfp)

    if inrate != outrate:
        lfp = resample(lfp, int(lfp.shape[0] * outrate / inrate))

    # Calculate PSD
    return welch(
        lfp,
        fs=outrate,
        window='hanning',
        nperseg=outrate,
        noverlap=outrate / 2.0,
        nfft=None,
        detrend='linear',
        return_onesided=True,
        scaling='density')

def create_times(t, dt):
    n_steps = int(t * (1.0 / dt))
    times = np.linspace(0, t, n_steps)

    return times

# Data path

In [None]:
DATA_PATH = "/Users/type/Data/Smith/data"

# Good files

In [None]:
files = ["Bo130408_s6ae_fixblank_active_0001_converted_ns2.mat",
"Bo130408_s6ae_fixblank_active_0002_converted_ns2.mat",
"Bo130409_s7ae_fixblank_active_0003_converted_ns2.mat",
"Wi130116_s51ae_fixblank_active_0001_converted_ns2.mat",
"Bo130404_s4ae_fixblank_active_0002_converted_ns2.mat",
"Bo130405_s5ae_fixblank_active_0001_converted_ns2.mat",
"Bo130405_s5ae_fixblank_active_0002_converted_ns2.mat",
"Bo130405_s5ae_fixblank_active_0003_converted_ns2.mat",
"Bo130418_s12ae_fixblank_active_0002_converted_ns2.mat",
"Wi121219_s43ae_fixblank_active_0001_converted_ns2.mat",
"Wi121219_s43ae_fixblank_active_0002_converted_ns2.mat",
"Wi130129_s55ae_fixblank_active_0001_converted_ns2.mat",
"Wi130129_s55ae_fixblank_active_0002_converted_ns2.mat",
"Wi130205_s58ae_fixblank_active_0001_converted_ns2.mat",
"Wi130205_s58ae_fixblank_active_0002_converted_ns2.mat",
"Wi130207_s59ae_fixblank_active_0001_converted_ns2.mat",
"Wi130207_s59ae_fixblank_active_0002_converted_ns2.mat",
"Wi130207_s59ae_fixblank_active_0003_converted_ns2.mat",
"Wi130208_s60ae_fixblank_active_0001_converted_ns2.mat",
"Wi130211_s61ae_fixblank_active_0001_converted_ns2.mat",
"Wi130212_s62ae_fixblank_active_0001_converted_ns2.mat"]

# Choose a file

In [None]:
i = 0
print("Running {}".format(files[i]))

# Load the handle

In [None]:
fi = h5py.File(os.path.join(DATA_PATH, files[i]))
pprint(list(fi.keys()))

# Create time

In [None]:
fs = float(fi['Fs'].value)
n_samples = float(fi['nSamples'].value) 
T = n_samples / fs
print("Experiment time T {}".format(T))

times = create_times(T, 1/fs)
print("Times: {}".format(times[:20]))

# Get data

and print stats.

In [None]:
X = fi['data'].value
print("Data shape: {}".format(X.shape))

# FOOOF!

In [None]:
results = random_fooof(X, 10, 10*fs, list(range(0, 95)), fs=fs)

In [None]:
pprint(results[:3])
save_fooof_results("test.csv", results)

- Results seem sane

# Process good files

- 6000 samples / file (or 600 per electrode on average)

### LFP

In [None]:
n_window = 6000
l = 10*fs
channels = list(range(0, 95))

for fi in files:
    # Extract name, drop extension
    fi_name = os.path.splitext(fi)[0]
    print(fi_name)
    
    # Do random FOOOF
    results = random_fooof(X, n_window, l, channels, fs=fs)
    
    # Save the result
    save_fooof_results(
        "{}_segments.csv".format(os.path.join(DATA_PATH, fi_name)), 
        results)