## Import all the required packages

In [None]:
import pickle
import os
from datetime import datetime

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import numpy as np
import pandas as pd

import skfuzzy as fuzz

from open_ephys.analysis import Session

from spectrum import dpss, pmtm

from statsmodels.tsa.arima.model import ARIMA
from statsmodels.stats.descriptivestats import sign_test
import statsmodels.api as sm
from statsmodels.formula.api import ols

from scipy import signal
from scipy import interpolate
from scipy import cluster
from scipy import io
from scipy import ndimage
from scipy import stats
from scipy.optimize import curve_fit

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score

## Classes, utility functions and constants used throughout notebook

In [None]:
###########
# Classes #
###########

class Burst:
    def __init__ (self, time, data):
        self.time = time
        self.data = data
        
    def get_xticks (self):
        return get_xticks(slice(*self.time))

############################
# Graph plotting functions #
############################

def format_plot (ax, legend=True, size=[12.5, 15]): 
    if legend:
        leg = plt.legend(frameon=False, bbox_to_anchor=(1, 1), loc='upper left')
        for legobj in leg.legendHandles:
            legobj.set_linewidth(2.0)
            legobj.set_alpha(1)
        for i in ax.get_legend().get_texts():
            i.set_fontsize(size[1])
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    for item in ([ax.title] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(size[0])
    for item in ([ax.xaxis.label, ax.yaxis.label]):
        item.set_fontsize(size[1])

def plot_graph_by_age (data, ylabel, scatter, colors, labels, save=False):
    for brain_idx, brain_area in enumerate(data.keys()):
        data_by_age = data[brain_area]
        xpos = np.arange(len(data_by_age.keys()))
        xlabels = [age for age in data_by_age.keys()]
        
        mean = [np.mean(d) for d in data_by_age.values()]
        stderr = [np.std(d)/len(d)**0.5 for d in data_by_age.values()]

        if scatter:
            xpos_scatter, value_scatter = [], []
            for idx, d in enumerate(data_by_age.values()):
                for val in d:
                    xpos_scatter.append(idx)
                    value_scatter.append(val)
            plt.plot(xpos_scatter, value_scatter, c=colors[brain_idx], lw=0, marker='o', alpha=0.25)

        plt.errorbar(xpos, mean, yerr=stderr, capsize=5, c=colors[brain_idx], label=labels[brain_idx])
        plt.xticks(xpos, xlabels)
        plt.xlabel('Age (days)')
        plt.ylabel(ylabel)
        
    ax = plt.gca()
        
    handles, labels = ax.get_legend_handles_labels()
    handles = [h[0] for h in handles]
    leg = ax.legend(handles, labels, frameon=False, bbox_to_anchor=(1, 1), loc='upper left')

    for legobj in leg.legendHandles:
        legobj.set_linewidth(2.0)
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    for item in ([ax.title] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(12.5)
    for item in ([ax.xaxis.label, ax.yaxis.label] + ax.get_legend().get_texts()):
        item.set_fontsize(15)
    
    if save:
        plt.savefig(FIG_ROOT + save, bbox_inches="tight")
    plt.show()
    
def plot_threshold (ax, data_range, threshold):
    x = get_xticks(data_range)
    y = np.ones(len(x)) * threshold
    
    ax.plot(x, y)
    ax.plot(x, -y)
    
def plot_frequency_bands (ax):
    ax.axvline(x=4, ymin=0, ymax=1, c='orange') # Alpha
    ax.axvline(x=8, ymin=0, ymax=1, c='orange') # Theta
    ax.axvline(x=13, ymin=0, ymax=1, c='orange') # Beta
    ax.axvline(x=30, ymin=0, ymax=1, c='orange') # Gamma
    ax.axvline(x=80, ymin=0, ymax=1, c='orange')
    
def plot_spectrogram (burst, ax=None):    
    freq_lim = 50
    window_size = 0.5
    
    window = int(SAMPLING_RATE*window_size)
    overlap = int(SAMPLING_RATE*window_size*0.99)
    
    f, t, Sxx = signal.spectrogram(burst.data, SAMPLING_RATE, nperseg=window, noverlap=overlap)
    
    f_lim = np.where(f <= freq_lim+25)[0][-1]
    f = f[:f_lim]
    Sxx = Sxx[:f_lim, :]
    
    t += burst.time[0]/SAMPLING_RATE
    
    if ax == None:
        fig, ax = plt.subplots()
    
    ax.pcolormesh(t, f, ndimage.gaussian_filter(Sxx, sigma=0), shading='gouraud', vmax=np.percentile(Sxx.flatten(), 98))
    ax.set_ylim([0, freq_lim])
    ax.set_yticks(np.arange(0, freq_lim, 20))    
    
def plot_psd (data, scale=False):
    f, Pxx = multitaper_psd(data)
    
    # Find 0-100 Hz range
    max_idx = np.where(f <= 100)[0][-1]
    f = f[:max_idx]
    Pxx = Pxx[:max_idx]
    
    # Scale total range to 1
    if scale:
        Pxx = scale_0_to_1 (Pxx)

    fig, ax = plt.subplots(figsize=[10,2])
    plt.plot(f, Pxx)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power spectral density')
    plt.xlim([0, 100])
    plot_frequency_bands(ax)
    plt.show()
    
def plot_group_boxplots (data, labels, xticks, colors=['blue', 'orange'], title='', xlabel='', ylabel='', yscale='linear', save=False):
    def set_box_color(bp, color):
        plt.setp(bp['boxes'], color=color)
        plt.setp(bp['whiskers'], color=color)
        plt.setp(bp['caps'], color=color)
        plt.setp(bp['medians'], color=color)
    
    fig = plt.figure()
    
    boxplots = []
    for idx, data_group in enumerate(data):
        shift = None
        
        if len(data) == 1:
            shift = 0
        elif len(data) == 2 and idx == 0:
            shift = -0.4
        elif len(data) ==2 and idx == 1:
            shift = 0.4

        positions = np.array(range(len(data_group))) * 2 + shift
        flier = dict(marker='o', markerfacecolor='none', markeredgecolor=colors[idx], alpha=0.5)
        bp = plt.boxplot(data_group, positions=positions, widths=0.6, flierprops=flier)
        set_box_color(bp, colors[idx])
        
        boxplots.append(bp)
    plt.xticks(range(0, len(xticks) * 2, 2), xticks)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.yscale(yscale)
        
    leg = plt.legend([bp["boxes"][0] for bp in boxplots], labels, frameon=False, fontsize=15, bbox_to_anchor=(1, 1), loc='upper left')
    for legobj in leg.legendHandles:
        legobj.set_linewidth(2.0)
        
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    for item in ([ax.title] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(12.5)
    for item in ([ax.xaxis.label, ax.yaxis.label] + ax.get_legend().get_texts()):
        item.set_fontsize(15)
    
    if save:
        plt.savefig(FIG_ROOT + save, bbox_inches="tight")
    plt.show()

#####################
# Utility functions #
#####################

def get_spikes_in_range (data_range, spike_times):
    spike_times = np.array(spike_times)
    
    start = data_range.start
    stop = data_range.stop  
    spike_time_idxs = np.where(np.logical_and(spike_times >= start, spike_times <= stop))[0]    
    
    return spike_times[spike_time_idxs]

def get_slice_from_s (s_beg, s_end):
    return slice(int(s_beg*SAMPLING_RATE), int(s_end*SAMPLING_RATE))

def get_xticks (slice_val):
    start, stop = slice_val.start, slice_val.stop
    return np.linspace(start, stop, num=stop-start) / SAMPLING_RATE

def save_data (data, filename=None):
    path = ROOT
    curr_time = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
    
    if filename:
        path = path + filename + '_' + curr_time + '.pickle'
    else:
        path = path + curr_time + '.pickle'
    
    with open(path, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    print('Saved data as', path)
    
def open_data (filename):
    path = os.path.join(ROOT, filename)
    
    with open(path, 'rb') as handle:
        data = pickle.load(handle)
        
    return data
    
#########################################
# Signal processing and stats functions #
#########################################

def welch_dof(x,y):
    dof = (np.var(x)/len(x) + np.var(y)/len(y))**2 / ((np.var(x)/len(x))**2 / (len(x)-1) + (np.var(y)/len(y))**2 / (len(y)-1))
    return dof 

def get_peaks (data):
    min_prominence = 0.5*rms(data)
    min_lag = int(SAMPLING_RATE * 0.025)
    
    peaks = signal.find_peaks(data, distance=min_lag, prominence=min_prominence)[0]
    troughs = signal.find_peaks(-data, distance=min_lag, prominence=min_prominence)[0]

    combined = np.array(sorted(np.concatenate([peaks, troughs])))

    peak_ranks = []
    trough_ranks = []

    for point_idx, point in enumerate(combined):
        if point in peaks:
            peak_ranks.append(point_idx)
        else:
            trough_ranks.append(point_idx)  
     
    def get_max_height_from_ranks (ranks):
        heights = [abs(data[combined[rank]]) for rank in ranks]
        
        return ranks[np.argmax(heights)]
    
    def get_nonadjacent_ranks (ranks):
        nonadjacent_ranks = []
        temp_ranks = []
        for rank_idx, curr_rank in enumerate(ranks):
            if rank_idx == 0:
                temp_ranks.append(curr_rank)
            else:
                prev_rank = ranks[rank_idx-1]

                if prev_rank+1 == curr_rank:
                    temp_ranks.append(curr_rank)
                else:
                    nonadjacent_ranks.append(get_max_height_from_ranks(temp_ranks))
                    temp_ranks = [curr_rank]

            if rank_idx+1 == len(ranks):
                nonadjacent_ranks.append(get_max_height_from_ranks(temp_ranks))

        return nonadjacent_ranks
        
    filtered_ranks = sorted(
        get_nonadjacent_ranks (peak_ranks) + 
        get_nonadjacent_ranks (trough_ranks)
    )
    filtered_points = combined[filtered_ranks]

    return filtered_points, len(get_nonadjacent_ranks (peak_ranks))

def twoway_anova (data, variable_labels):
    factor_1_var, factor_2_var, value_var = variable_labels
    
    prepared_data = {}
    prepared_data[factor_1_var] = []
    prepared_data[factor_2_var] = []
    prepared_data[value_var] = []
    
    for factor_1 in data.keys():
        factor_2_values = data[factor_1]
        for factor_2 in factor_2_values.keys():
            values = factor_2_values[factor_2]
            for value in values:
                prepared_data[factor_1_var].append(factor_1) 
                prepared_data[factor_2_var].append(factor_2) 
                prepared_data[value_var].append(value) 
    
    df = pd.DataFrame(prepared_data)
    
    formula = "{value_var} ~ C({factor_1_var}) + C({factor_2_var}) + C({factor_1_var}):C({factor_2_var})".format(
        factor_1_var=factor_1_var,
        factor_2_var=factor_2_var,
        value_var=value_var
    )
    model = ols(formula, data=df).fit()
    
    return sm.stats.anova_lm(model, typ=2)

def get_mean_psd (bursts):
    fs = None
    Pxxs = []
    
    for burst in bursts:
        if hasattr(burst, "normalized_psd"):
            f_burst, Pxx_burst = burst.normalized_psd
            
            fs = f_burst
            Pxxs.append(Pxx_burst)

    Pxx_mean = np.mean(Pxxs, axis=0)
    Pxx_stderr = np.std(Pxxs, axis=0) / len(Pxxs)**0.5
    
    f, Pxx_mean = get_psd_in_range((fs, Pxx_mean), [0, 40])
    f, Pxx_stderr = get_psd_in_range((fs, Pxx_stderr), [0, 40])
    
    return f, Pxx_mean, Pxx_stderr

def butter_bandpass_filter(data, lowcut, highcut, order=3):
    nyq = 0.5 * SAMPLING_RATE
    low = lowcut / nyq
    high = highcut / nyq
    sos = signal.butter(order, [low, high], analog=False, btype='bandpass', output='sos')
    y = signal.sosfilt(sos, data)
    return y

def butter_lowpass_filter(data, lowcut, order=3):
    nyq = 0.5 * SAMPLING_RATE
    low = lowcut / nyq
    sos = signal.butter(order, low, analog=False, btype='lowpass', output='sos')
    y = signal.sosfilt(sos, data)
    return y

def butter_highpass_filter(data, highcut, order=3):
    nyq = 0.5 * SAMPLING_RATE
    high = highcut / nyq
    sos = signal.butter(order, high, analog=False, btype='highpass', output='sos')
    y = signal.sosfilt(sos, data)
    return y

def scale_0_to_1 (data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

def multitaper_psd (data, NW=3, k=5, resample_freq=1000, show_progress=False):
    # Resample to 1000 Hz to speed up multitaper
    data = signal.resample(data, int(len(data)*resample_freq/SAMPLING_RATE))
    
    # How many seconds for each window (N = fs*duration)
    window = 1
    N = int(resample_freq*window)
    [tapers, eigen] = dpss(N, NW, k)
    
    Pxx_list = []
    
    last_progress = 0
    
    # Proceed through signal advancing in steps of size N
    for idx in range(0, len(data), N//10):
        if show_progress:
            progress = (idx/resample_freq)//60
            if last_progress != progress:
                last_progress = progress
                print('Processed {} mins'.format(progress))
        
        y = data[idx:idx+N]
        
        # For a constant window, pad data with zeros to make sure
        # each window is of length N
        if len(y) < N:
            padding = N - len(y)
            y = np.concatenate( (y, np.zeros(padding)) )
        
        Sk_complex, weights, eigenvalues = pmtm(y, e=eigen, v=tapers, show=False)
        Sk = abs(Sk_complex)**2
        Sk = np.mean(Sk * np.transpose(weights), axis=0)

        Pxx_list.append( Sk[0:N//2] )
        
    # Get list of frequencies to accompany PSD
    dt = 1.0/resample_freq
    f = np.linspace(0.0, 1.0/(2.0*dt), N//2)

    # Average the PSD over all windows
    Pxx = np.mean(Pxx_list, axis=0)
    
    return f, Pxx

def get_psd_in_range (PSD, freqs):
    f, Pxx = PSD
    
    f_low = np.where(f >= freqs[0])[0][0]
    f_high = np.where(f >= freqs[1])[0][0]
    
    return f[f_low:f_high], Pxx[f_low:f_high]

def get_mean_PSD_freq (PSD):
    max_freq = 80    
    f, Pxx = PSD
    
    max_freq_idx = np.where(f <= 80)[0][-1]
    return np.sum(f[:max_freq_idx]*Pxx[:max_freq_idx])/np.sum(Pxx[:max_freq_idx])

def get_relative_power (PSD):
    f, Pxx = PSD
    
    max_idx = np.where(f >= 100)[0][0]
    f = f[:max_idx]
    Pxx = Pxx[:max_idx]
    
    spindle_idxs = np.where(np.logical_and(f > 8, f <= 30))
    spindle_power = np.sum(Pxx[spindle_idxs]) / sum(Pxx)

    gamma_idxs = np.where(np.logical_and(f > 30, f <= 80))
    gamma_power = np.sum(Pxx[gamma_idxs]) / sum(Pxx)

    return spindle_power, gamma_power

def get_maximal_frequency (PSD):
    f, Pxx = PSD
    
    # Discard DC value
    f = f[1:]
    Pxx = Pxx[1:]

    return f[np.argmax(Pxx)]

######################
# Constant variables #
######################

SAMPLING_RATE    = 30000
ROOT             = '/' # Where to load data
FIG_ROOT         = '/' # Where to save figures
DATA_RANGE_ALL   = get_slice_from_s(0, 60*60) 

COLOR_CORTEX   = 'tab:green'
COLOR_THALAMUS = 'tab:purple'
COLOR_STRIATUM = 'tab:orange'

## Data structure for recordings used in analysis

In [None]:
RECORDINGS = [
    {
        "path": "2020-09-01_15-56-56",
        "age":  6,
        "thalamus": False,
        "striatum_channel": 2,
        "cortex_channel": 14,
        "recording": 1,
        "striatum_sigma": 2.125,
        "cortex_sigma": 2.125
    },
         {
        "path": "2020-09-30_13-40-33",
        "age": 10,
        "thalamus": False,
        "striatum_channel": 2,
        "cortex_channel": 14,
        "recording": 0,
        "striatum_sigma": 1.25,
        "cortex_sigma": 1.25
    },
    { 
        "path": "2020-10-05_13-54-29",
        "age": 15,
        "thalamus": False,
        "striatum_channel": 2,
        "cortex_channel": 15,
        "recording": 0,
        "striatum_sigma": 1.75,
        "cortex_sigma": 1.625
    },
    {
        "path": "2020-10-28_15-19-52",
        "age": 10,
        "thalamus": False,
        "striatum_channel": 2,
        "cortex_channel": 0,
        "recording": 0,
        "striatum_sigma": 1.25,
        "cortex_sigma": 1.25
    },
    {
        "path": "2021-02-09_16-08-16",
        "age":  8,
        "thalamus": False,
        "striatum_channel": 2,
        "cortex_channel": 15,
        "recording": 0,
        "striatum_sigma": 1.75,
        "cortex_sigma": 1.75
    },
    {
        "path": "2021-02-10_15-32-43",
        "age":  9,
        "thalamus": False,
        "striatum_channel": 13,
        "cortex_channel": 15,
        "recording": 0,
        "striatum_sigma": 2,
        "cortex_sigma": 1.5
    },
    
    { 
        "path": "2021-02-11_15-57-01", 
        "age": 10, 
        "thalamus": True, 
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 15,
        "recording": 0,
        "striatum_sigma": 1.75, 
        "thalamus_sigma": 1.75,
        "cortex_sigma": 1.25
    },
    { 
        "path": "2021-02-12_15-25-59", 
        "age": 11, 
        "thalamus": True, 
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 15,
        "recording": 0,
        "striatum_sigma": 1.25, 
        "thalamus_sigma": 1.25,
        "cortex_sigma": 1.25
    },
    { 
        "path": "2021-03-01_11-16-08",
        "age":  5, 
        "thalamus": True,
        "striatum_channel": 13,
        "thalamus_channel": 18,
        "cortex_channel": 0,
        "recording": 0,
        "striatum_sigma": 2.875, 
        "thalamus_sigma": 3.5,
        "cortex_sigma": 2.75
    },
    { 
        "path": "2021-03-03_11-02-08",
        "age":  7,
        "thalamus": True,
        "striatum_channel": 2,
        "thalamus_channel": 29,
        "cortex_channel": 0,
        "recording": 0,
        "striatum_sigma": 1, 
        "thalamus_sigma": 0.55,
        "cortex_sigma": 0.75
    },
    { 
        "path": "2021-03-05_10-58-31", 
        "age":  9, 
        "thalamus": True, 
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 0,
        "recording": 0,
        "striatum_sigma": 1.75, 
        "thalamus_sigma": 2.75,
        "cortex_sigma": 1.75
    },
    { 
        "path": "2021-03-08_10-50-18", 
        "age": 12,
        "thalamus": True, 
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 15,
        "recording": 0,
        "striatum_sigma": 1.25, 
        "thalamus_sigma": 1.75,
        "cortex_sigma": 1.25
    },
    { 
        "path": "2021-03-10_10-59-57",
        "age": 14,
        "thalamus": True,
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 15,
        "recording": 0,
        "striatum_sigma": 1.25, 
        "thalamus_sigma": 1.75,
        "cortex_sigma": 1.25
    },
    {
        "path": "2021-03-11_10-48-40",
        "age":  8, 
        "thalamus": True, 
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 0,
        "recording": 1,
        "striatum_sigma": 2.5, 
        "thalamus_sigma": 3,
        "cortex_sigma": 2.5
    },
    { 
        "path": "2021-03-12_10-54-00", 
        "age":  9, 
        "thalamus": True, 
        "striatum_channel": 13, 
        "thalamus_channel": 29, 
        "cortex_channel": 15,
        "recording": 1,
        "striatum_sigma": 2.125,
        "thalamus_sigma": 2.75,
        "cortex_sigma": 2.125
    },
    {
        "path": "2021-03-15_10-52-05", 
        "age": 12,
        "thalamus": True, 
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 0,
        "recording": 0,
        "striatum_sigma": 1.25, 
        "thalamus_sigma": 1.75,
        "cortex_sigma": 1.25
    },
    { 
        "path": "2021-05-14_15-40-25",
        "age": 10, 
        "thalamus": True, 
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 0,
        "recording": 0,
        "striatum_sigma": 1.25, 
        "thalamus_sigma": 1.75,
        "cortex_sigma": 1.25
    },
    { 
        "path": "2021-05-21_15-10-39",
        "age": 12, 
        "thalamus": True, 
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 15,
        "recording": 0,
        "striatum_sigma": 1.75, 
        "thalamus_sigma": 1.75,
        "cortex_sigma": 1.25
    },
    { 
        "path": "2021-05-26_15-12-59", 
        "age": 10, 
        "thalamus": True,
        "striatum_channel": 2,
        "thalamus_channel": 29,
        "cortex_channel": 15,
        "recording": 0,
        "striatum_sigma": 2.625,
        "thalamus_sigma": 2.375,
        "cortex_sigma": 2.125
    },
    { 
        "path": "2021-08-03_15-52-34", 
        "age": 5, 
        "thalamus": True,
        "striatum_channel": 2,
        "thalamus_channel": 29,
        "cortex_channel": 15,
        "recording": 1
    }, 
    {
        "path": "2021-08-09_10-14-01",
        "age": 6,
        "thalamus": True,
        "striatum_channel": 2,
        "thalamus_channel": 18,
        "cortex_channel": 15,
        "recording": 0
    },
    {
        "path": "2020-10-26_14-03-44",
        "age": 15,
        "recording": 0,
        "striatum_channel": 2,
        "thalamus_channel": 29,
        "cortex_channel": 15
    },
    {
        "path": "2021-06-25_11-20-59",
        "age": 16,
        "recording": 0,
        "striatum_channel": 2,
        "thalamus_channel": 29,
        "cortex_channel": 15
    }
]

# Sort by age
RECORDINGS = sorted(RECORDINGS, key=lambda k: k['age']) 

## Burst detection
### Envelope thresholding method

In [None]:
def detect_bursts_from_envelope(envelope, low_threshold):
    bursts = []
    
    start_t = False
    burst = []
    
    for t, sample in enumerate(envelope):
        if sample > low_threshold:
            if not start_t:
                start_t = t
            burst.append(sample)
            
        if start_t and (sample <= low_threshold or t+1 == len(envelope)):
            bursts.append( Burst([start_t, t], burst) )
            start_t = False
            burst = []
    
    return bursts
        

# Filter bursts shorter than minimum duration
def filter_short_bursts (bursts, minimum_duration):
    filtered_bursts = []
    
    for burst in bursts:
        if (burst.time[1] - burst.time[0]) > (minimum_duration*SAMPLING_RATE):
            filtered_bursts.append(burst)
            
    return filtered_bursts

# Filter bursts that have fewer than minimum points above threshold
def filter_minimum_peaks (bursts, threshold, minimum_peaks):
    filtered_bursts = []
    
    for burst in bursts:
        peaks = signal.find_peaks(np.abs(burst.data), prominence=threshold)[0]
        diff_count = 0
        
        for peak_idx, peak in enumerate(peaks):
            if peak_idx+1 == len(peaks):
                continue
                
            next_peak = peaks[peak_idx+1]
            diff = (next_peak - peak)/SAMPLING_RATE

            if diff_count == (minimum_peaks-1):
                break
            
            if diff < 0.2:
                diff_count += 1
            else:
                diff = 0
                
        
        if diff_count == minimum_peaks-1:
            filtered_bursts.append(burst)
            
    return filtered_bursts

# If bursts have been bandpass filtered, use this function
# to return unfiltered bursts using saved time points for each burst
def get_unfiltered_bursts (bursts, data):
    unfiltered_bursts = []
    
    for burst in bursts:
        burst_slice = slice(*burst.time)
        unfiltered_bursts.append( Burst(burst.time, data[burst_slice]) )

    return unfiltered_bursts

def get_baseline_periods (bursts, data):
    baseline_bursts = []
    
    padding = int(SAMPLING_RATE * 0.5)
    
    for idx, burst in enumerate(bursts):
        # If on final burst of set
        if (idx+1) == len(bursts):
            continue
            
        next_burst_time = bursts[idx+1][0]
        curr_burst_time = burst[0]
        time_diff = next_burst_time[0] - curr_burst_time[1]

        if time_diff > (padding*4):
            start_t = next_burst_time[0] + padding
            end_t = curr_burst_time[1] - padding
            baseline_data = data[start_t:end_t]
            baseline_bursts.append( ([start_t, end_t], baseline_data) )

    return baseline_bursts

def combined_adjacent_bursts(bursts, data, temporal_distance):
    combined_bursts = []
    temp_bursts = []
    
    # Take first and last bursts and combine
    def combine_bursts (temp_bursts):
        start = temp_bursts[0].time[0]
        end = temp_bursts[-1].time[1]
        
        return Burst([start, end], data[start:end])
    
    for burst_idx, burst in enumerate(bursts):    
        if burst_idx == 0:
            continue
            
        prev = bursts[burst_idx-1]
        
        prev_time = prev.time
        curr_time = burst.time
        
        temp_bursts.append(prev)
        
        # If bursts are greater than 1 second apart
        if (curr_time[0] - prev_time[1] > temporal_distance*SAMPLING_RATE):
            combined_bursts.append(combine_bursts(temp_bursts))
            temp_bursts = []
        
        # Or if on final burst of set
        if burst_idx+1 == len(bursts):
            temp_bursts.append(burst)
            
            combined_bursts.append(combine_bursts(temp_bursts))
            temp_bursts = []
            
    return combined_bursts


def hl_envelopes_idx(s, dmin=1, dmax=1):
    # locals min      
    lmin = (np.diff(np.sign(np.diff(s))) > 0).nonzero()[0] + 1 
    # locals max
    lmax = (np.diff(np.sign(np.diff(s))) < 0).nonzero()[0] + 1 

    # global max of dmax-chunks of locals max 
    lmin = lmin[[i+np.argmin(s[lmin[i:i+dmin]]) for i in range(0,len(lmin),dmin)]]
    # global min of dmin-chunks of locals min 
    lmax = lmax[[i+np.argmax(s[lmax[i:i+dmax]]) for i in range(0,len(lmax),dmax)]]
    
    return lmin,lmax
            

# Cobmines all previous functions into a single routine
# Returns list of burst tuples in form ([start time, end time], [burst data])

# Samples              Array of sampled data
# Sigma                How many standard deviations above mean for threshold
# Minimum_duration     Minimum burst duration to keep

def run_burst_procedure (data, minimum_peaks, minimum_duration, sigma):
    # Bandpass filter signal
    data_bandpass = butter_bandpass_filter(data, lowcut=1, highcut=100)
        
    # Get envelope of signal
    low_idx, high_idx = hl_envelopes_idx(data_bandpass, dmin=30, dmax=30)
    x = np.arange(0, len(data_bandpass))
    high_env = np.interp(x, x[high_idx], data_bandpass[high_idx])
    low_env = np.interp(x, x[low_idx], data_bandpass[low_idx])
    mean_env = []
    for t, _ in enumerate(high_env):
        a, b = high_env[t], abs(low_env[t])
        if a > b:
            mean_env.append(a)
        else:
            mean_env.append(b)
    
    # Clip data to prevent skew of std from random outlier events
    data_clipped = np.clip(data_bandpass, a_min=-1000, a_max=1000)
    
    # Get thresholds
    mean, std = np.mean(data_clipped), np.std(data_clipped)
    burst_threshold = mean + std*sigma
    print('Calculated threshold ({}, sigma={})\n'.format(
        round(burst_threshold, 2),
        round(sigma, 2)
    ))

    bursts = detect_bursts_from_envelope(mean_env, burst_threshold)
    print('\nBursts detected')
    
    bursts = combined_adjacent_bursts(bursts, mean_env, temporal_distance=0.2)
    print('Combined adjacent bursts')
    
    bursts = filter_rms(bursts, max_rms=1000)
    print('\tFiltered RMS')

    bursts = filter_short_bursts(bursts, minimum_duration)
    print('Short bursts filtered')
    
    bursts = get_unfiltered_bursts(bursts, data_bandpass)
    print('Got unfiltered bursts')
    
    bursts = filter_minimum_peaks(bursts, burst_threshold, minimum_peaks=minimum_peaks)
    print('Filtered minimum thresold points')
    
    return bursts

In [None]:
def load_data (recordings, get_mua=False, get_psd=False):
    bursts = []
    
    for idx, recording in enumerate(recordings[:17]):
        print("\nRecording {} ({}/{})".format(
            recording["path"],
            idx+1,
            len(recordings[:17])
        ))
        
        data_range = get_slice_from_s(0, 60*60)
        recording_n = recording["recording"]
        striatum_channel_n = recording["striatum_channel"]
        thalamus_channel_n = recording["thalamus_channel"] if recording["thalamus"] else 0
        cortex_channel_n = recording["cortex_channel"]

        session = Session(ROOT + recording["path"])
        striatum_data = session.recordings[recording_n].continuous[0].samples[data_range, striatum_channel_n]
        thalamus_data = session.recordings[recording_n].continuous[0].samples[data_range, thalamus_channel_n]
        cortex_data = session.recordings[recording_n].continuous[0].samples[data_range, cortex_channel_n]

        minimum_duration = 0.2
        minimum_peaks = 10
        
        recording = recording.copy()
        
        # Striatum
        if len(striatum_data):
            recording["length"] = len(striatum_data)
            recording["striatum_bursts"]  = run_burst_procedure(
                data=striatum_data,
                minimum_peaks=minimum_peaks,
                minimum_duration=minimum_duration,
                sigma=recording["striatum_sigma"]
            )
            if get_mua:
                recording["striatum_MUA"] = detect_MUA(
                    data=striatum_data
                )
            if get_psd:
                recording["striatum_PSD"] = multitaper_psd (
                    striatum_data,
                    NW=3, k=5, resample_freq=1000,
                    show_progress=True
                )
                
        # Thalamus
        if len(thalamus_data) and recording["thalamus"]:
            recording["thalamus_bursts"] = run_burst_procedure(
                data=thalamus_data,
                minimum_peaks=minimum_peaks,
                minimum_duration=minimum_duration,
                sigma=recording["thalamus_sigma"]
            )
            if get_mua:
                recording["thalamus_MUA"] = detect_MUA(
                    data=thalamus_data
                )
            if get_psd:
                recording["thalamus_PSD"] = multitaper_psd (
                    thalamus_data,
                    NW=3, k=5, resample_freq=1000,
                    show_progress=True
                )
        
        # Cortex
        if len(cortex_data):
            recording["cortex_bursts"]  = run_burst_procedure(
                data=cortex_data,
                minimum_peaks=minimum_peaks,
                minimum_duration=minimum_duration,
                sigma=recording["cortex_sigma"]
            )
            if get_mua:
                recording["cortex_MUA"] = detect_MUA(
                    data=cortex_data
                )
            if get_psd:
                recording["cortex_PSD"] = multitaper_psd (
                    cortex_data,
                    NW=3, k=5, resample_freq=1000,
                    show_progress=True
                )

        bursts.append(recording)
    
    return bursts
        
processed_recordings = load_data(RECORDINGS, get_mua=False, get_psd=False)
save_data(processed_recordings, 'processed_recordings_envelope')

## Burst detection
### RMS thresholding method

In [None]:
def rms (data):
    square = [d**2 for d in data]
    mean = np.mean(square)
    root = mean**0.5
    return root

def rms_hist(data, window=0.2):
    rms_list = []
    
    chunk_size = int(SAMPLING_RATE*window) # Convert window in s to samples
    for i in range(0, len(data), chunk_size):
        data_chunk = data[i:i+chunk_size]
        rms_list.append(rms(data_chunk))
    
    n_bins = int(np.percentile(rms_list, 85))
    hist, bin_edges = np.histogram(rms_list, density=True, bins=n_bins)
    
    return hist, bin_edges, rms_list

def plot_gaussian (bin_centres, hist, hist_fit):
    fig = plt.figure()
    plt.plot(bin_centres, hist, label='RMS', c='tab:red')
    plt.plot(bin_centres, hist_fit, label='Fitted Gaussian', c='black')
    plt.xlabel('RMS')
    plt.ylabel('Density')
    
    leg = plt.legend(frameon=False, fontsize=15, bbox_to_anchor=(0.925, 1), loc='upper left')
    for legobj in leg.legendHandles:
        legobj.set_linewidth(2.0)
        
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    for item in ([ax.title] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(12.5)
    for item in ([ax.xaxis.label, ax.yaxis.label] + ax.get_legend().get_texts()):
        item.set_fontsize(15)
    
    plt.show()

def fit_gaussian(hist, bin_edges, use_truncated_hist):
    bin_centres = (bin_edges[:-1] + bin_edges[1:])/2
    
    if use_truncated_hist:
        hist_30_idx = np.where(bin_centres <= 30)[0][-1]
        peak_idx = np.argmax(hist[:hist_30_idx])
    else:
        peak_idx = np.argmax(hist)
    peak = bin_centres[peak_idx]

    trimmed_idx = max(4, peak_idx*2)
    trimmed_bin_centres = bin_centres[:trimmed_idx]
    trimmed_hist = hist[:trimmed_idx]

    def gauss(x, *p):
        A, mu, sigma = p
        return A*np.exp(-(x-mu)**2/(2.*sigma**2))

    p0 = [1., peak, 1.]
    
    coeff, var_matrix = curve_fit(gauss, trimmed_bin_centres, trimmed_hist, p0=p0)
    
    hist_fit = gauss(bin_centres, *coeff)
    
    plot_gaussian (bin_centres, hist, hist_fit)
    
    return coeff[1], coeff[2]

def get_burst_events (data, rms_list, mu, sigma, window=0.2):
    bursts = []
    threshold = mu + 2*sigma
    
    chunk_size = int(SAMPLING_RATE*window) # Convert window in s to samples
    for chunk_idx, chunk_rms in enumerate(rms_list):
        if chunk_rms > threshold:
            t_start = chunk_idx*chunk_size
            t_end = min(len(data), (chunk_idx+1)*chunk_size)
            
            burst = Burst([t_start, t_end], data[t_start:t_end])
            bursts.append(burst)
    
    return bursts

def combine_adjacent_bursts(data, bursts, temporal_distance):
    combined_bursts = []
    temp_bursts = []
    
    # Take first and last bursts and combine
    def combine_bursts (temp_bursts):
        start = temp_bursts[0].time[0]
        end = temp_bursts[-1].time[1]
        
        return Burst([start, end], data[start:end])
    
    for burst_idx, burst in enumerate(bursts):    
        if burst_idx == 0:
            continue
            
        prev = bursts[burst_idx-1]
        
        prev_time = prev.time
        curr_time = burst.time
        
        temp_bursts.append(prev)
        
        # If bursts are greater than x seconds apart
        if (curr_time[0] - prev_time[1] > temporal_distance*SAMPLING_RATE):
            combined_bursts.append(combine_bursts(temp_bursts))
            temp_bursts = []
        
        # Or if on final burst of set
        if burst_idx+1 == len(bursts):
            temp_bursts.append(burst)
            
            combined_bursts.append(combine_bursts(temp_bursts))
            temp_bursts = []
            
    return combined_bursts

def filter_short_bursts (bursts, minimum_duration):
    minimum_samples = minimum_duration*SAMPLING_RATE
    
    return [b for b in bursts if (b.time[1]-b.time[0]) >= minimum_samples]

def filter_rms (bursts, max_rms):
    filtered_bursts = []
    
    for burst in bursts:
        rms_list = []
        chunk_size = int(SAMPLING_RATE*0.2) # Convert window in s to samples
        for i in range(0, len(burst.data), chunk_size):
            data_chunk = burst.data[i:i+chunk_size]
            rms_list.append(rms(data_chunk))
        if np.max(rms_list) < max_rms:
            filtered_bursts.append(burst)
    
    return filtered_bursts

def filter_min_peaks (bursts, min_peaks):
    filtered_bursts = []
    
    for burst in bursts:
        _, n_peaks = get_peaks(burst.data)
        
        if n_peaks >= min_peaks:
            filtered_bursts.append(burst)
    
    return filtered_bursts

def get_raw_bursts (data, bursts):
    unfiltered_bursts = []

    for burst in bursts:
        unfiltered_burst = Burst(
            burst.time,
            data[burst.time[0]:burst.time[1]]
        )
        unfiltered_bursts.append(unfiltered_burst)
    
    return unfiltered_bursts

def get_baseline_power (data, bursts):
    # Get baseline periods
    baseline_data = []
    for burst_idx, burst in enumerate(bursts):
        # If on last burst, baseline will last until end of recording
        if burst_idx+1 == len(bursts):
            next_burst_time = [len(data)]
        else:
            next_burst_time = bursts[burst_idx+1].time

        curr_burst_time = burst.time  
        
        baseline_period = data[curr_burst_time[1]:next_burst_time[0]]
        # Minimum length of 1s
        if len(baseline_period)/SAMPLING_RATE >= 1:
            # Trim to max 10s
            baseline_period = baseline_period[get_slice_from_s(0, 5)]
            baseline_data.append(baseline_period)
            
    # Take mean of all baseline psd's
    baseline_psds = []
    for baseline_period in baseline_data:
        baseline_psds.append(multitaper_psd(baseline_period))

    # If no baseline periods, just leave it as 'False'
    if len(baseline_psds):
        f = baseline_psds[0][0]
        Pxx = np.mean([psd[1] for psd in baseline_psds], axis=0)
        return f, Pxx, baseline_data
    else:
        return False, False, False
    
def get_baseline_amplitude (data):
    if data and len(data):
        amps = []
        for baseline_period in data:
            baseline_period = butter_bandpass_filter(baseline_period, 4, 100)
            amps.append(np.max(baseline_period)-np.min(baseline_period))
        return np.mean(amps, axis=0) 
    else:
        return False
    
def get_normalized_psd (bursts, baseline_psd):
    normalized_psd_bursts = []

    for burst in bursts:
        # Scale PSDs to 0-100 Hz
        f_baseline, Pxx_baseline = get_psd_in_range(baseline_psd, [0, 100])
        f_burst, Pxx_burst = get_psd_in_range(multitaper_psd(burst.data), [0, 100])
        
        # Take ratio between burst and baseline power
        Pxx_normed = Pxx_burst/Pxx_baseline

        # Anything above 1 indicates greater power for burst over baseline
        Pxx_clipped = np.clip(Pxx_normed, a_min=1, a_max=None)

        # Find max peak (primary freq) in burst/baseline power ratio
        peaks = signal.find_peaks(Pxx_clipped)[0]
        if len(peaks):
            max_power = max(Pxx_clipped[peaks])
            max_power_idx = np.where(Pxx_clipped == max_power)[0]
            max_freq = f_burst[max_power_idx][0]
            burst.normalized_psd = (f_burst, Pxx_normed)
            burst.primary_frequency_baseline = max_freq
        else:
            burst.primary_frequency_baseline = False
        
        normalized_psd_bursts.append(burst)
        
    return normalized_psd_bursts

rms_processed_recordings = []
for recording_idx, recording in enumerate(RECORDINGS):
    print("P{} {}/{} {}".format(recording["age"], recording_idx+1, len(RECORDINGS), recording["path"]))
    
    if recording["age"] > 12:
        continue
    
    for brain_area in ["cortex", "thalamus", "striatum"]:
        brain_channel = brain_area + "_channel"
        brain_bursts = brain_area + "_bursts"
        brain_baseline = brain_area + "_baseline"
        brain_baseline_amp = brain_area + "_baseline_amplitude"
        
        if not brain_channel in recording:
            continue
        else:
            print('\t{}'.format(brain_area))
            
        recording_n = recording["recording"]
        channel_n = recording[brain_channel]
        session = Session(ROOT + recording["path"])

        data_all = session.recordings[recording_n].continuous[0].samples[get_slice_from_s(0, 60*60), channel_n]
        data_all_4_to_100 = butter_bandpass_filter(data_all, 4, 100)
        
        hist, bin_edges, rms_list = rms_hist(data_all_4_to_100, window=0.2)
        try:
            use_truncated_hist = recording["path"] != "2021-02-11_15-57-01" and recording["age"] > 9 and recording["age"] < 13
            mu, sigma = fit_gaussian(hist, bin_edges, use_truncated_hist=use_truncated_hist)
        except:
            print('Warning! Could not fit Gaussian!')
            continue
        print('\tComputed mean ({}) and sigma ({})'.format(mu, sigma))

        bursts = get_burst_events(data_all_4_to_100, rms_list, mu, sigma, window=0.2)
        print('\tBursts found')
        bursts = combine_adjacent_bursts(data_all_4_to_100, bursts, temporal_distance=0.2)
        print('\tAdjacent bursts combined')
        
        baseline_f, baseline_Pxx, baseline_data = get_baseline_power(data_all, bursts)
        baseline_psd = baseline_f, baseline_Pxx
        print('\tBaseline periods found')
        baseline_amplitude = get_baseline_amplitude(baseline_data)
        print('\tBaseline amplitudes found')
        
        bursts = filter_short_bursts(bursts, minimum_duration=0.2)
        print('\tFiltered short bursts')
        bursts = filter_rms(bursts, max_rms=1000)
        print('\tFiltered RMS')
        bursts = filter_min_peaks(bursts, min_peaks=5)
        bursts = get_raw_bursts (data_all, bursts)
        print('\tGot 1500 Hz data')
        
        bursts = get_normalized_psd(bursts, baseline_psd)
        print('\tGot normed PSDs\n')
        
        recording = recording.copy()
        recording[brain_bursts] = bursts
        recording[brain_baseline] = baseline_psd
        recording[brain_baseline_amp] = baseline_amplitude
        recording["length"] = len(data_all)
        
    rms_processed_recordings.append(recording)

## MUA

In [None]:
def detect_MUA (data, standard_deviations=5, silent=False):
    # Filter between 0.4-4kHz
    data = butter_bandpass_filter(data, 400, 4000)
    
    # Produce threshold as anything below -5 std
    threshold = standard_deviations*np.std(data)
    
    if not silent:
        print('\tCalculated MUA threshold ({})\n'.format(round(threshold, 2)))
    
    # Find points below threshold
    spike_times = np.where(data < -threshold)[0]
    
    # Several points in spike waveform might be below threshold
    # So filter out contiguous (t and t+1) timepoints
    spike_times_unique = []
    prev_t = 0
    for t_idx, t in enumerate(spike_times):        
        if prev_t and (prev_t+1 != t):
            spike_times_unique.append(prev_t)
        if (t_idx+1) == len(spike_times) and (prev_t+1 != t):
            spike_times_unique.append(t)
        prev_t = t
    
    # How far to extract waveform in MUA signal from trough
    size = int(SAMPLING_RATE * 1/1000 * 1) # 1 ms

    # Extract waveforms and align to trough of waveform
    spike_waveforms = []
    for spike_t in spike_times_unique:
        t_start = np.clip(spike_t - size, a_min=0, a_max=None)
        t_end = np.clip(spike_t + size, a_min=None, a_max=len(data))
        spike_waveform = data[t_start:t_end]
        
        trough_t = np.argmin(spike_waveform) + t_start
        trough_t_start = np.clip(trough_t-size, a_min=0, a_max=None)
        trough_t_end = np.clip(trough_t+size, a_min=None, a_max=len(data))
        
        spike_waveform_aligned = data[trough_t_start:trough_t_end]
        if np.max(spike_waveform_aligned) > 0 and np.max(spike_waveform_aligned[:15]) > 0:       
            spike_waveforms.append(spike_waveform_aligned)
    
    return spike_times_unique, spike_waveforms

In [None]:
processed_recordings_mua = []

for recording_idx, recording in enumerate(RECORDINGS):
    print("P{} {}/{} {}".format(recording["age"], recording_idx+1, len(RECORDINGS), recording["path"]))

    for brain_area in ["cortex", "striatum", "thalamus"]:
        brain_channel = brain_area + "_channel"
        brain_mua = brain_area + "_mua"
        
        if not brain_channel in recording:
            continue
        else:
            print('\t{}'.format(brain_area))

        recording_n = recording["recording"]
        channel_n = recording[brain_channel]
        session = Session(ROOT + recording["path"])

        data_all = session.recordings[recording_n].continuous[0].samples[get_slice_from_s(0, 60*60), channel_n]
        
        recording["length"] = len(data_all) / SAMPLING_RATE
        recording[brain_mua] = detect_MUA(data_all)
        
    processed_recordings_mua.append(recording)

## MUA activity across developmental time periods

In [None]:
spikes_per_second = {
    "striatum": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "thalamus": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "cortex": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    }
}

spike_filled_time = {
    "striatum": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "thalamus": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "cortex": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    }
}

for brain_area in ["striatum", "thalamus", "cortex"]:
    for recording in processed_recordings_mua_last:
        if recording["age"] < 7:
            age = "5-6"
        elif recording["age"] >= 7 and recording["age"] < 9:
            age = "7-8"
        elif recording["age"] >= 9 and recording["age"] < 11:
            age = "9-10"
        elif recording["age"] >= 11 and recording["age"] < 13:
            age = "11-12"
        else:
            age = ">12"
            
        if not brain_area + "_mua" in recording:
            continue

        sps = len(recording[brain_area + "_mua"][1]) / recording["length"]
        spikes_per_second[brain_area][age].append(sps)
        
        sparse_spike_times = recording[brain_area + "_mua"][0]
        spike_times = np.zeros(int(recording["length"]*SAMPLING_RATE))
        spike_times[sparse_spike_times] = 1
        
        filled_time = 0
        total_time = 0
        
        chunk_size = int(SAMPLING_RATE*1)
        for i in range(0, len(spike_times), chunk_size):
            spike_chunk = spike_times[i:i+chunk_size]
            if 1 in spike_chunk:
                filled_time += 1
            total_time += 1
            
        sft = (filled_time / total_time) * 100
        spike_filled_time[brain_area][age].append(sft)
        
print(twoway_anova(spikes_per_second, ['brain_area', 'age', 'spikes']))
print(twoway_anova(spike_filled_time, ['brain_area', 'age', 'spikes']))

plot_graph_by_age (
    data=spikes_per_second,
    ylabel='Spikes $\mathregular{s^{-1}}$',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=["Striatum", 'CL/Pf', 'Cortex']
)
plot_graph_by_age (
    data=spike_filled_time,
    ylabel='Spike filled time(%)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=["Striatum", 'CL/Pf', 'Cortex'],
)

## LFP and burst statistics over developmental time periods

In [None]:
LFP_power = {
    "striatum": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "thalamus": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "cortex": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    }
}   

def bandpower(x, fs, fmin, fmax):
    f, Pxx = scipy.signal.periodogram(x, fs=fs)
    ind_min = scipy.argmax(f > fmin) - 1
    ind_max = scipy.argmax(f > fmax) - 1
    return scipy.trapz(Pxx[ind_min: ind_max], f[ind_min: ind_max])

for recording_idx, recording in enumerate(RECORDINGS):
    print("P{} {}/{} {}".format(recording["age"], recording_idx+1, len(RECORDINGS), recording["path"]))
    
    if recording["age"] < 7:
        age = "5-6"
    elif recording["age"] >= 7 and recording["age"] < 9:
        age = "7-8"
    elif recording["age"] >= 9 and recording["age"] < 11:
        age = "9-10"
    elif recording["age"] >= 11 and recording["age"] < 13:
        age = "11-12"
    else:
        age = ">12"

    for brain_area in ["striatum", "thalamus", "cortex"]:
        brain_channel = brain_area + "_channel"
        
        if not brain_channel in recording:
            continue
        else:
            print('\t{}'.format(brain_area))
        
        recording_n = recording["recording"]
        channel_n = recording[brain_channel]
        session = Session(ROOT + recording["path"])

        data_all = session.recordings[recording_n].continuous[0].samples[get_slice_from_s(0, 60*10), channel_n]

        lp = bandpower(data_all, SAMPLING_RATE, 4, 100)
        
        LFP_power[brain_area][age].append(lp)

In [None]:
amplitude = {
    "striatum": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "thalamus": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "cortex": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    }
}

relative_amplitude = {
    "striatum": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "thalamus": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "cortex": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    }
}

occurence = {
    "striatum": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "thalamus": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "cortex": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    }
}

filled_time = {
    "striatum": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "thalamus": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "cortex": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    }
}

for brain_area in ["striatum", "thalamus", "cortex"]:
    print(brain_area)
    for idx, recording in enumerate(rms_processed_recordings):
        if recording["age"] < 7:
            age = "5-6"
        elif recording["age"] >= 7 and recording["age"] < 9:
            age = "7-8"
        elif recording["age"] >= 9 and recording["age"] < 11:
            age = "9-10"
        elif recording["age"] >= 11 and recording["age"] < 13:
            age = "11-12"
        else:
            age = ">12"

        if not brain_area + "_bursts" in recording:
            continue
        else:
            print('\t{}/{}'.format(idx+1, len(rms_processed_recordings)))
            
        baseline = recording[brain_area + "_baseline_amplitude"]
    
        amp_arr = []
        filled_time_arr = []    
        
        for burst in recording[brain_area + "_bursts"]:
            data = butter_bandpass_filter(burst.data, 4, 100)
            
            dur = burst.time[1] - burst.time[0]
            
            if dur/SAMPLING_RATE > 20:
                continue
            
            filled_time_arr.append(dur)
            
            amp = np.max(data) - np.min(data)
            amp_arr.append(amp)
            
        if len(amp_arr):
            amplitude[brain_area][age].append(np.mean(amp_arr))
            relative_amplitude[brain_area][age].append(np.mean(amp_arr)/baseline)
        
        filled = (np.sum(filled_time_arr) / recording["length"]) * 100
        filled_time[brain_area][age].append(filled)
        
        occur = len(recording[brain_area + "_bursts"]) / (recording["length"] / (60*SAMPLING_RATE))
        occurence[brain_area][age].append(occur)
        
plot_graph_by_age (
    data=amplitude,
    ylabel='Burst amplitude (μV)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
) 
print(twoway_anova(amplitude, ['brain_area', 'age', 'amplitude']))

plot_graph_by_age (
    data=relative_amplitude,
    ylabel='Relative burst amplitude ($\mathregular{A/A_0}$)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
) 
print(twoway_anova(relative_amplitude, ['brain_area', 'age', 'amplitude']))

plot_graph_by_age (
    data=occurence,
    ylabel='Bursts $\mathregular{min^{-1}}$',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
print(twoway_anova(occurence, ['brain_area', 'age', 'occurence']))

plot_graph_by_age (
    data=filled_time,
    ylabel='Burst filled time (%)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
print(twoway_anova(filled_time, ['brain_area', 'age', 'filled_time']))


plot_graph_by_age (
    data=LFP_power,
    ylabel='LFP power ($\mathregular{μV^2}$)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
print(twoway_anova(LFP_power, ['brain_area', 'age', 'power']))

## PCA and clustering procedure

In [None]:
def rms (data):
    square = [d**2 for d in data]
    mean = np.mean(square)
    root = mean**0.5
    return root
    
def get_feature_vector (burst):
    if not len(burst.data):
        return False
    
    burst_data = butter_bandpass_filter(burst.data, 4, 100)
    
    # Duration
    duration = (burst.time[1] - burst.time[0]) / SAMPLING_RATE
    
    # Negative peak
    negative_peak = min(burst_data)
    
    # rms
    rms_list = []
    chunk_size = int(SAMPLING_RATE*0.2)
    for i in range(0, len(burst_data), chunk_size):
        data_chunk = burst_data[i:i+chunk_size]
        rms_list.append(rms(data_chunk))
    
    # rms features
    max_rms = np.max(rms_list)
    min_rms = np.min(rms_list)
    flatness = min_rms / max_rms
    
    # slope
    data_downsampled = signal.resample(burst_data, int(len(burst_data)/SAMPLING_RATE * 500))
    max_slope = np.max([x - z for x, z in zip(data_downsampled[:-1], data_downsampled[1:])])
    
    # Spectral power
    f_burst, Pxx_burst = burst.normalized_psd
    theta_idx = np.where(f_burst >= 4)[0][0] 
    beta_idx = np.where(f_burst >= 16)[0][0]
    lgamma_idx = np.where(f_burst >= 40)[0][0]
    
    theta_power = np.sum(Pxx_burst[theta_idx:beta_idx]) / np.sum(Pxx_burst[theta_idx:])
    beta_lgamma_power = np.sum(Pxx_burst[beta_idx:lgamma_idx]) / np.sum(Pxx_burst[theta_idx:])
    
    # Peak/trough features
    peaks, cycles = get_peaks(burst_data)
    iti = np.mean(np.diff([p for p in peaks if burst_data[p] < 0])) / SAMPLING_RATE
    if np.isnan(iti):
        return False
    
    return [
        duration,
        max_rms,
        negative_peak,
        flatness,
        max_slope,
        beta_lgamma_power,
        theta_power,
        iti
    ]

feature_vector_labels = [
    "Duration",
    "Max RMS",
    "Negative peak",
    "Flatness",
    "Max slope",
    "β-γ power",
    "θ-α power",
    "ITI",
    "Spike rate"
]
feature_vector_labels_full = [
    "Duration (s)",
    "Max RMS (μV)",
    "Negative peak (μV)",
    "Flatness",
    "Max slope",
    "Beta/low-gamma power",
    "Theta-alpha power",
    "Inter-trough-interval (s)",
    "Spikes $s^{-1}$"
]

all_bursts = []
all_features = []

for idx, recording in enumerate(rms_processed_recordings):
    print("{} (P{}) {}/{}".format(recording["path"], recording["age"], idx+1, len(rms_processed_recordings)))
    
    key = "thalamus_bursts" # or "cortex_bursts", "striatum_bursts"
    if not key in recording:
        continue

    spike_times = processed_recordings_mua[idx]["thalamus_mua"][0]
        
    for burst_idx, burst in enumerate(recording[key]):
        if (burst_idx+1) % 20 == 0:
            print('\t{}/{}'.format(burst_idx+1, len(recording[key])))

        if burst.primary_frequency_baseline:
            burst.feature_vec = get_feature_vector(burst)
            if hasattr(burst, 'feature_vec') and burst.feature_vec and burst.feature_vec[0] < 20:
                # Spikes in range of burst
                spikes_in_range = [t for t in spike_times if (t >= burst.time[0] and t <= burst.time[1])]
                spike_rate = len(spikes_in_range) / burst.feature_vec[0]
                burst.feature_vec.append(spike_rate)
                
                burst.age = recording["age"]
                burst.recording_path = recording["path"]

                all_features.append(burst.feature_vec)
                all_bursts.append(burst)

In [None]:
# Fuzzy cluster with thresh
def fuzzy_cluster (samples, n_clusters, threshold=0.6):
    cntr, u, u0, d, jm, p, fpc = fuzz.cmeans(samples.T, n_clusters, 2, error=0.005, maxiter=10000, init=None)
    
    labels = []
    for sample in u.T:
        max_p = np.max(sample)
        if max_p > threshold:
            labels.append(np.argmax(sample, axis=0))
        else:
            labels.append(len(sample))
    labels = np.array(labels)
    
    return labels, fpc
    
# Run PCA analysis using first 3 components
pca = PCA(n_components=3)
scaler = StandardScaler()
pc_feature_list = pca.fit_transform(scaler.fit_transform(all_features))

# Get optimal number of clusters
n_clusters = np.arange(2, 6)
silhouette_scores = []

for n in n_clusters:
    _, fpc = fuzzy_cluster(pc_feature_list, n)
    silhouette_scores.append(fpc)

plt.plot(n_clusters, silhouette_scores, c='black')
plt.xticks(n_clusters)
plt.xlabel('N clusters')
plt.ylabel('Fuzzy partition coefficient')
format_plot(plt.gca(), legend=False)
plt.show()

print('FPC values', silhouette_scores)
optimal_n_clusters = n_clusters[np.argmax(silhouette_scores)]

# Plot clusters (of first 3 components)
cluster_data = {
    "cluster": [],
    "pc1": [],
    "pc2": [],
    "pc3": []
}

labels, _ = fuzzy_cluster(pc_feature_list, optimal_n_clusters)
fig = plt.figure(figsize=[5, 5], dpi=150)
ax = fig.add_subplot(111, projection='3d')
colors = ["red", "blue", "gray"]
for cluster_i in [1,0,2]:
    if cluster_i == 1:
        label = 'NGB'
        alpha=0.1
    elif cluster_i == 0:
        label = 'SB'
        alpha=0.1
    else:
        label = 'UC'
        alpha=0.5
    
    samples = pc_feature_list[labels == cluster_i]
    ax.scatter(samples[:, 0], samples[:, 1], samples[:, 2], alpha=alpha, label=label, c=colors[cluster_i])
    
    for sample in samples:
        cluster_data["cluster"].append(label)
        cluster_data["pc1"].append(sample[0])
        cluster_data["pc2"].append(sample[1])
        cluster_data["pc3"].append(sample[1])
    
ax.set_xlabel('PC1')
ax.set_xlim([None,4.5])
ax.set_ylabel('PC2')
ax.set_ylim([-4.5,None])
ax.set_zlabel('PC3', size=15)
ax.set_zlim([None,8])
ax.tick_params(axis='z', labelsize=12.5)
format_plot(ax)
plt.show()
    
# Visualize variance explained by each component
total_features = len(feature_vector_labels)
all_pca = PCA(n_components=total_features)
all_pca.fit(scaler.fit_transform(all_features))
plt.plot(np.cumsum(all_pca.explained_variance_ratio_), c='black')
plt.xticks(np.arange(total_features), np.arange(total_features)+1)
plt.xlabel('N components')
plt.ylabel('Variance')
format_plot(plt.gca(), legend=False)
print('Explained variance', np.cumsum(all_pca.explained_variance_ratio_))


# Visualize relative contribution of components
component_contribution_data = {
    "component": [],
    "duration": [],
    "max_rms": [],
    "negative_peak": [],
    "flatness": [],
    "max_slope": [],
    "beta_lgamma_power": [],
    "theta_power": [],
    "iti": [],
    "spike_rate": []
}
component_contribution_keys = [k for k in component_contribution_data.keys()]

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=[15, 3.5])
axs = list(axs.flat)
for component_idx, component in enumerate(pca.components_):
    ax = axs[component_idx]
    xpos = np.arange(len(component))
    
    ax.bar(xpos, np.abs(component), facecolor='dimgray')
    ax.set_xticks(xpos)
    ax.set_xticklabels(feature_vector_labels, rotation=45)
    ax.set_ylabel('Coefficient')
    ax.set_title('PC{}'.format(component_idx+1))
    ax.set_ylim([0, 0.8])
    format_plot(ax, legend=False)
    
    component_contribution_data["component"].append(component_idx+1)
    for coeff_idx, coeff in enumerate(component):
        component_contribution_data[component_contribution_keys[coeff_idx+1]].append(np.abs(coeff))
    
for ax in axs[total_features:]:
    ax.set_visible(False)
plt.tight_layout()
plt.show()

## Multivariate F-ratio test

In [None]:
def multivariate_f_ratio (a, b):
    a_group_means = np.mean(a, axis=0)
    b_group_means = np.mean(b, axis=0)
    overall_means = np.mean(np.concatenate([a,b]), axis=0)

    ss_between_group = 0
    ss_between_group += (len(a) * np.linalg.norm(overall_means-a_group_means)**2)
    ss_between_group += (len(b) * np.linalg.norm(overall_means-b_group_means)**2)
    
    ss_within_group = 0
    for feature_vec in a:
        ss_within_group += (np.linalg.norm(a_group_means-feature_vec)**2)
    for feature_vec in b:
        ss_within_group += (np.linalg.norm(b_group_means-feature_vec)**2)      
    ss_within_group = ss_within_group / (len(a)+len(b) - 2)
  
    f = ss_between_group/ss_within_group
    return f

def permute_multivariate_f_ratio (a, b, iterations):
    combined = np.concatenate([a, b])

    test_stat = multivariate_f_ratio(a, b)
    null_dist = []

    for i in range(iterations):
        shuffled = np.random.permutation(combined)
        a_shuffled = shuffled[:len(a)]
        b_shuffled = shuffled[len(a):]

        f = multivariate_f_ratio(a_shuffled, b_shuffled)
        null_dist.append(f)

        if (i+1) % 1000 == 0:
            print('Iteration', i+1)

    p_val = len(np.where(null_dist>= test_stat)[0]) / len(null_dist)
    
    return p_val, test_stat, null_dist
    
# Lists of feature vectors for e.g., cortex ngb and thalamus ngb clusters
# on which to perform F-ratio test
burst_group_a = cortex_ngb_features
burst_group_b = thalamus_ngb_features

p, f, null = permute_multivariate_f_ratio(burst_group_a, burst_group_b, iterations=10000)

bins = 100
max_val = np.max(np.histogram(null, bins=bins)[0])

plt.hist(null, bins=bins, facecolor='dimgray')
plt.plot([f, f], [0, max_val])
plt.xlabel('F-value')
plt.ylabel('Count')
format_plot(plt.gca(), legend=False)
plt.show()
print(
    'p = {}, F({}, {}) = {}'.format(p, 1, len(burst_group_a) + len(burst_group_b) - 2, f)
)

## Developmental trajectory of burst properties

In [None]:
 def plot_group_lineplots (data, ylim, labels, xticks, colors=['blue', 'orange'], title='', xlabel='', ylabel='', yscale='linear', save=False):    
    fig = plt.figure()
    
    lines = []
    for idx, data_group in enumerate(data):
        shift = None
        
        if len(data) == 1:
            shift = 0
        elif len(data) == 2 and idx == 0:
            shift = -0.4
        elif len(data) ==2 and idx == 1:
            shift = 0.4

        positions = np.array(range(len(data_group)))# * 2 + shift
        line = plt.errorbar(
            positions,
            [np.mean(group) for group in data_group],
            yerr=[np.std(group)/(len(data_group)**0.5) for group in data_group],
            capsize=5,
            c=colors[idx]
        )
        
        for group_idx, group in enumerate(data_group):
            xpos = [positions[group_idx] for i in range(len(group))]
            plt.plot(xpos, group, c=colors[idx], lw=0, marker='o', alpha=0.25)
        
        lines.append(line)
        
    plt.xticks(np.arange(len(xticks)), xticks)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.yscale(yscale)
    plt.ylim(ylim)
        
    leg = plt.legend([line for line in lines], labels, frameon=False, fontsize=15, bbox_to_anchor=(1, 1), loc='upper left')
    for legobj in leg.legendHandles:
        legobj.set_linewidth(2.0)
        
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    for item in ([ax.title] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(12.5)
    for item in ([ax.xaxis.label, ax.yaxis.label] + ax.get_legend().get_texts()):
        item.set_fontsize(15)
    
    if save:
        plt.savefig(FIG_ROOT + save, bbox_inches="tight")
    plt.show()

# GROUP is group of bursts identified earlier via PCA
GROUP = {
    "NGB": cortex_ngb,
    "SB": cortex_sb
}
    
occurence = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}

amplitude = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}

duration = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}

spike_rate = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}

alphatheta_power = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}
betagamma_power = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}

