# Package imports

In [None]:
import pickle, os

from open_ephys.analysis import Session

from matplotlib import pyplot as plt

import scipy
from scipy import signal
from scipy import ndimage
from scipy import stats

import numpy as np
import pandas as pd

from statsmodels.formula.api import ols

from utilities import *
from recording_data import RECORDINGS
from clustering import (
    feature_vector_labels, feature_vector_labels_full,
    get_feature_vectors, get_cluster_labels, sort_bursts_by_labels
)

# Open previously saved data

In [None]:
rms_processed_recordings = open_data('') # Path to processed data (pickle file format)
processed_recordings_mua = open_data('') # Path to processed data (pickle file format)

# PCA

In [None]:
for key in ['thalamus', 'cortex', 'striatum']:
    all_bursts, all_features = get_feature_vectors (
        rms_processed_recordings,
        processed_recordings_mua,
        key=key+'_bursts', mua_key=key+'_mua'
    )
    
    labels, pca = get_cluster_labels(all_features)
    feature_list_ngb, feature_list_sb, bursts_ngb, bursts_sb = sort_bursts_by_labels(
        all_bursts, all_features, labels
    )
    
    if key == 'thalamus':
        thalamus_ngb_features = feature_list_ngb
        thalamus_sb_features = feature_list_sb
        thalamus_ngb = bursts_ngb
        thalamus_sb = bursts_sb
    if key == 'striatum':
        striatum_ngb_features = feature_list_ngb
        striatum_sb_features = feature_list_sb
        striatum_ngb = bursts_ngb
        striatum_sb = bursts_sb
    if key == 'cortex':
        cortex_ngb_features = feature_list_ngb
        cortex_sb_features = feature_list_sb
        cortex_ngb = bursts_ngb
        cortex_sb = bursts_sb

# Feature vector distributions across groups

In [None]:
key = 'cortex'

all_bursts, all_features = get_feature_vectors(
    rms_processed_recordings,
    processed_recordings_mua,
    key=key+'_bursts', mua_key=key+'_mua'
)
all_labels, _ = get_cluster_labels(all_features)
feature_list_g1, feature_list_g2, bursts_g1, bursts_g2 = sort_bursts_by_labels(
    all_bursts, all_features, all_labels
)

psd_data = {
    "NGB_f": [],
    "NGB_Pxx": [],
    "NGB_Pxx_stderr": [],
    "SB_f": [],
    "SB_Pxx": [],
    "SB_Pxx_stderr": []
}

f, Pxx_mean, Pxx_stderr = get_mean_psd(bursts_g1)
plt.plot(f, Pxx_mean, label='NGB', c='tab:blue')
plt.fill_between(f, Pxx_mean+Pxx_stderr, Pxx_mean-Pxx_stderr, alpha=0.5)

psd_data["NGB_f"] = f
psd_data["NGB_Pxx"] = Pxx_mean
psd_data["NGB_Pxx_stderr"] = Pxx_stderr

f, Pxx_mean, Pxx_stderr = get_mean_psd(bursts_g2)
plt.plot(f, Pxx_mean, label='SB', c='tab:red')
plt.fill_between(f, Pxx_mean+Pxx_stderr, Pxx_mean-Pxx_stderr, alpha=0.5)

plt.xlabel('Frequency (Hz)')
plt.ylabel('Relative power ($\mathregular{P/P_0}$)')

psd_data["SB_f"] = f
psd_data["SB_Pxx"] = Pxx_mean
psd_data["SB_Pxx_stderr"] = Pxx_stderr

format_plot(plt.gca())
plt.show()

def cohen_d_for_welch (a, b):
    return (np.mean(a) - np.mean(b)) / np.sqrt((np.var(a) + np.var(b)) / 2)

tabulated_data = {
    'feature': [],
    'NGB mean (SEM)': [],
    'SB mean (SEM)': [],
    'df': [],
    'Welch’s t': [],
    'p': [],
    'Cohen’s d': []
}


