# Package imports

In [None]:
import pickle
import os
from datetime import datetime

from open_ephys.analysis import Session

from matplotlib import pyplot as plt
plt.rcParams["font.family"] = "Arial"

import scipy
from scipy import signal
from scipy import interpolate
from scipy import ndimage
from scipy.optimize import curve_fit
from scipy import stats

import numpy as np
import pandas as pd

from utilities import *
from detect_mua import detect_MUA, get_spike_wave_params

# Recordings

In [None]:
# Array of dicts for recording data in format:
RECORDINGS = [
    {
        "path": "",           # Path to recording (folder name)
        "age": ,              # Age of animal at recording
        "thalamus": ,         # Does recording have thalamus probe?
        "striatum_channel": , # Striatum channel to use
        "cortex_channel": ,   # Cortex channel to use
        "recording": ,        # Recording within folder
        "striatum_sigma": ,   # Sigma threshold (for altenative burst detection method)
        "cortex_sigma":       # Sigma threshold (for altenative burst detection method)
    }
]

RECORDINGS = sorted(RECORDINGS, key=lambda k: k['age']) 

# Open previously saved data

In [None]:
rms_processed_recordings = open_data('') # Path to processed data (pickle file format)
processed_recordings_mua = open_data('') # Path to processed data (pickle file format)

# Burst detection

In [None]:
def detect_bursts_from_envelope(envelope, low_threshold):
    bursts = []
    
    start_t = False
    burst = []
    
    for t, sample in enumerate(envelope):
        if sample > low_threshold:
            if not start_t:
                start_t = t
            burst.append(sample)
            
        if start_t and (sample <= low_threshold or t+1 == len(envelope)):
            bursts.append( Burst([start_t, t], burst) )
            start_t = False
            burst = []
    
    return bursts
        

# Filter bursts shorter than minimum duration
def filter_short_bursts (bursts, minimum_duration):
    filtered_bursts = []
    
    for burst in bursts:
        if (burst.time[1] - burst.time[0]) > (minimum_duration*SAMPLING_RATE):
            filtered_bursts.append(burst)
            
    return filtered_bursts

# Filter bursts that have fewer than minimum points above threshold
def filter_minimum_peaks (bursts, threshold, minimum_peaks):
    filtered_bursts = []
    
    for burst in bursts:
        peaks = signal.find_peaks(np.abs(burst.data), prominence=threshold)[0]
        diff_count = 0
        
        for peak_idx, peak in enumerate(peaks):
            if peak_idx+1 == len(peaks):
                continue
                
            next_peak = peaks[peak_idx+1]
            diff = (next_peak - peak)/SAMPLING_RATE

            if diff_count == (minimum_peaks-1):
                break
            
            if diff < 0.2:
                diff_count += 1
            else:
                diff = 0
                
        
        if diff_count == minimum_peaks-1:
            filtered_bursts.append(burst)
            
    return filtered_bursts

# If bursts have been bandpass filtered, use this function
# to return unfiltered bursts using saved time points for each burst
def get_unfiltered_bursts (bursts, data):
    unfiltered_bursts = []
    
    for burst in bursts:
        burst_slice = slice(*burst.time)
        unfiltered_bursts.append( Burst(burst.time, data[burst_slice]) )

    return unfiltered_bursts

def get_baseline_periods (bursts, data):
    baseline_bursts = []
    
    padding = int(SAMPLING_RATE * 0.5) # 0.1 seconds
    
    for idx, burst in enumerate(bursts):
        # If on final burst of set
        if (idx+1) == len(bursts):
            continue
            
        next_burst_time = bursts[idx+1][0]
        curr_burst_time = burst[0]
        time_diff = next_burst_time[0] - curr_burst_time[1]

        if time_diff > (padding*4):
            start_t = next_burst_time[0] + padding
            end_t = curr_burst_time[1] - padding
            baseline_data = data[start_t:end_t]
            baseline_bursts.append( ([start_t, end_t], baseline_data) )

    return baseline_bursts