alphatheta_peak = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}
betagamma_peak = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}


box_labels = ["NGB", "SB"]
age_group_labels = list(occurence["SB"].keys())

for brain_area in ["cortex"]:
    brain_area_bursts = brain_area + "_bursts"
    
    for recording_n, recording in enumerate(rms_processed_recordings):
        if not brain_area_bursts in recording:
            continue
            
        print("Processing {}, recording {}/{}".format(brain_area, recording_n+1, len(rms_processed_recordings)))

        age = recording["age"]    
        if age < 7:
            age_group = "5-6"
        elif age >= 7 and age < 9:
            age_group = "7-8"
        elif age >= 9 and age < 11:
            age_group = "9-10"
        elif age >= 11 and age < 13:
            age_group = "11-12"
        else:
            continue

        total_bursts_ngb = []
        total_bursts_sb = []
        
        amplitude_recording = {
            "NGB": [],
            "SB": []
        }
        duration_recording = {
            "NGB": [],
            "SB": []
        }
        spike_rate_recording = {
            "NGB": [],
            "SB": []
        }
        alphatheta_recording = {
            "NGB": [],
            "SB": []
        }
        betagamma_recording = {
            "NGB": [],
            "SB": []
        }
        alphatheta_peak_recording = {
            "NGB": [],
            "SB": []
        }
        betagamma_peak_recording = {
            "NGB": [],
            "SB": []
        }
        
        for burst_idx, burst in enumerate(recording[brain_area_bursts]):            
            if (burst in GROUP['SB']):
                group_key = 'SB'
                total_bursts_sb.append(burst)
            elif burst in GROUP['NGB']:
                group_key = 'NGB'
                total_bursts_ngb.append(burst)
            else:
                continue

            amplitude_recording[group_key].append(np.max(burst.data)-np.min(burst.data))
            duration_recording[group_key].append((burst.time[1] - burst.time[0])/SAMPLING_RATE)
            spike_rate_recording[group_key].append(burst.feature_vec[8])
            alphatheta_recording[group_key].append(burst.feature_vec[6])
            betagamma_recording[group_key].append(burst.feature_vec[5])
            
            f, Pxx = burst.normalized_psd

            theta_idx = np.where(f >= 4)[0][0] 
            beta_idx = np.where(f >= 16)[0][0]
            lgamma_idx = np.where(f >= 40)[0][0]

            alphatheta_peak_ = f[np.argmax(Pxx[theta_idx:beta_idx])+theta_idx]
            betagamma_peak_ = f[np.argmax(Pxx[beta_idx:lgamma_idx])+beta_idx]

            alphatheta_peak_recording[group_key].append(alphatheta_peak_)
            betagamma_peak_recording[group_key].append(betagamma_peak_)
            
        occurence_ngb = len(total_bursts_ngb) / (recording["length"] / SAMPLING_RATE) * 60
        occurence_sb = len(total_bursts_sb) / (recording["length"] / SAMPLING_RATE) * 60
        
        occurence["NGB"][age_group].append(occurence_ngb)
        occurence["SB"][age_group].append(occurence_sb)
        
        if len(amplitude_recording["NGB"]):
            amplitude["NGB"][age_group].append(np.mean(amplitude_recording["NGB"]))
        if len(amplitude_recording["SB"]):
            amplitude["SB"][age_group].append(np.nanmean(amplitude_recording["SB"]))

        if len(duration_recording["NGB"]):
            duration["NGB"][age_group].append(np.mean(duration_recording["NGB"]))
        if len(duration_recording["SB"]):
            duration["SB"][age_group].append(np.nanmean(duration_recording["SB"]))
            
        if len(spike_rate_recording["NGB"]):
            spike_rate["NGB"][age_group].append(np.mean(spike_rate_recording["NGB"]))
        if len(spike_rate_recording["SB"]):
            spike_rate["SB"][age_group].append(np.mean(spike_rate_recording["SB"]))
        
        if len(alphatheta_recording["NGB"]):
            alphatheta_power["NGB"][age_group].append(np.mean(alphatheta_recording["NGB"]))
        if len(alphatheta_recording["SB"]):
            alphatheta_power["SB"][age_group].append(np.mean(alphatheta_recording["SB"]))
        
        if len(betagamma_recording["NGB"]):
            betagamma_power["NGB"][age_group].append(np.mean(betagamma_recording["NGB"]))
        if len(betagamma_recording["SB"]):
            betagamma_power["SB"][age_group].append(np.mean(betagamma_recording["SB"]))
        
        if len(alphatheta_peak_recording["NGB"]):
            alphatheta_peak["NGB"][age_group].append(np.mean(alphatheta_peak_recording["NGB"]))
        if len(alphatheta_peak_recording["SB"]):
            alphatheta_peak["SB"][age_group].append(np.mean(alphatheta_peak_recording["SB"]))
        
        if len(betagamma_peak_recording["NGB"]):
            betagamma_peak["NGB"][age_group].append(np.mean(betagamma_peak_recording["NGB"]))
        if len(betagamma_peak_recording["SB"]):
            betagamma_peak["SB"][age_group].append(np.mean(betagamma_peak_recording["SB"]))
        