fig, axs = plt.subplots(nrows=3, ncols=3, figsize=[20,10])
axs = axs.reshape(-1)
for feature_idx, feature_label in enumerate(feature_vector_labels_full):        
    f_g1 = [f[feature_idx] for f in feature_list_g1]
    f_g2 = [f[feature_idx] for f in feature_list_g2]
    
    print('N ngb =', len(f_g1), 'sb = ',len(f_g2), 'uc =', len(all_features[all_labels == 2]))
    
    weights = [np.ones(len(f_g1)) / float(len(f_g1)), np.ones(len(f_g2)) / float(len(f_g2))]
    
    axs[feature_idx].hist([f_g1, f_g2], label=['NGB', 'SB'], color=['tab:blue', 'tab:red'], weights=weights)
    axs[feature_idx].set_xlabel(feature_label)
    axs[feature_idx].set_ylabel('Density')
    axs[feature_idx].set_ylim(0,1)
    
    format_plot(axs[feature_idx], legend=False, size=[17.5, 20])
    
    print('\n', feature_label)
    print('Mean ngb =', np.mean(f_g1), 'sb =', np.mean(f_g2))
    print('Std ngb =', np.std(f_g1), 'sb =', np.std(f_g2))
    print(stats.ttest_ind(f_g1, f_g2, equal_var=False))
    print('Cohen\'s d =', cohen_d_for_welch(f_g1, f_g2))
    print('DOF = ', welch_dof(f_g1, f_g2))
    
    t, p = stats.ttest_ind(f_g1, f_g2, equal_var=False)
    
    tabulated_data['feature'].append(feature_label)
    tabulated_data['NGB mean (SEM)'].append(f'{np.mean(f_g1):.3g} ({np.std(f_g1)/len(f_g1)**0.5:.3g})')
    tabulated_data['SB mean (SEM)'].append(f'{np.mean(f_g2):.3g} ({np.std(f_g2)/len(f_g2)**0.5:.3g})')
    tabulated_data['df'].append(f'{welch_dof(f_g1, f_g2):.3g}')
    tabulated_data['Welch’s t'].append(f'{t:.3g}')
    tabulated_data['p'].append(f'{p:.3g}')
    tabulated_data['Cohen’s d'].append(f'{cohen_d_for_welch(f_g1, f_g2):.3g}')
    
plt.tight_layout()
plt.show()

pd.DataFrame(tabulated_data).set_index('feature')

# Burst group spike statistics

In [None]:
spike_isi_g1 = np.array([len(burst.spikes)/burst.feature_vec[0] for burst in striatum_ngb if len(burst.spikes)>1])
spike_isi_g2 = np.array([len(burst.spikes)/burst.feature_vec[0] for burst in striatum_sb if len(burst.spikes)>1])

log_spike_isi_g1 = spike_isi_g1
log_spike_isi_g2 = spike_isi_g2

print('NGB median', np.nanmedian(spike_isi_g1))
print('SB median', np.nanmedian(spike_isi_g2))

print('Spikes ISI:', stats.mannwhitneyu(spike_isi_g1, spike_isi_g2))

bins = 10**np.arange(-1, 2+0.1, 0.1)
weights = [np.ones(len(log_spike_isi_g1)) / float(len(log_spike_isi_g1)), np.ones(len(log_spike_isi_g2)) / float(len(log_spike_isi_g2))]
plt.hist(
    [log_spike_isi_g1],
    label=['NGB'],
    color=['tab:blue'],
    weights=weights[0],
    bins=np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins)),
    alpha=0.75
);
plt.hist(
    [log_spike_isi_g2],
    label=['SB'],
    color=['tab:red'],
    weights=weights[1],
    bins=np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins)),
    alpha=0.75
);
plt.legend()
plt.xlabel(feature_label)
plt.xscale('log')
plt.ylabel('Density')
plt.xlabel('Log spike rate (Hz)')
plt.ylim(0,0.25)
plt.xlim(10**-1, 10**2)
format_plot(plt.gca(), legend=False, size=[17.5, 20])
plt.show()