def combined_adjacent_bursts(bursts, data, temporal_distance):
    combined_bursts = []
    temp_bursts = []
    
    # Take first and last bursts and combine
    def combine_bursts (temp_bursts):
        start = temp_bursts[0].time[0]
        end = temp_bursts[-1].time[1]
        
        return Burst([start, end], data[start:end])
    
    for burst_idx, burst in enumerate(bursts):    
        if burst_idx == 0:
            continue
            
        prev = bursts[burst_idx-1]
        
        prev_time = prev.time
        curr_time = burst.time
        
        temp_bursts.append(prev)
        
        # If bursts are greater than 1 second apart
        if (curr_time[0] - prev_time[1] > temporal_distance*SAMPLING_RATE):
            combined_bursts.append(combine_bursts(temp_bursts))
            temp_bursts = []
        
        # Or if on final burst of set
        if burst_idx+1 == len(bursts):
            temp_bursts.append(burst)
            
            combined_bursts.append(combine_bursts(temp_bursts))
            temp_bursts = []
            
    return combined_bursts


def hl_envelopes_idx(s, dmin=1, dmax=1):
    # locals min      
    lmin = (np.diff(np.sign(np.diff(s))) > 0).nonzero()[0] + 1 
    # locals max
    lmax = (np.diff(np.sign(np.diff(s))) < 0).nonzero()[0] + 1 


    # global max of dmax-chunks of locals max 
    lmin = lmin[[i+np.argmin(s[lmin[i:i+dmin]]) for i in range(0,len(lmin),dmin)]]
    # global min of dmin-chunks of locals min 
    lmax = lmax[[i+np.argmax(s[lmax[i:i+dmax]]) for i in range(0,len(lmax),dmax)]]
    
    return lmin,lmax
            

# Cobmines all previous functions into a single routine
# Returns list of burst tuples in form ([start time, end time], [burst data])

# Samples              Array of sampled data
# Sigma                How many standard deviations above mean for threshold
# Minimum_duration     Minimum burst duration to keep

def run_burst_procedure (data, minimum_peaks, minimum_duration, sigma):
    # Bandpass filter signal
    data_bandpass = butter_bandpass_filter(data, lowcut=1, highcut=100)
        
    # Get envelope of signal
    low_idx, high_idx = hl_envelopes_idx(data_bandpass, dmin=30, dmax=30)
    x = np.arange(0, len(data_bandpass))
    high_env = np.interp(x, x[high_idx], data_bandpass[high_idx])
    low_env = np.interp(x, x[low_idx], data_bandpass[low_idx])
    mean_env = []
    for t, _ in enumerate(high_env):
        a, b = high_env[t], abs(low_env[t])
        if a > b:
            mean_env.append(a)
        else:
            mean_env.append(b)
    
    # Clip data to prevent skew of std from random outlier events
    data_clipped = np.clip(data_bandpass, a_min=-1000, a_max=1000)
    
    # Get thresholds
    mean, std = np.mean(data_clipped), np.std(data_clipped)
    burst_threshold = mean + std*sigma
    print('Calculated threshold ({}, sigma={})\n'.format(
        round(burst_threshold, 2),
        round(sigma, 2)
    ))

    bursts = detect_bursts_from_envelope(mean_env, burst_threshold)
    print('\nBursts detected')
    
    bursts = combined_adjacent_bursts(bursts, mean_env, temporal_distance=0.2)
    print('Combined adjacent bursts')
    
    bursts = filter_rms(bursts, max_rms=1000)
    print('\tFiltered RMS')

    bursts = filter_short_bursts(bursts, minimum_duration)
    print('Short bursts filtered')
    
    bursts = get_unfiltered_bursts(bursts, data_bandpass)
    print('Got unfiltered bursts')
    
    bursts = filter_minimum_peaks(bursts, burst_threshold, minimum_peaks=minimum_peaks)
    print('Filtered minimum thresold points')
    
    for burst in bursts:
        burst.get_primary_frequency()
    print('Primary frequencies computed')
    
    return bursts