# Parameters to pass to line plot function
plot_parameters = [
    { "title": "occurence", "ylabel": "Bursts $\mathregular{s^{-1}}$", "data": occurence, "ylim": [0, 0.16] },
    { "title": "amplitude", "ylabel": "Amplitude (μV)", "data": amplitude, "ylim": [200, 2000] },
    { "title": "duration", "ylabel": "Duration (s)", "data": duration, "ylim": [0, 10] },
    { "title": "spike_rate", "ylabel": "Spikes $\mathregular{s^{-1}}$", "data": spike_rate, "ylim": [0, 32] },
    { "title": "alphatheta", "ylabel": "Relative alpha-theta power", "data": alphatheta_power, "ylim": [0, 0.6] },
    { "title": "betagamma", "ylabel": "Relative beta-gamma power", "data": betagamma_power, "ylim": [0.25, 0.65] },
    { "title": "alphatheta_peak", "ylabel": "Peak alpha-theta frequency", "data": alphatheta_peak, "ylim": [5, 15] },
    { "title": "betagamma_peak", "ylabel": "Peak beta-gamma frequency", "data": betagamma_peak, "ylim": [15, 30] }
]

# Line plots
for plot_param in plot_parameters:
    data = plot_param["data"]
    age_groups_data = []
    
    print(plot_param["title"])

    for brain_area_values in zip(data.keys(), data.values()):
        brain_key, brain_area = brain_area_values
        print('\t', brain_key)
        age_groups = []
        for age_group_values in zip(brain_area.keys(), brain_area.values()):
            age_group_key, age_group = age_group_values
            print('\t {} ({})'.format(round(np.mean(age_group), 1), round(np.std(age_group), 2)))
            age_groups.append(age_group)
        age_groups_data.append(age_groups)

    plot_group_lineplots(
        data=age_groups_data,
        ylim=plot_param["ylim"],
        labels=box_labels,
        xticks=age_group_labels,
        colors=['tab:blue', 'tab:red'],
        xlabel="Age (days)",
        ylabel=plot_param["ylabel"],
        yscale=('log' if plot_param["title"] == "Duration" or plot_param["title"] == "Amplitude" else 'linear')
    )
    
    print(twoway_anova(data, ['burst_type', 'age', 'value']))

