# Package imports

In [None]:
import pickle, os

from open_ephys.analysis import Session

from matplotlib import pyplot as plt
plt.rcParams["font.family"] = "Arial"
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False

import scipy
from scipy import signal
from scipy import ndimage
from scipy import stats

import numpy as np
import pandas as pd
import pingouin as pg

from IPython.display import display

import statsmodels.api as sm
from statsmodels.formula.api import ols

from utilities import *
from recording_data import RECORDINGS
from clustering import (
    feature_vector_labels, feature_vector_labels_full,
    get_feature_vectors, get_cluster_labels, sort_bursts_by_labels
)

from detect_mua import detect_MUA, get_spike_wave_params

# Open previously saved data

In [None]:
rms_processed_recordings = open_data('') # Path to processed data (pickle file format)
processed_recordings_mua = open_data('') # Path to processed data (pickle file format)

# Plot overall 'aggregate' statistics

In [None]:
def plot_graph_by_age (data, ylabel, scatter, colors, labels, save=False, log=False, ax=False):
    if not ax:
        fig = plt.figure()
        ax = plt.gca()

    for brain_idx, brain_area in enumerate(data.keys()):
        data_by_age = data[brain_area]
        xpos = np.arange(len(data_by_age.keys()))
        xlabels = [age for age in data_by_age.keys()]
        
        mean = [np.mean(d) if len(d) else np.nan for d in data_by_age.values()]
        stderr = [np.std(d)/len(d)**0.5 if len(d) else np.nan for d in data_by_age.values()]

        if scatter:
            xpos_scatter, value_scatter = [], []
            for idx, d in enumerate(data_by_age.values()):
                for val in d:
                    xpos_scatter.append(idx)
                    value_scatter.append(val)
            ax.plot(xpos_scatter, value_scatter, c=colors[brain_idx], lw=0, marker='o', fillstyle='none', alpha=0.3)

        ax.errorbar(xpos, mean, yerr=stderr, capsize=5, c=colors[brain_idx], label=labels[brain_idx])
        ax.set_xticks(xpos)
        ax.set_xticklabels(xlabels)
        ax.set_xlabel('Age (days)')
        ax.set_ylabel(ylabel)
         
    handles, labels = ax.get_legend_handles_labels()
    handles = [h[0] for h in handles] # remove the errorbars
    leg = ax.legend(handles, labels, frameon=False, fontsize=15, bbox_to_anchor=(1, 1), loc='upper left')
    for legobj in leg.legendHandles:
        legobj.set_linewidth(3.0)
        
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(12.5)
    for item in [ax.title, ax.xaxis.label, ax.yaxis.label] +ax.get_legend().get_texts():
        item.set_fontsize(15)
        
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    if log:
        ax.set_yscale('log')
        
def get_anova_by_age (data):
    pd_data = {
        'brain_area': [],
        'age': [],
        'value': []
    }
    
    for b in ['striatum', 'thalamus', 'cortex']:
        for a in ['5-6', '7-8', '9-10', '11-12', '>12']:
            if a == '>12':
                continue

            for v in data[b][a]:
                pd_data['brain_area'].append(b)
                pd_data['age'].append(a)
                pd_data['value'].append(v)
    
    df = pd.DataFrame(pd_data)
    
    return pg.anova(data=df, between=['brain_area', 'age'], dv='value')

def make_area_by_age_dict ():
    d = {}
    for b in ['striatum', 'thalamus', 'cortex']:
        d[b] = {}
        for a in ['5-6', '7-8', '9-10', '11-12', '>12']:
            d[b][a] = []
    return d

## MUA statistics

In [None]:
spikes_per_second = make_area_by_age_dict()
spike_filled_time = make_area_by_age_dict()