In [None]:
def load_data (recordings, get_mua=False, get_psd=False):
    bursts = []
    
    for idx, recording in enumerate(recordings[:17]):
        print("\nRecording {} ({}/{})".format(
            recording["path"],
            idx+1,
            len(recordings)
        ))
        
        data_range = get_slice_from_s(0, 60*60)
        recording_n = recording["recording"]
        striatum_channel_n = recording["striatum_channel"]
        thalamus_channel_n = recording["thalamus_channel"] if recording["thalamus"] else 0
        cortex_channel_n = recording["cortex_channel"]

        session = Session(ROOT + recording["path"])
        striatum_data = session.recordings[recording_n].continuous[0].samples[data_range, striatum_channel_n]
        thalamus_data = session.recordings[recording_n].continuous[0].samples[data_range, thalamus_channel_n]
        cortex_data = session.recordings[recording_n].continuous[0].samples[data_range, cortex_channel_n]

        minimum_duration = 0.2
        minimum_peaks = 5
        
        recording = recording.copy()
        
        # Striatum
        if len(striatum_data):
            recording["length"] = len(striatum_data)
            recording["striatum_bursts"]  = run_burst_procedure(
                data=striatum_data,
                minimum_peaks=minimum_peaks,
                minimum_duration=minimum_duration,
                sigma=recording["striatum_sigma"]
            )
            if get_mua:
                recording["striatum_MUA"] = detect_MUA(
                    data=striatum_data,
                    standard_deviations=5
                )
            if get_psd:
                recording["striatum_PSD"] = multitaper_psd (
                    striatum_data,
                    NW=3, k=5, resample_freq=1000,
                    show_progress=True
                )
                
        # Thalamus
        if len(thalamus_data) and recording["thalamus"]:
            recording["thalamus_bursts"] = run_burst_procedure(
                data=thalamus_data,
                minimum_peaks=minimum_peaks,
                minimum_duration=minimum_duration,
                sigma=recording["thalamus_sigma"]
            )
            if get_mua:
                recording["thalamus_MUA"] = detect_MUA(
                    data=thalamus_data,
                    standard_deviations=5
                )
            if get_psd:
                recording["thalamus_PSD"] = multitaper_psd (
                    thalamus_data,
                    NW=3, k=5, resample_freq=1000,
                    show_progress=True
                )
        
        # Cortex
        if len(cortex_data):
            recording["cortex_bursts"]  = run_burst_procedure(
                data=cortex_data,
                minimum_peaks=minimum_peaks,
                minimum_duration=minimum_duration,
                sigma=recording["cortex_sigma"]
            )
            if get_mua:
                recording["cortex_MUA"] = detect_MUA(
                    data=cortex_data,
                    standard_deviations=5
                )
            if get_psd:
                recording["cortex_PSD"] = multitaper_psd (
                    cortex_data,
                    NW=3, k=5, resample_freq=1000,
                    show_progress=True
                )

        bursts.append(recording)

    return bursts
        
processed_recordings = load_data(RECORDINGS, get_mua=False, get_psd=False)
save_data(processed_recordings, 'processed_recording_final')

In [None]:
# This code will find the baseline periods in a recording
# Defined as those periods of signal not included as burst where 
# All time points are below the cut-off threshold above 0.5s in length

# These baseline periods are trimmed to a max of 10s and the PSD for each is taken
# These PSDs are then averaged giving the mean baseline PSD for each brain area and recording