## Burst synchrony
### General functions
Used for getting co-occuring bursts across brain regions and their lags

In [None]:
def get_overlapping_bursts (bursts_a, bursts_b, bursts_c=None):
    # maximum difference in start times considered overlapping
    max_t_diff = 0.5
    burst_pairs = []

    for burst_a in bursts_a:
        start_t_a = burst_a.time[0]

        for burst_b in bursts_b:
            start_t_b = burst_b.time[0]
            
            if bursts_c:
                for burst_c in bursts_c:
                    start_t_c = burst_c.time[0]
                    
                    t_diff_ab = abs(start_t_a - start_t_b)
                    t_diff_ac = abs(start_t_a - start_t_c)
                    t_diff_bc = abs(start_t_b - start_t_c)
                    
                    max_diff = max(t_diff_ab, t_diff_ac, t_diff_bc)
                    min_len = min(len(burst_a.data), len(burst_b.data), len(burst_c.data))
                    
                    if max_diff < (SAMPLING_RATE*max_t_diff) and max_diff < min_len:
                        burst_pairs.append(
                            { "bursts": [burst_a, burst_b, burst_c], "hash": hash(burst_a)+hash(burst_b)+hash(burst_c) }
                        )
            else:
                t_diff = abs(start_t_a - start_t_b)
                if t_diff < (SAMPLING_RATE*max_t_diff) and t_diff < min(len(burst_a.data), len(burst_b.data)):
                    burst_pairs.append(
                        { "bursts": [burst_a, burst_b], "hash": hash(burst_a)+hash(burst_b) }
                    )
    
    burst_pairs = np.array(burst_pairs)
    unique_burst_pair_idxs = np.unique([pair["hash"] for pair in burst_pairs], return_index=True)[1]
    unique_burst_pairs = burst_pairs[unique_burst_pair_idxs]
    unique_burst_pairs = [pair["bursts"] for pair in burst_pairs]
    
    return unique_burst_pairs