for brain_area in ["striatum", "thalamus", "cortex"]:
    for recording in processed_recordings_mua:
        if recording["age"] == 16:
            continue

        if recording["age"] < 7:
            age = "5-6"
        elif recording["age"] >= 7 and recording["age"] < 9:
            age = "7-8"
        elif recording["age"] >= 9 and recording["age"] < 11:
            age = "9-10"
        elif recording["age"] >= 11 and recording["age"] < 13:
            age = "11-12"
        else:
            age = ">12"

        if not brain_area + "_mua" in recording:
            continue
        if not len(recording[brain_area + "_mua"]):
            continue

        recording_length = recording["length"]
        
        sps = []
        for c in recording[brain_area + "_mua"]:
            spikes = [t for t in c]
            sps.append( len(spikes)/recording_length )
        sps = np.mean(sps)
        spikes_per_second[brain_area][age].append(sps)

        sparse_spike_times = recording[brain_area + "_mua"][0]
        spike_times = np.zeros(int(recording["length"]*SAMPLING_RATE))
        spike_times[sparse_spike_times] = 1

        filled_time = 0
        total_time = 0

        chunk_size = int(SAMPLING_RATE*0.5)
        for i in range(0, len(spike_times), chunk_size):
            spike_chunk = spike_times[i:i+chunk_size]
            if 1 in spike_chunk:
                filled_time += 1
            total_time += 1

        sft = (filled_time / total_time) * 100
        spike_filled_time[brain_area][age].append(sft)