def get_baseline_periods (recording):
    for brain_area in ["striatum", "thalamus", "cortex"]:
        burst_key = brain_area + "_bursts"
        if not burst_key in recording:
            continue
        
        baseline_key = brain_area + "_baseline"
        channel_key = brain_area + "_channel"
        sigma_key = brain_area + "_sigma"
        path, recording_n, channel_n = recording["path"], recording["recording"], recording[channel_key]
        
        session = Session(ROOT + path)
        bursts = recording[burst_key]
        
        data = session.recordings[recording_n].continuous[0].samples[DATA_RANGE_ALL, channel_n]
        data_bandpass = butter_bandpass_filter(data, lowcut=1, highcut=100)
        data_clipped = np.clip(data_bandpass, a_min=-1000, a_max=1000)
        print('Data filtered')
        
        burst_threshold = np.mean(data_clipped) + np.std(data_clipped)*recording[sigma_key]
        print('Threshold calculated ({})'.format(burst_threshold))
        
        baseline_data = []
        for burst_idx, burst in enumerate(bursts):
            next_burst_time = None
            
            # If on last burst, baseline will last until end of recording
            if burst_idx+1 == len(bursts):
                next_burst_time = [len(data_bandpass)]
            else:
                next_burst_time = bursts[burst_idx+1].time
                
            curr_burst_time = burst.time  
            data_raw = data_bandpass[curr_burst_time[1]:next_burst_time[0]]
            
            # If none of baseline period is above threshold
            if len(np.where(data_raw > burst_threshold)[0]) == 0:
                baseline_data.append(data_raw)

        # Try again without threshold
        if len(baseline_data) == 0:
            print('Trying again without threshold')
            for burst_idx, burst in enumerate(bursts):
                next_burst_time = None

                # If on last burst, baseline will last until end of recording
                if burst_idx+1 == len(bursts):
                    next_burst_time = [len(data_bandpass)]
                else:
                    next_burst_time = bursts[burst_idx+1].time

                curr_burst_time = burst.time  
                data_raw = data_bandpass[curr_burst_time[1]:next_burst_time[0]]
                baseline_data.append(data_raw)

        print('Baselines calculated')
        
        # Take mean of all baseline psd's
        baseline_psds = []
        for baseline in baseline_data:
            if len(baseline)/30000 > 0.5:
                baseline = baseline[get_slice_from_s(0, 10)] # Trim to max 10s
                baseline_psds.append(multitaper_psd(baseline))

        # If still no baseline periods, just leave it as 'False'
        if len(baseline_psds):
            f = baseline_psds[0][0]
            Pxx = np.mean([psd[1] for psd in baseline_psds], axis=0)
            recording[baseline_key] = f, Pxx
        else:
            recording[baseline_key] = False
            print('None found')
            
    return recording
    
recordings_with_baselines = []
for recording_idx, recording in enumerate(processed_recordings):
    print("\nRecording {} ({}/{})".format(
        recording["path"],
        recording_idx+1,
        len(processed_recordings)
    ))
    recordings_with_baselines.append(get_baseline_periods(recording))

In [None]:
# This code takes the primary frequency of a burst as the peak in the PSD taken from the ratio
# between the baseline and burst PSDs
#
# This will add a field "primary_frequency_baseline" to each burst across thalamus, striatum, cortex
# for each recording
# Simply change INPUT_RECORDINGS to variable holding the array of processed recordings

INPUT_RECORDINGS = recordings_with_baselines

for recording_idx, recording in enumerate(INPUT_RECORDINGS):
    print("\nRecording {} ({}/{})".format(
        recording["path"],
        recording_idx+1,
        len(INPUT_RECORDINGS)
    ))
    
    for brain_area in ["striatum", "thalamus", "cortex"]:
        baseline_key = brain_area + "_baseline"
        burst_key = brain_area + "_bursts"
        
        if not burst_key in recording:
            continue
            
        f_baseline, Pxx_baseline = get_psd_in_range(recording[baseline_key], [1, 100])

        for burst in recording[burst_key]:
            # Take ratio between burst and baseline power
            f_burst, Pxx_burst = get_psd_in_range(multitaper_psd(burst.data), [1, 100])
            Pxx_ratio = Pxx_burst/Pxx_baseline
            
            # Anything above 1 indicates greater power for burst over baseline
            Pxx_clipped = np.clip(Pxx_ratio, a_min=1, a_max=None)
            
            # Find max peak (primary freq) in burst/baseline power ratio
            peaks = signal.find_peaks(Pxx_clipped)[0]
            if len(peaks):
                max_power = max(Pxx_clipped[peaks])
                max_power_idx = np.where(Pxx_clipped == max_power)[0]
                max_freq = f[max_power_idx][0]

                burst.primary_frequency_baseline = max_freq
            else:
                burst.primary_frequency_baseline = False

# Different burst detection method

In [None]:
# Based on (but all code self-written) from https://www.frontiersin.org/articles/10.3389/fncir.2014.00050/full#F1