def get_nonoverlapping_bursts (bursts_a, bursts_b, burst_pairs):
    bursts_a_nonoverlapping = []
    for burst_a in bursts_a:
        is_overlapping = False
        for burst_pair in burst_pairs:
            if burst_a in burst_pair:
                is_overlapping = True
                break
        if not is_overlapping:
            bursts_a_nonoverlapping.append(burst_a)

    bursts_b_nonoverlapping = []
    for burst_b in bursts_b:
        is_overlapping = False
        for burst_pair in burst_pairs:
            if burst_b in burst_pair:
                is_overlapping = True
                break
        if not is_overlapping:
            bursts_b_nonoverlapping.append(burst_b)
      
    return bursts_a_nonoverlapping, bursts_b_nonoverlapping

def align_overlapping_bursts (burst_a, burst_b):
    start_overlap = max(burst_a.time[0], burst_b.time[0])
    end_overlap = min(burst_a.time[1], burst_b.time[1])
    
    
    start_a = start_overlap-burst_a.time[0]
    end_a = start_a + (end_overlap-start_overlap)
    start_b = start_overlap-burst_b.time[0]
    end_b = start_b + (end_overlap-start_overlap)
    
    burst_a_trimmed = burst_a.data[start_a:end_a]
    burst_b_trimmed = burst_b.data[start_b:end_b]
    
    return (
        burst_a_trimmed,
        burst_b_trimmed,
        get_xticks(slice(start_overlap, end_overlap))
    )