spike_n_g1 = np.array([len(burst.spikes) for burst in striatum_ngb])
spike_n_g2 = np.array([len(burst.spikes) for burst in striatum_sb])

log_spike_n_g1 = spike_n_g1[np.nonzero(spike_n_g1)]
log_spike_n_g2 = spike_n_g2[np.nonzero(spike_n_g2)]

zero_spike_n_g1 = len([0 for burst in bursts_g1 if not len(burst.spikes)])
zero_spike_n_g2 = len([0 for burst in bursts_g2 if not len(burst.spikes)])

print('NGB median', np.median(spike_n_g1))
print('SB median', np.median(spike_n_g2))

print('NGB non-zero median', np.median(spike_n_g1[np.nonzero(spike_n_g1)]))
print('SB non-zero median', np.median(spike_n_g2[np.nonzero(spike_n_g2)]))
print('Spikes count:', stats.mannwhitneyu(spike_n_g1[np.nonzero(spike_n_g1)], spike_n_g2[np.nonzero(spike_n_g2)]))

g1_pc = zero_spike_n_g1/len(spike_n_g1)*100
g2_pc = zero_spike_n_g2/len(spike_n_g2)*100

print('G1 zero percent', g1_pc)
print('G2 zero percent', g2_pc)

fig = plt.figure(figsize=[3,4])
plt.bar([0, 1], [g1_pc, g2_pc], color=['tab:blue', 'tab:red'])
plt.xticks([0, 1], ['NGB', 'SB'])
plt.ylabel('% events with 0 spikes')
plt.ylim(0,50)
format_plot(plt.gca(), legend=False, size=[17.5, 20])
plt.show()

bins = 10**np.arange(0, 3+0.25, 0.25)
weights = [np.ones(len(log_spike_n_g1)) / float(len(log_spike_n_g1)), np.ones(len(log_spike_n_g2)) / float(len(log_spike_n_g2))]
plt.hist(
    [log_spike_n_g1],
    label=['NGB'],
    color=['tab:blue'],
    weights=weights[0],
    bins=np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins)),
    alpha=0.75
);
plt.hist(
    [log_spike_n_g2],
    label=['SB'],
    color=['tab:red'],
    weights=weights[1],
    bins=np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins)),
    alpha=0.75
);
plt.legend()
plt.xlabel(feature_label)
plt.ylabel('Density')
plt.xlabel('Log spikes per event')
plt.ylim(0,0.5)
plt.xlim(10**0, 10**3.0)
plt.xscale('log')
format_plot(plt.gca(), legend=True, size=[17.5, 20])
plt.show()

# Alpha-beta power

In [None]:
def get_alpha_beta (bursts):
    alpha = []
    beta = []
    for burst in bursts:
        if hasattr(burst, "normalized_psd"):
            f_burst, Pxx_burst = burst.normalized_psd

            _, Pxx_alpha = get_psd_in_range((f_burst, Pxx_burst), [4, 12])
            _, Pxx_beta = get_psd_in_range((f_burst, Pxx_burst), [12, 20])

            alpha.append(np.mean(Pxx_alpha))
            beta.append(np.mean(Pxx_beta))
            
    return alpha, beta
 
alpha_mns = []
alpha_errs = []

beta_mns = []
beta_errs = []
    