def rms (data):
    square = [d**2 for d in data]
    mean = np.mean(square)
    root = mean**0.5
    return root

def rms_hist(data, window=0.2):
    rms_list = []
    
    chunk_size = int(SAMPLING_RATE*window) # Convert window in s to samples
    for i in range(0, len(data), chunk_size):
        data_chunk = data[i:i+chunk_size]
        rms_list.append(rms(data_chunk))
    
    n_bins = int(np.percentile(rms_list, 95))
    hist, bin_edges = np.histogram(rms_list, density=True, bins=n_bins, range=( min(rms_list), min(np.percentile(rms_list, 99), 2000) ))
    
    return hist, bin_edges, rms_list

def plot_gaussian (bin_centres, hist, hist_fit):
    fig = plt.figure()
    plt.plot(bin_centres, hist, label='RMS', c='tab:red')
    plt.plot(bin_centres, hist_fit, label='Fitted Gaussian', c='black')
    plt.xlabel('RMS')
    plt.ylabel('Density')
    
    leg = plt.legend(frameon=False, fontsize=15, bbox_to_anchor=(0.925, 1), loc='upper left')
    for legobj in leg.legendHandles:
        legobj.set_linewidth(2.0)
        
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    for item in ([ax.title] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(12.5)
    for item in ([ax.xaxis.label, ax.yaxis.label] + ax.get_legend().get_texts()):
        item.set_fontsize(15)
    
    plt.show()

def fit_gaussian(hist, bin_edges, use_truncated_hist):
    bin_centres = (bin_edges[:-1] + bin_edges[1:])/2
    
    if False:
        hist_100_idx = np.where(bin_centres <= 100)[0][-1]
        trimmed_hist = hist[:hist_100_idx]
        trimmed_bin_centres = bin_centres[:hist_100_idx]
    else:
        peak_idx = np.argmax(hist)
    
    peak_idx = np.argmax(hist)
    peak = bin_centres[peak_idx]

    def gauss(x, *p):
        A, mu, sigma = p
        return A*np.exp(-(x-mu)**2/(2.*sigma**2))

    
    p0 = [np.max(hist), peak, np.std(np.concatenate([np.repeat(c, int(v*1000)) for c,v in zip(bin_centres, hist)], axis=0))]
    
    coeff, var_matrix = curve_fit(gauss, trimmed_bin_centres, trimmed_hist, p0=p0)
    _, mu, sigma = coeff
    
    hist_fit = gauss(bin_centres, *coeff)
    
    plot_gaussian (bin_centres, hist, hist_fit)
    
    smoothed_diff = hist_fit-hist
    start_idx = np.where(bin_centres >= mu+sigma)[0][0]
    thresh_idx = np.where(smoothed_diff[start_idx:] <= 0)[0][0] + start_idx + 1
    thresh = bin_centres[thresh_idx]

    return mu, abs(sigma), thresh

def get_burst_events (data, rms_list, window=0.2, threshold=0):
    bursts = []
    
    chunk_size = int(SAMPLING_RATE*window) # Convert window in s to samples
    for chunk_idx, chunk_rms in enumerate(rms_list):
        if chunk_rms > threshold:
            t_start = chunk_idx*chunk_size
            t_end = min(len(data), (chunk_idx+1)*chunk_size)
            
            burst = Burst([t_start, t_end], data[t_start:t_end])
            bursts.append(burst)
    
    return bursts

def combine_adjacent_bursts(data, bursts, temporal_distance):
    combined_bursts = []
    temp_bursts = []
    
    # Take first and last bursts and combine
    def combine_bursts (temp_bursts):
        start = temp_bursts[0].time[0]
        end = temp_bursts[-1].time[1]
        
        return Burst([start, end], data[start:end])
    
    for burst_idx, burst in enumerate(bursts):    
        if burst_idx == 0:
            continue
            
        prev = bursts[burst_idx-1]
        
        prev_time = prev.time
        curr_time = burst.time
        
        temp_bursts.append(prev)
        
        # If bursts are greater than x seconds apart
        if (curr_time[0] - prev_time[1] > temporal_distance*SAMPLING_RATE):
            combined_bursts.append(combine_bursts(temp_bursts))
            temp_bursts = []
        
        # Or if on final burst of set
        if burst_idx+1 == len(bursts):
            temp_bursts.append(burst)
            
            combined_bursts.append(combine_bursts(temp_bursts))
            temp_bursts = []
            
    return combined_bursts

def filter_short_bursts (bursts, minimum_duration):
    minimum_samples = minimum_duration*SAMPLING_RATE
    
    return [b for b in bursts if (b.time[1]-b.time[0]) >= minimum_samples]

def filter_rms (bursts, max_rms):
    filtered_bursts = []
    
    for burst in bursts:
        rms_list = []
        chunk_size = int(SAMPLING_RATE*0.2) # Convert window in s to samples
        for i in range(0, len(burst.data), chunk_size):
            data_chunk = burst.data[i:i+chunk_size]
            rms_list.append(rms(data_chunk))
        if np.max(rms_list) < max_rms:
            filtered_bursts.append(burst)
    
    return filtered_bursts

def filter_min_peaks (bursts, min_peaks):
    filtered_bursts = []
    
    for burst in bursts:
        _, n_peaks = get_peaks(burst.data)
        
        if n_peaks >= min_peaks:
            filtered_bursts.append(burst)
    
    return filtered_bursts

def get_raw_bursts (data, bursts):
    unfiltered_bursts = []

    for burst in bursts:
        unfiltered_burst = Burst(
            burst.time,
            data[burst.time[0]:burst.time[1]]
        )
        unfiltered_bursts.append(unfiltered_burst)
    
    return unfiltered_bursts

def get_baseline_power (data, bursts):
    # Get baseline periods
    baseline_data = []
    for burst_idx, burst in enumerate(bursts):
        # If on last burst, baseline will last until end of recording
        if burst_idx+1 == len(bursts):
            next_burst_time = [len(data)]
        else:
            next_burst_time = bursts[burst_idx+1].time

        curr_burst_time = burst.time  
        
        baseline_period = data[curr_burst_time[1]:next_burst_time[0]]
        # Minimum length of 1s
        if len(baseline_period)/SAMPLING_RATE >= 1:
            # Trim to max 10s
            baseline_period = baseline_period[get_slice_from_s(0, 10)]
            baseline_data.append(baseline_period)
            
    # Take mean of all baseline psd's
    baseline_psds = []
    for baseline_period in baseline_data:
        baseline_psds.append(multitaper_psd(baseline_period))

    # If no baseline periods, just leave it as 'False'
    if len(baseline_psds):
        f = baseline_psds[0][0]
        Pxx = np.mean([psd[1] for psd in baseline_psds], axis=0)
        return f, Pxx, baseline_data
    else:
        return False, False, False
    
def get_baseline_amplitude (data):
    if data and len(data):
        amps = []
        for baseline_period in data:
            baseline_period = butter_bandpass_filter(baseline_period, 4, 100)
            amps.append(np.max(baseline_period)-np.min(baseline_period))
        return np.mean(amps, axis=0) 
    else:
        return False
    
def get_normalized_psd (bursts, baseline_psd):
    normalized_psd_bursts = []

    for burst in bursts:
        # Scale PSDs to 0-100 Hz
        f_baseline, Pxx_baseline = get_psd_in_range(baseline_psd, [0, 100])
        f_burst, Pxx_burst = get_psd_in_range(multitaper_psd(burst.data), [0, 100])
        
        # Take ratio between burst and baseline power
        Pxx_normed = Pxx_burst/Pxx_baseline

        # Anything above 1 indicates greater power for burst over baseline
        Pxx_clipped = np.clip(Pxx_normed, a_min=1, a_max=None)

        # Find max peak (primary freq) in burst/baseline power ratio
        peaks = signal.find_peaks(Pxx_clipped)[0]
        if len(peaks):
            max_power = max(Pxx_clipped[peaks])
            max_power_idx = np.where(Pxx_clipped == max_power)[0]
            max_freq = f_burst[max_power_idx][0]
            burst.normalized_psd = (f_burst, Pxx_normed)
            burst.primary_frequency_baseline = max_freq
        else:
            burst.primary_frequency_baseline = False
        
        normalized_psd_bursts.append(burst)
        
    return normalized_psd_bursts

rms_processed_recordings = []
for recording_idx, recording in enumerate(RECORDINGS_NEW):    
    print("P{} {}/{} {}".format(recording["age"], recording_idx+1, len(RECORDINGS_NEW), recording["path"]))
        
    for brain_area in ["cortex"]: #, "thalamus", "striatum"]:
        brain_channel = brain_area + "_channel"    
        brain_bursts = brain_area + "_bursts"
        brain_baseline = brain_area + "_baseline"
        brain_baseline_amp = brain_area + "_baseline_amplitude"
        
        if not brain_channel in recording:
            continue
        else:
            print('\t{}'.format(brain_area))
            
        recording_n = recording["recording"]
        channel_n = recording[brain_channel]
        session = Session(ROOT + recording["path"])

        data_all = session.recordings[recording_n].continuous[0].samples[get_slice_from_s(0, 60*60), channel_n]
        data_all_4_to_100 = butter_bandpass_filter(data_all, 4, 100)
        
        hist, bin_edges, rms_list = rms_hist(data_all_4_to_100, window=0.2)
        try:
            mu, sigma, thresh = fit_gaussian(hist, bin_edges, use_truncated_hist=False)
            thresh = max(thresh, mu+3*sigma)
        except Exception as e:
            print(e)
            print('Warning! Could not fit Gaussian!')
            continue
        print('\tComputed mean ({}) and sigma ({}), threshold ({})'.format(mu, sigma, thresh))

        bursts = get_burst_events(data_all_4_to_100, rms_list, window=0.2, threshold=thresh)
        print('\tBursts found')
        bursts = combine_adjacent_bursts(data_all_4_to_100, bursts, temporal_distance=0.2)
        print('\tAdjacent bursts combined')
        
        baseline_f, baseline_Pxx, baseline_data = get_baseline_power(data_all, bursts)
        baseline_psd = baseline_f, baseline_Pxx
        print('\tBaseline periods found')
        baseline_amplitude = get_baseline_amplitude(baseline_data)
        print('\tBaseline amplitudes found')
        
        bursts = filter_short_bursts(bursts, minimum_duration=0.2)
        print('\tFiltered short bursts')
        bursts = filter_rms(bursts, max_rms=1000)
        print('\tFiltered RMS')
        bursts = filter_min_peaks(bursts, min_peaks=5)            
        bursts = get_raw_bursts (data_all, bursts)
        print('\tGot 1500 Hz data')
        
        bursts = get_normalized_psd(bursts, baseline_psd)
        print('\tGot normed PSDs\n')
        
        recording = recording.copy()
        recording[brain_bursts] = bursts
        recording[brain_baseline] = baseline_psd
        recording[brain_baseline_amp] = baseline_amplitude
        recording["length"] = len(data_all)
        
    rms_processed_recordings.append(recording)

    save_data(rms_processed_recordings, '') # Path to save

# MUA detection

In [None]:
processed_recordings_mua = {}

for recording_idx, recording in enumerate(RECORDINGS[-5:]):
    print("P{} {}/{} {}".format(recording["age"], recording_idx+1, len(RECORDINGS), recording["path"]))
    
    for brain_area in ["cortex"]: #, "striatum", "thalamus"]:
        brain_channel = brain_area + "_channel"
        brain_mua = brain_area + "_mua"
        
        if not brain_channel in recording:
            continue
        else:
            print('\t{}'.format(brain_area))

        recording_n = recording["recording"]
        channel_n = recording[brain_channel]
        session = Session(ROOT + recording["path"])

        data_all = session.recordings[recording_n].continuous[0].samples[get_slice_from_s(0, 60*60), channel_n]
        
        recording["length"] = len(data_all) / SAMPLING_RATE
        recording[brain_mua] = detect_MUA(data_all, 5)
        
        d = butter_bandpass_filter(data_all, 400, 4000).copy()
        t_max = SAMPLING_RATE*5*60
        
    processed_recordings_mua[recording['path']] = recording