def get_threshold (burst_pairs):
    f_arr = None
    Cxy_arr = []
    
    a_burst, b_burst = burst_pairs[0]
    a, b, xticks = align_overlapping_bursts(a_burst, b_burst)
    for iteration in range(1000):
        np.random.shuffle(a)
        np.random.shuffle(b)

        window_size = 0.5
        window = int(SAMPLING_RATE*window_size)
        overlap = int(SAMPLING_RATE*window_size*0.5)
        f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
        f, Cxy = get_psd_in_range((f, Cxy), [0, 100])
        
        plt.plot(a)
        plt.plot(b)
        plt.show()

        f_arr = f
        Cxy_arr.append(Cxy)
        
    Cxy_arr = np.array(Cxy_arr)

    return np.percentile(Cxy_arr, 95, axis=0)

def get_lags (burst_a_data, burst_b_data, freq, prewhiten):
    burst_a_data = butter_bandpass_filter(burst_a_data, freq[0], freq[1])
    burst_b_data = butter_bandpass_filter(burst_b_data, freq[0], freq[1])
    
    # Get instantaneous amplitude
    analytic_signal_a = signal.hilbert(burst_a_data)
    amplitude_a = np.abs(analytic_signal_a)
    analytic_signal_b = signal.hilbert(burst_b_data)
    amplitude_b = np.abs(analytic_signal_b)
    
    # Remove dc component
    amplitude_a = amplitude_a - np.mean(amplitude_a)
    amplitude_b = amplitude_b - np.mean(amplitude_b)

    # Pre-whiten
    if prewhiten:
        model = ARIMA(amplitude_a, order=(0,1,1))
        model_fit=model.fit()
        resid_a = model_fit.resid
        
        a = resid_a
    else:
        a = amplitude_a

    # Normalize cross-correlation
    resid_a = (a - np.mean(a)) / (np.std(a) * len(a))
    amplitude_b = (amplitude_b - np.mean(amplitude_b)) / (np.std(amplitude_b))
    cross_corr = signal.correlate(a, amplitude_b)
    
    # Take only +/- 100ms from 0 lag
    centre = int(len(cross_corr) * 0.5)
    hundred_ms = int(SAMPLING_RATE*0.1)
    cross_corr = cross_corr[centre-hundred_ms:centre+hundred_ms]

    # Square results
    cross_corr_sq = cross_corr**2
    
    # Convert time points into ms
    xpos = np.linspace(-hundred_ms, hundred_ms, len(cross_corr))
    xpos = xpos / SAMPLING_RATE * 1000
    
    return xpos, cross_corr_sq