for key, group in zip(['thal', 'cp', 'ctx'], [cortex_ngb, striatum_ngb, thalamus_ngb]):
    print(f'\n{key}')
    
    alpha, beta = get_alpha_beta(group)
    
    alpha_mns.append(np.mean(alpha))
    alpha_errs.append(np.std(alpha)/len(alpha)**0.5)
    beta_mns.append(np.mean(beta))
    beta_errs.append(np.std(beta)/len(beta)**0.5)

    print('Mean alpha-theta =', np.mean(alpha), 'low-beta =', np.mean(beta))
    print('Std alpha-theta =', np.std(alpha), 'low-beta =', np.std(beta))
    print(stats.ttest_ind(alpha, beta, equal_var=False))
    print('Cohen\'s d =', cohen_d_for_welch(alpha, beta))
    print('DOF = ', welch_dof(alpha, beta))

x = np.arange(3)
plt.bar(x-0.125, alpha_mns, yerr=alpha_errs, width=0.25, label='4-12 Hz')
plt.bar(x+0.125, beta_mns, yerr=beta_errs, width=0.25, label='12-20 Hz')
plt.xticks(x, ['Cortex', 'Striatum', 'Thalamus'])
plt.legend()
format_plot(plt.gca(), legend=True, size=[17.5, 20])
plt.ylabel('Mean relative power ($P/P_{0}$)')
plt.show()




diffs_arr = []
for key, group in zip(['thal', 'cp', 'ctx'], [cortex_ngb, striatum_ngb, thalamus_ngb]):
    alpha, beta = get_alpha_beta(group)
    
    diff = np.array(alpha)-np.array(beta)
    diffs_arr.append(diff)

ctx_diff, cp_diff, thal_diff  = diffs_arr

for pair_name, pair in zip(['ctx-cp', 'ctx-thal', 'thal-cp'], [(ctx_diff, cp_diff), (ctx_diff, thal_diff), (thal_diff, cp_diff)]):
    print(f'\n{pair_name}')
    
    a_diff, b_diff = pair

    print('Mean alpha-theta =', np.mean(a_diff), 'low-beta =', np.mean(b_diff))
    print('Std alpha-theta =', np.std(a_diff), 'low-beta =', np.std(b_diff))
    print(stats.ttest_ind(a_diff, b_diff, equal_var=False)[0], stats.ttest_ind(a_diff, b_diff, equal_var=False)[1]*3)
    print('Cohen\'s d =', cohen_d_for_welch(a_diff, b_diff))
    print('DOF = ', welch_dof(a_diff, b_diff))

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=[20,10])
axs = axs.reshape(-1)
    
for feature_idx, feature_label in enumerate(feature_vector_labels_full):
    striatum_g1 = [f[feature_idx] for f in cortex_ngb_features]
    thal_g1     = [f[feature_idx] for f in thalamus_ngb_features]
    
    weights = [np.ones(len(striatum_g1)) / float(len(striatum_g1)), np.ones(len(thal_g1)) / float(len(thal_g1))]
    
    axs[feature_idx].hist([striatum_g1, thal_g1], label=['Cortex', 'CL/Pf'], weights=weights, color=[COLOR_CORTEX, COLOR_STRIATUM])
    axs[feature_idx].set_xlabel(feature_label)
    axs[feature_idx].set_ylabel('Density')
    axs[feature_idx].set_ylim([0, 1])
    
    format_plot(axs[feature_idx], legend=False, size=[17.5, 20])
    
    print('\n', '', feature_label)
    print(np.mean(striatum_g1), np.mean(thal_g1))
    print(np.std(striatum_g1), np.std(thal_g1))
    print(stats.ttest_ind(striatum_g1, thal_g1, equal_var=False))
    print('Cohen\'s d =', cohen_d_for_welch(striatum_g1, thal_g1))
    print('DOF = ', welch_dof(striatum_g1, thal_g1))
    
plt.tight_layout()
plt.show()

# F-ratio between burst groups and regions

In [None]:
# https://onlinelibrary.wiley.com/doi/epdf/10.1111/j.1442-9993.2001.01070.pp.x