plot_graph_by_age (
    data=spikes_per_second,
    ylabel='Spikes per second',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
plt.show()
display(get_anova_by_age(spikes_per_second))

plot_graph_by_age (
    data=spike_filled_time,
    ylabel='Spikes filled time (%)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
plt.show()
display(get_anova_by_age(spike_filled_time))

# LFP power statistics

In [None]:
LFP_power = make_area_by_age_dict()
        
def bandpower(x, fs, fmin, fmax):
    f, Pxx = scipy.signal.periodogram(x, fs=fs)
    ind_min = scipy.argmax(f > fmin) - 1
    ind_max = scipy.argmax(f > fmax) - 1
    return scipy.trapz(Pxx[ind_min: ind_max], f[ind_min: ind_max])

for recording_idx, recording in enumerate(rms_processed_recordings):
    print("P{} {}/{} {}".format(recording["age"], recording_idx+1, len(rms_processed_recordings), recording["path"]))

    if recording["age"] == 16:
        continue


    if recording["age"] < 7:
        age = "5-6"
    elif recording["age"] >= 7 and recording["age"] < 9:
        age = "7-8"
    elif recording["age"] >= 9 and recording["age"] < 11:
        age = "9-10"
    elif recording["age"] >= 11 and recording["age"] < 13:
        age = "11-12"
    else:
        age = ">12"

    for brain_area in ["striatum", "thalamus", "cortex"]:
        brain_channel = brain_area + "_channel"

        if not brain_channel in recording:
            continue
        else:
            print('\t{}'.format(brain_area))

        recording_n = recording["recording"]
        channel_n = recording[brain_channel]
        session = Session(recording['root'] + recording["path"])

        data_all = session.recordings[recording_n].continuous[0].samples[get_slice_from_s(60*5, 60*20), channel_n]

        lp = bandpower(data_all, SAMPLING_RATE, 4, 100)
        LFP_power[brain_area][age].append(lp)


plot_graph_by_age (
    data=LFP_power,
    ylabel='LFP power ($\mathregular{μV^2}$)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
plt.show()
display(get_anova_by_age(LFP_power))

# Burst statistics

In [None]:
amplitude = make_area_by_age_dict()
relative_amplitude = make_area_by_age_dict()
occurence = make_area_by_age_dict()
filled_time = make_area_by_age_dict()
duration = make_area_by_age_dict()

for brain_area in ["striatum", "thalamus", "cortex"]:
    print(brain_area)
    for idx, recording in enumerate(rms_processed_recordings):
        if recording["age"] == 16:
            continue

        if recording["age"] < 7:
            age = "5-6"
        elif recording["age"] >= 7 and recording["age"] < 9:
            age = "7-8"
        elif recording["age"] >= 9 and recording["age"] < 11:
            age = "9-10"
        elif recording["age"] >= 11 and recording["age"] < 13:
            age = "11-12"
        else:
            age = ">12"

        if not brain_area + "_bursts" in recording:
            continue
        else:
            print('\t{}/{}'.format(idx+1, len(rms_processed_recordings)))

        recording_length = recording["length"]
        baseline = recording[brain_area + "_baseline_amplitude"]

        amp_arr = []
        filled_time_arr = []  
        duration_arr = []

        for burst in recording[brain_area + "_bursts"]:
            data = butter_bandpass_filter(burst.data, 4, 100)

            dur = (burst.time[1] - burst.time[0])/SAMPLING_RATE

            if dur > 20:
                continue

            filled_time_arr.append(dur)

            duration_arr.append(dur)

            amp = np.max(data) - np.min(data)
            amp_arr.append(amp)

        if len(amp_arr):
            amplitude[brain_area][age].append(np.mean(amp_arr))
            relative_amplitude[brain_area][age].append(np.mean(amp_arr)/baseline)
            duration[brain_area][age].append(np.mean(duration_arr))

        filled = (np.sum(filled_time_arr) / recording_length) * 100
        filled_time[brain_area][age].append(filled)

        occur = len(recording[brain_area + "_bursts"]) / recording_length
        occurence[brain_area][age].append(occur)


plot_graph_by_age (
    data=duration,
    ylabel='Duration (s)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex'],
)
plt.show()
display(get_anova_by_age(duration))

plot_graph_by_age (
    data=amplitude,
    ylabel='Burst amplitude (μV)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
plt.show()
display(get_anova_by_age(amplitude))

plot_graph_by_age (
    data=relative_amplitude,
    ylabel='Relative burst amplitude ($\mathregular{A/A_0}$)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
) 
plt.show()
display(get_anova_by_age(relative_amplitude))

plot_graph_by_age (
    data=occurence,
    ylabel='Bursts $\mathregular{s^{-1}}$',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
plt.show()
display(get_anova_by_age(occurence))

plot_graph_by_age (
    data=filled_time,
    ylabel='Burst filled time (%)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
plt.show()
display(get_anova_by_age(filled_time))

In [None]:

plot_graph_by_age (
    data=duration,
    ylabel='Duration (s)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex'],
)
plt.show()
display(get_anova_by_age(duration))

plot_graph_by_age (
    data=amplitude,
    ylabel='Burst amplitude (μV)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
plt.show()
display(get_anova_by_age(amplitude))

plot_graph_by_age (
    data=relative_amplitude,
    ylabel='Relative burst amplitude ($\mathregular{A/A_0}$)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
) 
plt.show()
display(get_anova_by_age(relative_amplitude))

plot_graph_by_age (
    data=occurence,
    ylabel='Bursts $\mathregular{s^{-1}}$',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
plt.show()
display(get_anova_by_age(occurence))

plot_graph_by_age (
    data=filled_time,
    ylabel='Burst filled time (%)',
    scatter=True,
    colors=[COLOR_STRIATUM, COLOR_THALAMUS, COLOR_CORTEX],
    labels=['Striatum', 'CL/Pf', 'Cortex']
)
plt.show()
display(get_anova_by_age(filled_time))

In [None]:
# Is burst incidence lower for thalamus at lower ages?

def cohen_d_for_welch (a, b):
    return (np.mean(a) - np.mean(b)) / np.sqrt((np.var(a) + np.var(b)) / 2)

week1_occurence = {
    'striatum': [],
    'thalamus': [],
    'cortex': []
}

for brain_area in ["striatum", "thalamus", "cortex"]:
    print(brain_area)
    for idx, recording in enumerate(rms_processed_recordings):
        if recording["age"] < 8:
            continue
        if not brain_area + "_bursts" in recording:
            continue
        else:
            print('\t{}/{}'.format(idx+1, len(rms_processed_recordings)))
                
        occur = len(recording[brain_area + "_bursts"]) / (recording["length"] / (SAMPLING_RATE))
        week1_occurence[brain_area].append(occur)

print('Mean cp =', np.mean(week1_occurence['striatum']))
print('Std cp =', np.std(week1_occurence['striatum']))

print('\nMean ctx =', np.mean(week1_occurence['cortex']))
print('Std ctx =', np.std(week1_occurence['cortex']))

print('\nMean thal =', np.mean(week1_occurence['thalamus']))
print('Std thal =', np.std(week1_occurence['thalamus']))

print('\nCp vs ctx')
print(stats.ttest_ind(week1_occurence['striatum'], week1_occurence['cortex'], equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(week1_occurence['striatum'], week1_occurence['cortex']))
print('DOF = ', welch_dof(week1_occurence['striatum'], week1_occurence['cortex']))

print('\nCp vs thal')
print(stats.ttest_ind(week1_occurence['striatum'], week1_occurence['thalamus'], equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(week1_occurence['striatum'], week1_occurence['thalamus']))
print('DOF = ', welch_dof(week1_occurence['striatum'], week1_occurence['thalamus']))

print('\nCtx vs thal')
print(stats.ttest_ind(week1_occurence['cortex'], week1_occurence['thalamus'], equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(week1_occurence['cortex'], week1_occurence['thalamus']))
print('DOF = ', welch_dof(week1_occurence['cortex'], week1_occurence['thalamus']))