## Get proportion of synchronous events by age and region

In [None]:
cooccuring_events = {
    "Thal-CP": [],
    "Ctx-CP": [],
    "Ctx-thal": [],
    "Ctx-thal-CP": []
}

cooccuring_events_by_age = {
    "Thal-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "Ctx-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "Ctx-thal": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "Ctx-thal-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    }
}

nonoverlapping_events = {
    "thalamus": [],
    "striatum": [],
    "cortex": []
}

proportion_cooccuring_events = {
    "Thal-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "Ctx-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "Ctx-thal": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    },
    "Ctx-thal-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": [],
        ">12": []
    }
}

for recording_idx, recording in enumerate(rms_processed_recordings):
    if recording["age"] < 7:
        age = "5-6"
    elif recording["age"] >= 7 and recording["age"] < 9:
        age = "7-8"
    elif recording["age"] >= 9 and recording["age"] < 11:
        age = "9-10"
    elif recording["age"] >= 11 and recording["age"] < 13:
        age = "11-12"
    else:
        age = ">12"

    
    Ctx_CP = get_overlapping_bursts (recording["cortex_bursts"], recording["striatum_bursts"])
    Ctx_nonoverlap, CP_nonoverlap = get_nonoverlapping_bursts (recording["cortex_bursts"], recording["striatum_bursts"], Ctx_CP)
    nonoverlapping_events["cortex"] += Ctx_nonoverlap
    nonoverlapping_events["striatum"] += CP_nonoverlap
    cooccuring_events["Ctx-CP"] += Ctx_CP
    cooccuring_events_by_age["Ctx-CP"][age] += Ctx_CP
    proportion_cooccuring_events["Ctx-CP"][age].append(
        len(Ctx_CP) / (len(recording["cortex_bursts"]) + len(recording["striatum_bursts"]))
    )
    
    if "thalamus_bursts" in recording:
        Thal_CP = get_overlapping_bursts (recording["thalamus_bursts"], recording["striatum_bursts"])
        thal_nonoverlap, cp_nonoverlap = get_nonoverlapping_bursts (recording["thalamus_bursts"], recording["striatum_bursts"], Thal_CP)
        nonoverlapping_events["thalamus"] += thal_nonoverlap
        nonoverlapping_events["striatum"] += cp_nonoverlap
        cooccuring_events["Thal-CP"] += Thal_CP
        cooccuring_events_by_age["Thal-CP"][age] += Thal_CP
        proportion_cooccuring_events["Thal-CP"][age].append(
            len(Thal_CP) / (len(recording["thalamus_bursts"]) + len(recording["striatum_bursts"]))
        )
        
        Ctx_thal = get_overlapping_bursts (recording["cortex_bursts"], recording["thalamus_bursts"])
        ctx_nonoverlap, thal_nonoverlap = get_nonoverlapping_bursts (recording["cortex_bursts"], recording["thalamus_bursts"], Ctx_thal)
        nonoverlapping_events["cortex"] += ctx_nonoverlap
        nonoverlapping_events["thalamus"] += thal_nonoverlap
        cooccuring_events["Ctx-thal"] += Ctx_thal
        cooccuring_events_by_age["Ctx-thal"][age] += Ctx_thal
        proportion_cooccuring_events["Ctx-thal"][age].append(
            len(Ctx_thal) / (len(recording["cortex_bursts"]) + len(recording["thalamus_bursts"]))
        )
        
        Ctx_thal_CP = get_overlapping_bursts (recording["cortex_bursts"], recording["thalamus_bursts"], recording["striatum_bursts"])
        cooccuring_events["Ctx-thal-CP"] += Ctx_thal_CP
        cooccuring_events_by_age["Ctx-thal-CP"][age] += Ctx_thal_CP
        proportion_cooccuring_events["Ctx-thal-CP"][age].append(
            len(Ctx_thal_CP) / (len(recording["cortex_bursts"]) + len(recording["thalamus_bursts"]) + len(recording["striatum_bursts"]))
        )

    print("{}/{}".format(recording_idx+1, len(rms_processed_recordings)))

## Plot previous results

In [None]:
# Aggregate proportions (over all ages)   
bar_means = []
bar_stderrs = []
bar_labels = []
    
for comparison in proportion_cooccuring_events.keys():
    data = sum([val for val in proportion_cooccuring_events[comparison].values()], [])
    mean = np.mean(data)
    stderr = np.std(data) / len(data)
    
    bar_means.append(mean)
    bar_stderrs.append(stderr)
    bar_labels.append(comparison)
xpos = np.arange(len(bar_means))    
        
plt.bar(xpos, bar_means, yerr=bar_stderrs, capsize=5)
plt.xticks(xpos, bar_labels)
plt.ylabel('Proportion of total events')
plt.title('Proportion of co-occuring events')
plt.show()


ngb_pairs = []
sb_pairs = []
mixed_pairs = []

for recording in rms_processed_recordings:
    if not "thalamus_bursts" in recording:
        continue
    
    recording_ngb = []
    recording_sb = []
    
    recording_ngb_pairs = []
    recording_sb_pairs = []
    recording_mixed_pairs = []

    # By burst type
    for burst_pair in cooccuring_events["Thal-CP"]:
        if not burst_pair[0] in recording["thalamus_bursts"]:
            continue
            
        thal_burst, cp_burst = burst_pair
        if thal_burst in thalamus_ngb and cp_burst in striatum_ngb:
            recording_ngb_pairs.append(burst_pair)
            recording_ngb.append(thal_burst)
            recording_ngb.append(cp_burst)
        elif thal_burst in thalamus_sb and cp_burst in striatum_sb:
            recording_sb_pairs.append(burst_pair)
            recording_sb.append(thal_burst)
            recording_sb.append(cp_burst)
        elif thal_burst in thalamus_ngb and cp_burst in striatum_sb:
            recording_mixed_pairs.append(burst_pair)
            recording_ngb.append(thal_burst)
            recording_sb.append(cp_burst)
        elif thal_burst in thalamus_sb and cp_burst in striatum_ngb:
            recording_mixed_pairs.append(burst_pair) 
            recording_sb.append(thal_burst)
            recording_ngb.append(cp_burst)
            
    if len(recording_ngb_pairs):
        ngb_pairs.append( len(recording_ngb_pairs) / len(recording_ngb) )
    if len(recording_sb_pairs):
        sb_pairs.append( len(recording_sb_pairs) / len(recording_sb) )
    if len(recording_mixed_pairs):
        mixed_pairs.append( len(recording_mixed_pairs) / (len(recording_ngb)+len(recording_sb)) )

xpos = np.arange(3)
yvals = [np.mean(ngb_pairs), np.mean(sb_pairs), np.mean(mixed_pairs)]
yerrs = [
    np.std(ngb_pairs)/len(ngb_pairs)**0.5,
    np.std(sb_pairs)/len(sb_pairs)**0.5,
    np.std(mixed_pairs)/len(mixed_pairs)**0.5
]