def multivariate_f_ratio (a, b):
    a_group_means = np.mean(a, axis=0)
    b_group_means = np.mean(b, axis=0)
    overall_means = np.mean(np.concatenate([a,b]), axis=0)

    ss_between_group = 0
    ss_between_group += (len(a) * np.linalg.norm(overall_means-a_group_means)**2)
    ss_between_group += (len(b) * np.linalg.norm(overall_means-b_group_means)**2)
    
    ss_within_group = 0
    for feature_vec in a:
        ss_within_group += (np.linalg.norm(a_group_means-feature_vec)**2)
    for feature_vec in b:
        ss_within_group += (np.linalg.norm(b_group_means-feature_vec)**2)      
    ss_within_group = ss_within_group / (len(a)+len(b) - 2)
  
    f = ss_between_group/ss_within_group
    return f

def permute_multivariate_f_ratio (a, b, iterations):
    combined = np.concatenate([a, b])

    test_stat = multivariate_f_ratio(a, b)
    null_dist = []

    for i in range(iterations):
        shuffled = np.random.permutation(combined)
        a_shuffled = shuffled[:len(a)]
        b_shuffled = shuffled[len(a):]

        f = multivariate_f_ratio(a_shuffled, b_shuffled)
        null_dist.append(f)

        if (i+1) % 1000 == 0:
            print('Iteration', i+1)

    p_val = len(np.where(null_dist>= test_stat)[0]) / len(null_dist)
    
    return p_val, test_stat, null_dist
    
p, f, null = permute_multivariate_f_ratio(striatum_ngb_features, striatum_sb_features, iterations=1000)

bins = 100
max_val = np.max(np.histogram(null, bins=bins)[0])


plt.hist(null, bins=bins, facecolor='dimgray')
plt.plot([f, f], [0, max_val])
plt.xlabel('F-value')
plt.ylabel('Count')
format_plot(plt.gca(), legend=False)
plt.show()

print('p = {}, F({}, {}) = {}'.format(
    p,
    1,
    len(striatum_ngb_features) + len(striatum_sb_features) - 2,
    f
))

In [None]:
bins = 100
max_val = np.max(np.histogram(null, bins=bins)[0])

fig, (ax,ax2) = plt.subplots(1, 2, sharey=True, facecolor='w')

# plot the same data on both axes
ax.hist(null, bins=bins, facecolor='dimgray')
ax2.plot([f, f], [0, max_val-1000], c='tab:orange')

ax.set_xlim(0,14.5)
ax2.set_xlim(8509-7.25, 8509+7.25)
ax.set_ylim(0, 1500)

ax.set_xlabel('F-value')
ax.set_ylabel('Count')

# hide the spines between ax and ax2
ax.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax.yaxis.tick_left()
ax.tick_params(labelright='off')
ax2.yaxis.tick_right()

format_plot(ax, legend=False)
format_plot(ax2, legend=False)

d = .015 
kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
ax.plot((1-d,1+d), (-d,+d), **kwargs)

kwargs.update(transform=ax2.transAxes) 
ax2.plot((-d,+d), (-d,+d), **kwargs)

plt.show()

# Bursts over time in different groups