labels = ["Both NGB", "Both SB", "SB and NGB"]
colors = ["tab:blue", "tab:red", 'dimgray']
bars = plt.bar(xpos, yvals, yerr=yerrs, capsize=5)
for idx, bar_i in enumerate(bars):
    bar_i.set_color(colors[idx])
plt.xticks(xpos, labels)
plt.ylabel('Normalized frequency')
format_plot(plt.gca(), legend=False)
plt.show()

print('Mean =', yvals)
print('Std =', [np.std(ngb_pairs), np.std(sb_pairs), np.std(mixed_pairs)])
print('NGB vs SB', stats.ttest_ind(ngb_pairs, sb_pairs, equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(ngb_pairs, sb_pairs))
print('DOF = ', welch_dof(ngb_pairs, sb_pairs))

print('\nNGB vs mixed', stats.ttest_ind(ngb_pairs, mixed_pairs, equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(ngb_pairs, mixed_pairs))
print('DOF = ', welch_dof(ngb_pairs, mixed_pairs))

print('\nSB vs mixed', stats.ttest_ind(sb_pairs, mixed_pairs, equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(sb_pairs, mixed_pairs))
print('DOF = ', welch_dof(sb_pairs, mixed_pairs))
    
# By age
for comparison in proportion_cooccuring_events.keys():
    data_by_age = proportion_cooccuring_events[comparison]
    
    labels = []
    means = []
    stderrs = []
    
    for data, label in zip(data_by_age.values(), data_by_age.keys()):
        if len(data):
            means.append(np.mean(data))
            stderrs.append(np.std(data) / len(data)**0.5)
            labels.append(label)
          
    xpos = np.arange(len(labels))
    means = np.array(means)
    stderrs = np.array(stderrs)
        
    plt.plot(xpos, means, label=comparison)
    plt.fill_between(xpos, means-stderrs, means+stderrs, alpha=0.25)
    plt.xticks(xpos, labels)
    
plt.xlabel('Age')
plt.ylabel('Proportion of total events')
plt.title('Proportion of co-occuring events')
plt.legend(bbox_to_anchor=(1,1), loc="upper left")
plt.show()

## Cross-spectral coherence

In [None]:
window_size = 0.5
window = int(SAMPLING_RATE*window_size)
overlap = int(SAMPLING_RATE*window_size*0)

Cxy_arr_shuffled_ngb = []
Cxy_arr_ngb = []

Cxy_arr_shuffled_sb = []
Cxy_arr_sb = []

mean_coherence_ngb = []
mean_coherence_shuffled_ngb = []

mean_coherence_sb = []
mean_coherence_shuffled_sb = []

bursts_a_ngb = []
bursts_b_ngb = []

bursts_a_sb = []
bursts_b_sb = []

for burst_pair_idx, burst_pair in enumerate(cooccuring_events["Ctx-CP"]):
    burst_a, burst_b = burst_pair
    a, b, ticks = align_overlapping_bursts (burst_a, burst_b)
    
    if (burst_pair_idx+1) % 100 == 0:
        print('Burst pair', burst_pair_idx+1)

    a = butter_bandpass_filter(a, 4, 80)
    b = butter_bandpass_filter(b, 4, 80)

    if len(a)/SAMPLING_RATE < 0.5:
        continue

    f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
    f, Cxy = get_psd_in_range((f, Cxy), [4, 80])
    
    if np.mean(Cxy) > 0.8:
        continue
    
    if burst_a in cortex_ngb and burst_b in striatum_ngb:
        Cxy_arr_ngb.append(Cxy)
        mean_coherence_ngb.append(np.mean(Cxy))
        bursts_a_ngb.append(a)
        bursts_b_ngb.append(b)
    elif burst_a in cortex_sb and burst_b in striatum_sb:
        Cxy_arr_sb.append(Cxy)
        mean_coherence_sb.append(np.mean(Cxy))
        bursts_a_sb.append(a)
        bursts_b_sb.append(b)

for i in range(1000):
    if (i+1) % 10 == 0:
        print('Shuffle iteration', i+1)
    
    burst_pairs_shuffled_ngb = np.array([np.random.permutation(bursts_a_ngb), np.random.permutation(bursts_b_ngb)], dtype=object).T
    burst_pairs_shuffled_sb = np.array([np.random.permutation(bursts_a_sb), np.random.permutation(bursts_b_sb)], dtype=object).T

    Cxy_shuffled_inner_ngb = []
    Cxy_shuffled_inner_sb = []
    
    for burst_pair in burst_pairs_shuffled_ngb:
        a, b = burst_pair
        f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
        f, Cxy = get_psd_in_range((f, Cxy), [4, 80])
        
        Cxy_shuffled_inner_ngb.append(Cxy)
    for burst_pair in burst_pairs_shuffled_sb:
        a, b = burst_pair
        f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
        f, Cxy = get_psd_in_range((f, Cxy), [4, 80])
        
        Cxy_shuffled_inner_sb.append(Cxy)
    
    Cxy_shuffled_inner_ngb_mean = np.mean(Cxy_shuffled_inner_ngb, axis=0)
    Cxy_shuffled_inner_sb_mean = np.mean(Cxy_shuffled_inner_sb, axis=0)
    
    Cxy_arr_shuffled_ngb.append(Cxy_shuffled_inner_ngb_mean)
    Cxy_arr_shuffled_sb.append(Cxy_shuffled_inner_sb_mean)
    
    mean_coherence_shuffled_ngb.append(np.mean(Cxy_shuffled_inner_ngb_mean))
    mean_coherence_shuffled_sb.append(np.mean(Cxy_shuffled_inner_sb_mean))
    
# Bar plot
y = [
    np.mean(mean_coherence_ngb),
    np.mean(mean_coherence_sb)
]
yerr = [
    np.std(mean_coherence_ngb)/len(mean_coherence_ngb)**0.5,
    np.std(mean_coherence_sb)/len(mean_coherence_sb)**0.5
]
colors = [
    "tab:blue"
    "tab:red"
]
bars = plt.bar(["NGB", "SB"], y, yerr=yerr, capsize=5)
for idx, bar_i in enumerate(bars):
    bar_i.set_color(colors[idx])
plt.ylabel('Mean coherence')
format_plot(plt.gca(), legend=False)
plt.show()
print('ngb versus sb', stats.ttest_ind(mean_coherence_ngb, mean_coherence_sb, equal_var=False))
print('ngb', stats.ttest_ind(mean_coherence_ngb, mean_coherence_shuffled_ngb, equal_var=False))
print('sb', stats.ttest_ind(mean_coherence_sb, mean_coherence_shuffled_sb, equal_var=False))

# Coherence plot 
mean_ngb = np.mean(Cxy_arr_ngb, axis=0)
stderr_ngb = np.std(Cxy_arr_ngb, axis=0) / len(Cxy_arr_ngb)**0.5
mean_sb = np.mean(Cxy_arr_sb, axis=0)
stderr_sb = np.std(Cxy_arr_sb, axis=0) / len(Cxy_arr_ngb)**0.5

# Coherence plot NGB
plt.plot(f, mean_ngb, label='NGB', c='tab:blue')
plt.fill_between(f, mean_ngb+stderr_ngb, mean_ngb-stderr_ngb, alpha=0.5, facecolor='tab:blue')
shuffled_95_percentile_ngb = np.percentile(Cxy_arr_shuffled_ngb, 95, axis=0)
plt.plot(f, shuffled_95_percentile_ngb, '--', label='Shuffled NGB', c='tab:blue')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Coherence')
plt.legend()
format_plot(plt.gca())
plt.show()

# Coherence plot SB
plt.plot(f, mean_sb, label='SB', c='tab:red')
plt.fill_between(f, mean_sb+stderr_sb, mean_sb-stderr_sb, alpha=0.5, facecolor='tab:red')
shuffled_95_percentile_sb = np.percentile(Cxy_arr_shuffled_sb, 95, axis=0)
plt.plot(f, shuffled_95_percentile_sb, '--', label='Shuffled SB', c='tab:red')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Coherence')
plt.legend()
format_plot(plt.gca())
plt.show()

## Lag analysis

In [None]:
# Get cross-spectral coherence for burst pairs    
window_size = 0.5
window = int(SAMPLING_RATE*window_size)
overlap = int(SAMPLING_RATE*window_size*0)

xpos_arr = None
cross_corr_arr_test = []
peaks_arr_test = []

cross_corr_arr_control = []
peaks_arr_control = []

bursts_a = []
bursts_b = []

test_bands = [] # 'Test' bands where coherence is maximal
test_band_label = ''
control_bands = [] # 'Control' bands where coherence is low, as a comparison
control_band_label = ''

events = np.concatenate([cooccuring_events['Ctx-CP']])

for burst_pair_idx, burst_pair in enumerate(events):
    burst_a, burst_b = burst_pair
    a, b, ticks = align_overlapping_bursts (burst_a, burst_b)
    
    if (burst_pair_idx+1) % 10 == 0:
        print( 'Burst pair {}/{}'.format(burst_pair_idx+1, len(events) ))

    if not (burst_a in cortex_ngb and burst_b in striatum_ngb):
        continue
        
    if len(a)/SAMPLING_RATE < 0.5:
        continue
        
    a = butter_bandpass_filter(a, 4, 80)
    b = butter_bandpass_filter(b, 4, 80)
    
    bursts_a.append(a)
    bursts_b.append(b)

    f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
    f, Cxy = get_psd_in_range((f, Cxy), [4, 80])
    
    if np.mean(Cxy) > 0.8:
        continue
    
    xpos, cross_corr = get_lags(a, b, test_bands, prewhiten=True)
    xpos_arr = xpos
    cross_corr_arr_test.append(cross_corr)
    
    peaks_arr_test.append(xpos[np.argmax(cross_corr)])
    
    xpos, cross_corr = get_lags(a, b, control_bands, prewhiten=True)
    xpos_arr = xpos
    cross_corr_arr_control.append(cross_corr)
    peaks_arr_control.append(xpos[np.argmax(cross_corr)])

means_test = np.mean(cross_corr_arr_test, axis=0)
stderrs_test = np.std(cross_corr_arr_test, axis=0) / len(cross_corr_arr_test)**0.5
means_control = np.mean(cross_corr_arr_control, axis=0)
stderrs_control = np.std(cross_corr_arr_control, axis=0) / len(cross_corr_arr_control)**0.5

plt.plot(xpos_arr, means_test, label=test_band_label, c='tab:red')
plt.fill_between(xpos_arr, means_test-stderrs_test, means_test+stderrs_test, alpha=0.25, facecolor='tab:red')
plt.plot(xpos_arr, means_control, label=control_band_label, c='tab:blue')
plt.fill_between(xpos_arr, means_control-stderrs_control, means_control+stderrs_control, alpha=0.25, facecolor='tab:blue')
plt.ylabel('Squared cross-correlation')
plt.xlabel('Lag (ms)')
plt.legend()
format_plot(plt.gca())
plt.show()

means_test = np.mean(peaks_arr_test, axis=0)
stderrs_test = np.std(peaks_arr_test, axis=0) / len(peaks_arr_test)**0.5
means_control = np.mean(peaks_arr_control, axis=0)
stderrs_control = np.std(peaks_arr_control, axis=0) / len(peaks_arr_control)**0.5

bars = plt.bar([0, 1], [means_test, means_control], yerr=[stderrs_test, stderrs_control], capsize=5, facecolor='dimgray')
bars[0].set_color('tab:red')
bars[1].set_color('tab:blue')
plt.xticks([0, 1], [test_band_label, control_band_label])
plt.ylabel('Lag (ms)')
format_plot(plt.gca(), legend=False)
plt.show()

print('5-15 Hz', stats.ttest_1samp(peaks_arr_test, popmean=0))
print('25-35 Hz', stats.ttest_1samp(peaks_arr_control, popmean=0))