In [None]:
def plot_group_lineplots (data, ylim, labels, xticks, colors=['blue', 'orange'], title='', xlabel='', ylabel='', yscale='linear', save=False):    
    fig = plt.figure()
    
    lines = []
    for idx, data_group in enumerate(data):
        shift = None
        
        if len(data) == 1:
            shift = 0
        elif len(data) == 2 and idx == 0:
            shift = -0.4
        elif len(data) ==2 and idx == 1:
            shift = 0.4

        positions = np.array(range(len(data_group)))# * 2 + shift
        line = plt.errorbar(
            positions,
            [np.mean(group) for group in data_group],
            yerr=[np.std(group)/(len(data_group)**0.5) for group in data_group],
            capsize=5,
            c=colors[idx]
        )
        
        for group_idx, group in enumerate(data_group):
            xpos = [positions[group_idx] for i in range(len(group))]
            plt.plot(xpos, group, c=colors[idx], lw=0, marker='o', alpha=0.25)
        
        lines.append(line)
        
    plt.xticks(np.arange(len(xticks)), xticks)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.yscale(yscale)
    plt.ylim(ylim)
        
    leg = plt.legend([line for line in lines], labels, frameon=False, fontsize=15, bbox_to_anchor=(1, 1), loc='upper left')
    for legobj in leg.legendHandles:
        legobj.set_linewidth(2.0)
        
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    for item in ([ax.title] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(12.5)
    for item in ([ax.xaxis.label, ax.yaxis.label] + ax.get_legend().get_texts()):
        item.set_fontsize(15)
    
    if save:
        plt.savefig(FIG_ROOT + save, bbox_inches="tight")
    plt.show()

# GROUP is group of bursts identified earlier via PCA
# set to bursts_g1/bursts_g2
GROUP = {
    "NGB": cortex_ngb,
    "SB": cortex_sb
}
    
occurence = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}
occurence_csv = {
    "burst_type": [],
    "age": [],
    "path": [],
    "bursts_per_second": []
}

amplitude = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}
amplitude_csv = {
    "burst_type": [],
    "age": [],
    "path": [],
    "amplitude": []
}

duration = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}
duration_csv = {
    "burst_type": [],
    "age": [],
    "path": [],
    "duration": []
}

spike_rate = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}
spike_rate_csv = {
    "burst_type": [],
    "age": [],
    "path": [],
    "spike_rate": []
}

alphatheta_power = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}
betagamma_power = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}

alphatheta_peak = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}
betagamma_peak = {
    "NGB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "SB": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}


box_labels = ["NGB", "SB"]
age_group_labels = list(occurence["SB"].keys())

for brain_area in ["cortex"]:
    brain_area_bursts = brain_area + "_bursts"
    
    for recording_n, recording in enumerate(rms_processed_recordings):
        if not brain_area_bursts in recording:
            continue
            
        print("Processing {}, recording {}/{}".format(brain_area, recording_n+1, len(rms_processed_recordings)))

        age = recording["age"]    
        if age < 7:
            age_group = "5-6"
        elif age >= 7 and age < 9:
            age_group = "7-8"
        elif age >= 9 and age < 11:
            age_group = "9-10"
        elif age >= 11 and age < 13:
            age_group = "11-12"
        else:
            continue

        total_bursts_ngb = []
        total_bursts_sb = []
        
        amplitude_recording = {
            "NGB": [],
            "SB": []
        }
        duration_recording = {
            "NGB": [],
            "SB": []
        }
        spike_rate_recording = {
            "NGB": [],
            "SB": []
        }
        alphatheta_recording = {
            "NGB": [],
            "SB": []
        }
        betagamma_recording = {
            "NGB": [],
            "SB": []
        }
        alphatheta_peak_recording = {
            "NGB": [],
            "SB": []
        }
        betagamma_peak_recording = {
            "NGB": [],
            "SB": []
        }
        
        for burst_idx, burst in enumerate(recording[brain_area_bursts]):            
            if (burst in GROUP['SB']):
                group_key = 'SB'
                total_bursts_sb.append(burst)
            elif burst in GROUP['NGB']:
                group_key = 'NGB'
                total_bursts_ngb.append(burst)
            else:
                continue

            amplitude_recording[group_key].append(np.max(burst.data)-np.min(burst.data))
            duration_recording[group_key].append((burst.time[1] - burst.time[0])/SAMPLING_RATE)
            spike_rate_recording[group_key].append(burst.feature_vec[8])
            alphatheta_recording[group_key].append(burst.feature_vec[6])
            betagamma_recording[group_key].append(burst.feature_vec[5])
            
            f, Pxx = burst.normalized_psd

            theta_idx = np.where(f >= 4)[0][0] 
            beta_idx = np.where(f >= 16)[0][0]
            lgamma_idx = np.where(f >= 40)[0][0]

            alphatheta_peak_ = f[np.argmax(Pxx[theta_idx:beta_idx])+theta_idx]
            betagamma_peak_ = f[np.argmax(Pxx[beta_idx:lgamma_idx])+beta_idx]

            alphatheta_peak_recording[group_key].append(alphatheta_peak_)
            betagamma_peak_recording[group_key].append(betagamma_peak_)
            
        occurence_ngb = len(total_bursts_ngb) / (recording["length"] / SAMPLING_RATE) * 60
        occurence_sb = len(total_bursts_sb) / (recording["length"] / SAMPLING_RATE) * 60
        
        occurence["NGB"][age_group].append(occurence_ngb)
        occurence_csv["burst_type"].append('ngb')
        occurence_csv["age"].append(recording["age"])
        occurence_csv["path"].append(recording["path"])
        occurence_csv["bursts_per_second"].append(occurence_ngb)
        
        occurence["SB"][age_group].append(occurence_sb)
        occurence_csv["burst_type"].append('sb')
        occurence_csv["age"].append(recording["age"])
        occurence_csv["path"].append(recording["path"])
        occurence_csv["bursts_per_second"].append(occurence_sb)
        
        if len(amplitude_recording["NGB"]):
            amplitude["NGB"][age_group].append(np.mean(amplitude_recording["NGB"]))
            
            amplitude_csv["burst_type"].append('ngb')
            amplitude_csv["age"].append(recording["age"])
            amplitude_csv["path"].append(recording["path"])
            amplitude_csv["amplitude"].append(np.mean(amplitude_recording["NGB"]))
        if len(amplitude_recording["SB"]):
            amplitude["SB"][age_group].append(np.nanmean(amplitude_recording["SB"]))
        
            amplitude_csv["burst_type"].append('sb')
            amplitude_csv["age"].append(recording["age"])
            amplitude_csv["path"].append(recording["path"])
            amplitude_csv["amplitude"].append(np.mean(amplitude_recording["SB"]))
            
        if len(duration_recording["NGB"]):
            duration["NGB"][age_group].append(np.mean(duration_recording["NGB"]))
            
            duration_csv["burst_type"].append('ngb')
            duration_csv["age"].append(recording["age"])
            duration_csv["path"].append(recording["path"])
            duration_csv["duration"].append(np.mean(duration_recording["NGB"]))
        if len(duration_recording["SB"]):
            duration["SB"][age_group].append(np.nanmean(duration_recording["SB"]))
                        
            duration_csv["burst_type"].append('sb')
            duration_csv["age"].append(recording["age"])
            duration_csv["path"].append(recording["path"])
            duration_csv["duration"].append(np.mean(duration_recording["SB"]))
            
        if len(spike_rate_recording["NGB"]):
            spike_rate["NGB"][age_group].append(np.mean(spike_rate_recording["NGB"]))
            
            spike_rate_csv["burst_type"].append('ngb')
            spike_rate_csv["age"].append(recording["age"])
            spike_rate_csv["path"].append(recording["path"])
            spike_rate_csv["spike_rate"].append(np.mean(spike_rate_recording["NGB"]))
        if len(spike_rate_recording["SB"]):
            spike_rate["SB"][age_group].append(np.mean(spike_rate_recording["SB"]))
            
            spike_rate_csv["burst_type"].append('sb')
            spike_rate_csv["age"].append(recording["age"])
            spike_rate_csv["path"].append(recording["path"])
            spike_rate_csv["spike_rate"].append(np.mean(spike_rate_recording["SB"]))
        
        if len(alphatheta_recording["NGB"]):
            alphatheta_power["NGB"][age_group].append(np.mean(alphatheta_recording["NGB"]))
        if len(alphatheta_recording["SB"]):
            alphatheta_power["SB"][age_group].append(np.mean(alphatheta_recording["SB"]))
        
        if len(betagamma_recording["NGB"]):
            betagamma_power["NGB"][age_group].append(np.mean(betagamma_recording["NGB"]))
        if len(betagamma_recording["SB"]):
            betagamma_power["SB"][age_group].append(np.mean(betagamma_recording["SB"]))
        
        if len(alphatheta_peak_recording["NGB"]):
            alphatheta_peak["NGB"][age_group].append(np.mean(alphatheta_peak_recording["NGB"]))
        if len(alphatheta_peak_recording["SB"]):
            alphatheta_peak["SB"][age_group].append(np.mean(alphatheta_peak_recording["SB"]))
        
        if len(betagamma_peak_recording["NGB"]):
            betagamma_peak["NGB"][age_group].append(np.mean(betagamma_peak_recording["NGB"]))
        if len(betagamma_peak_recording["SB"]):
            betagamma_peak["SB"][age_group].append(np.mean(betagamma_peak_recording["SB"]))
        

# Parameters to pass to boxplot function
plot_parameters = [
    { "title": "occurence", "ylabel": "Bursts $\mathregular{s^{-1}}$", "data": occurence, "ylim": [None, None] },
    { "title": "amplitude", "ylabel": "Amplitude (μV)", "data": amplitude, "ylim": [200, 2000] },
    { "title": "duration", "ylabel": "Duration (s)", "data": duration, "ylim": [0, 10] },
    { "title": "alphatheta", "ylabel": "Relative alpha-theta power", "data": alphatheta_power, "ylim": [0, 0.6] },
    { "title": "alphatheta_peak", "ylabel": "Peak alpha-theta frequency", "data": alphatheta_peak, "ylim": [5, 15] },
    { "title": "betagamma", "ylabel": "Relative beta-gamma power", "data": betagamma_power, "ylim": [0.25, 0.65] },
    { "title": "betagamma_peak", "ylabel": "Peak beta-gamma frequency", "data": betagamma_peak, "ylim": [15, 30] }
]

tabulated_data = {
    'feature': [],
    'NGB mean (σ) P5-6': [],
    'SB mean (σ) P5-6': [],
    'NGB mean (σ) P7-8': [],
    'SB mean (σ) P7-8': [],
    'NGB mean (σ) P9-10': [],
    'SB mean (σ) P9-10': [],
    'NGB mean (σ) P11-12': [],
    'SB mean (σ) P11-12': []
}

# Box plots
for plot_param in plot_parameters:
    data = plot_param["data"]
    age_groups_data = []
        
    tabulated_data['feature'].append(plot_param['title'])
        
    for brain_area_values in zip(data.keys(), data.values()):
        brain_key, brain_area = brain_area_values
        print('\t', brain_key)
        
        temp_mn = []
        temp_std = []
        
        age_groups = []
        for age_group_values in zip(brain_area.keys(), brain_area.values()):
            age_group_key, age_group = age_group_values
            
            tab_key = f'{brain_key} mean (σ) P{age_group_key}'
            tabulated_data[tab_key].append(
                f'{round(np.mean(age_group), 1)} ({round(np.std(age_group)/len(age_group)**0.5, 2)})'
            )
            
            age_groups.append(age_group)
            
            temp_mn.append(np.mean(age_group))
            temp_std.append(np.std(age_group))
        
        print(np.mean(temp_mn), np.mean(temp_std))
        age_groups_data.append(age_groups)

    plot_group_lineplots(
        data=age_groups_data,
        ylim=plot_param["ylim"],
        labels=box_labels,
        xticks=age_group_labels,
        colors=['tab:blue', 'tab:red'],
        xlabel="Age (days)",
        ylabel=plot_param["ylabel"],
        yscale=('log' if plot_param["title"] == "Duration" or plot_param["title"] == "Amplitude" else 'linear')
    )
    
    print(twoway_anova(data, ['burst_type', 'age', 'value']))
    
display(pd.DataFrame(tabulated_data).set_index('feature'))