# Package imports

In [None]:
import pickle, os

from open_ephys.analysis import Session

from matplotlib import pyplot as plt
plt.rcParams["font.family"] = "Arial"
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False

import scipy
from scipy import signal
from scipy import ndimage
from scipy import stats

import numpy as np
import pandas as pd

from statsmodels.tsa.arima.model import ARIMA
from statsmodels.stats.descriptivestats import sign_test
import statsmodels.api as sm
from statsmodels.formula.api import ols

from utilities import *
from recording_data import RECORDINGS
from clustering import (
    feature_vector_labels, feature_vector_labels_full,
    get_feature_vectors, get_cluster_labels, sort_bursts_by_labels
)
from xcorr import correlate_template
from detect_mua import detect_MUA, get_spike_wave_params

# Open previously saved data

In [None]:
rms_processed_recordings = open_data('') # Path to processed data (pickle file format)
processed_recordings_mua = open_data('') # Path to processed data (pickle file format)

# PCA clustering

In [None]:
for key in ['thalamus', 'cortex', 'striatum']:
    all_bursts, all_features = get_feature_vectors (
        rms_processed_recordings,
        processed_recordings_mua,
        key=key+'_bursts', mua_key=key+'_mua'
    )
    
    labels, pca = get_cluster_labels(all_features)
    feature_list_ngb, feature_list_sb, bursts_ngb, bursts_sb = sort_bursts_by_labels(
        all_bursts, all_features, labels
    )
    
    if key == 'thalamus':
        thalamus_ngb_features = feature_list_ngb
        thalamus_sb_features = feature_list_sb
        thalamus_ngb = bursts_ngb
        thalamus_sb = bursts_sb
    if key == 'striatum':
        striatum_ngb_features = feature_list_ngb
        striatum_sb_features = feature_list_sb
        striatum_ngb = bursts_ngb
        striatum_sb = bursts_sb
    if key == 'cortex':
        cortex_ngb_features = feature_list_ngb
        cortex_sb_features = feature_list_sb
        cortex_ngb = bursts_ngb
        cortex_sb = bursts_sb

# Useful functions used throughout

In [None]:
def cohen_d_for_welch (a, b):
    return (np.mean(a) - np.mean(b)) / np.sqrt((np.var(a) + np.var(b)) / 2)

def align_overlapping_bursts (burst_a, burst_b):
    start_overlap = max(burst_a.time[0], burst_b.time[0])
    end_overlap = min(burst_a.time[1], burst_b.time[1])
    
    
    start_a = start_overlap-burst_a.time[0]
    end_a = start_a + (end_overlap-start_overlap)
    start_b = start_overlap-burst_b.time[0]
    end_b = start_b + (end_overlap-start_overlap)
    
    burst_a_trimmed = burst_a.data[start_a:end_a]
    burst_b_trimmed = burst_b.data[start_b:end_b]
    
    return (
        burst_a_trimmed,
        burst_b_trimmed,
        get_xticks(slice(start_overlap, end_overlap))
    )

def get_threshold (burst_pairs):
    f_arr = None
    Cxy_arr = []
    
    a_burst, b_burst = burst_pairs[0]
    a, b, xticks = align_overlapping_bursts(a_burst, b_burst)
    for iteration in range(200):
        np.random.shuffle(a)
        np.random.shuffle(b)

        window_size = 0.5
        window = int(SAMPLING_RATE*window_size)
        overlap = int(SAMPLING_RATE*window_size*0.5)
        f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
                
        f, Cxy = get_psd_in_range((f, Cxy), [0, 100])
        

        plt.plot(a)
        plt.plot(b)
        plt.show()

        f_arr = f
        Cxy_arr.append(Cxy)
        
    Cxy_arr = np.array(Cxy_arr)

    return np.percentile(Cxy_arr, 95, axis=0)

def get_lags (burst_a_data, burst_b_data, freq, prewhiten):
    burst_a_data = butter_bandpass_filter(burst_a_data, freq[0], freq[1])
    burst_b_data = butter_bandpass_filter(burst_b_data, freq[0], freq[1])
        
    # Get instantaneous amplitude
    analytic_signal_a = signal.hilbert(burst_a_data)
    amplitude_a = np.abs(analytic_signal_a)
    amplitude_a = amplitude_a - np.mean(amplitude_a)
    
    analytic_signal_b = signal.hilbert(burst_b_data)
    amplitude_b = np.abs(analytic_signal_b)
    amplitude_b = amplitude_b - np.mean(amplitude_b)
    
    # Pre-whiten
    if prewhiten:
        model_a = ARIMA(amplitude_a, order=(1,0,1))
        fit_a = model_a.fit()
        
        a_resid = fit_a.resid        
        a, b = a_resid, amplitude_b
    else:
        a = amplitude_a
        b = amplitude_b

    # Get cross-correlation
    cross_corr = correlate_template(a, b, mode='full', normalize='naive', demean=False)

    # Take only +/- 100ms from 0 lag
    centre = int(len(cross_corr) * 0.5)
    hundred_ms = int(SAMPLING_RATE*0.1)
    cross_corr = cross_corr[centre-hundred_ms:centre+hundred_ms]

    # Square results
    cross_corr_sq = cross_corr**2
    
    # Convert time points into ms
    xpos = np.linspace(-hundred_ms, hundred_ms, len(cross_corr))
    xpos = xpos / (SAMPLING_RATE) * 1000
    
    return xpos, cross_corr_sq

# Get co-occuring bursts across regions

In [None]:
def get_overlapping_bursts (bursts_a, bursts_b, bursts_c=None):
    # maximum difference in start times considered overlapping
    max_t_diff = 0.5
    burst_pairs = []

    for a_idx, burst_a in enumerate(bursts_a):
        start_t_a = burst_a.time[0]
        
        for burst_b in bursts_b:
            start_t_b = burst_b.time[0]
            
            if bursts_c:
                for burst_c in bursts_c:
                    start_t_c = burst_c.time[0]
                    
                    t_diff_ab = abs(start_t_a - start_t_b)
                    t_diff_ac = abs(start_t_a - start_t_c)
                    t_diff_bc = abs(start_t_b - start_t_c)
                    
                    max_diff = max(t_diff_ab, t_diff_ac, t_diff_bc)
                    min_len = min(len(burst_a.data), len(burst_b.data), len(burst_c.data))
                    
                    if max_diff < (SAMPLING_RATE*max_t_diff) and max_diff < min_len:
                        burst_pairs.append(
                            { "bursts": [burst_a, burst_b, burst_c], "hash": hash(burst_a)+hash(burst_b)+hash(burst_c) }
                        )
            else:
                t_diff = abs(start_t_a - start_t_b)
                if t_diff < (SAMPLING_RATE*max_t_diff) and t_diff < min(len(burst_a.data), len(burst_b.data)):
                    burst_pairs.append(
                        { "bursts": [burst_a, burst_b], "hash": hash(burst_a)+hash(burst_b) }
                    )
    
    burst_pairs = np.array(burst_pairs)
    unique_burst_pair_idxs = np.unique([pair["hash"] for pair in burst_pairs], return_index=True)[1]
    unique_burst_pairs = burst_pairs[unique_burst_pair_idxs]
    unique_burst_pairs = [pair["bursts"] for pair in burst_pairs]
    
    return unique_burst_pairs

def get_nonoverlapping_bursts (bursts_a, bursts_b, burst_pairs):
    bursts_a_nonoverlapping = []
    for burst_a in bursts_a:
        is_overlapping = False
        for burst_pair in burst_pairs:
            if burst_a in burst_pair:
                is_overlapping = True
                break
        if not is_overlapping:
            bursts_a_nonoverlapping.append(burst_a)

    bursts_b_nonoverlapping = []
    for burst_b in bursts_b:
        is_overlapping = False
        for burst_pair in burst_pairs:
            if burst_b in burst_pair:
                is_overlapping = True
                break
        if not is_overlapping:
            bursts_b_nonoverlapping.append(burst_b)
      
    return bursts_a_nonoverlapping, bursts_b_nonoverlapping

cooccuring_events = {
    "Thal-CP": [],
    "Ctx-CP": [],
    "Ctx-thal": [],
    "Ctx-thal-CP": []
}

cooccuring_events_by_age = {
    "Thal-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "Ctx-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "Ctx-thal": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "Ctx-thal-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}

nonoverlapping_events = {
    "thalamus": [],
    "striatum": [],
    "cortex": []
}

proportion_cooccuring_events = {
    "Thal-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "Ctx-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "Ctx-thal": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "Ctx-thal-CP": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}


proportion_nonoverlapping_events = {
    "thalamus": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "striatum": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    },
    "cortex": {
        "5-6": [],
        "7-8": [],
        "9-10": [],
        "11-12": []
    }
}

for recording_idx, recording in enumerate(rms_processed_recordings):    
    if recording["age"] < 7:
        age = "5-6"
    elif recording["age"] >= 7 and recording["age"] < 9:
        age = "7-8"
    elif recording["age"] >= 9 and recording["age"] < 11:
        age = "9-10"
    elif recording["age"] >= 11 and recording["age"] < 13:
        age = "11-12"
    else:
        continue
    
    print("{}/{}".format(recording_idx+1, len(rms_processed_recordings)))
    
    Ctx_CP = get_overlapping_bursts (recording["cortex_bursts"], recording["striatum_bursts"])
    
    nonoverlapping_events_ctx = []
    nonoverlapping_events_cp = []
    nonoverlapping_events_thal = []
            
    cooccuring_events["Ctx-CP"] += Ctx_CP
    cooccuring_events_by_age["Ctx-CP"][age] += Ctx_CP
    proportion_cooccuring_events["Ctx-CP"][age].append(
        len(Ctx_CP) / (len(recording["cortex_bursts"]) + len(recording["striatum_bursts"]))
    )
    
    if "thalamus_bursts" in recording:
        Thal_CP = get_overlapping_bursts (recording["thalamus_bursts"], recording["striatum_bursts"])
        cooccuring_events["Thal-CP"] += Thal_CP
        cooccuring_events_by_age["Thal-CP"][age] += Thal_CP
        proportion_cooccuring_events["Thal-CP"][age].append(
            len(Thal_CP) / (len(recording["thalamus_bursts"]) + len(recording["striatum_bursts"]))
        )
        
        Ctx_thal = get_overlapping_bursts (recording["cortex_bursts"], recording["thalamus_bursts"])
        cooccuring_events["Ctx-thal"] += Ctx_thal
        cooccuring_events_by_age["Ctx-thal"][age] += Ctx_thal
        proportion_cooccuring_events["Ctx-thal"][age].append(
            len(Ctx_thal) / (len(recording["cortex_bursts"]) + len(recording["thalamus_bursts"]))
        )
        
        Ctx_thal_CP = get_overlapping_bursts (recording["cortex_bursts"], recording["thalamus_bursts"], recording["striatum_bursts"])
        cooccuring_events["Ctx-thal-CP"] += Ctx_thal_CP
        cooccuring_events_by_age["Ctx-thal-CP"][age] += Ctx_thal_CP
        proportion_cooccuring_events["Ctx-thal-CP"][age].append(
            len(Ctx_thal_CP) / (len(recording["cortex_bursts"]) + len(recording["thalamus_bursts"]) + len(recording["striatum_bursts"]))
        )
        
    def unique_list (l):
        temp = []
        for i in l:
            if i in temp:
                continue
            else:
                temp.append(i)
        return temp
                
    proportion_nonoverlapping_events['cortex'][age].append(
        len(unique_list(nonoverlapping_events_ctx))/len(recording["cortex_bursts"])
    )
    if "thalamus_bursts" in recording:
        proportion_nonoverlapping_events['thalamus'][age].append(
            len(unique_list(nonoverlapping_events_thal))/len(recording["thalamus_bursts"])
        )
    proportion_nonoverlapping_events['striatum'][age].append(
        len(unique_list(nonoverlapping_events_cp))/len(recording["striatum_bursts"])
    )

# Spike-spike cross-correlations

In [None]:
def get_spike_spike_crosscorr (burst_pair, shuffle):
    burst_a, burst_b = burst_pair
    a, b, _ = align_overlapping_bursts(burst_a, burst_b)
    
    spikes_a_times, _ = detect_MUA(a, 4, silent=True)
    spikes_a = np.zeros(len(a))
    for t in spikes_a_times:
        spikes_a[t] = 1
    
    spikes_b_times, _ = detect_MUA(b, 4, silent=True)
    spikes_b = np.zeros(len(b))
    for t in spikes_b_times:
        spikes_b[t] = 1
    
    if len(spikes_a_times) < 30 or len(spikes_b_times) < 30:
        return np.nan
    
    spikes_a = ndimage.gaussian_filter1d(spikes_a, int(SAMPLING_RATE/1000 * 2))
    spikes_b = ndimage.gaussian_filter1d(spikes_b, int(SAMPLING_RATE/1000 * 2))
    
    cross_corr = correlate_template(spikes_a, spikes_b, mode='full', normalize='naive', demean=False)
    
    # Take only +/- 5ms from 0 lag
    centre = int(len(cross_corr) * 0.5)
    ten_ms = int(SAMPLING_RATE/1000 * 10)
    cross_corr = cross_corr[centre-ten_ms:centre+ten_ms]

    # Square results
    cross_corr_sq = cross_corr**2
    
    return cross_corr_sq, swfs_b, swfs_b
    
# Store the coherence spectra of each pair of co-occuring events
cross_corr_arr = []
for burst_pair_idx, burst_pair in enumerate(cooccuring_events["Ctx-CP"]):
    if (burst_pair_idx+1) % 100 == 0:
        print('Burst pair', burst_pair_idx+1)
        
    cross_corr = get_spike_spike_crosscorr (burst_pair, shuffle=False)
    if cross_corr is not np.nan:
        cross_corr_arr.append(cross_corr)
        
mean_cross_corr = np.mean(cross_corr_arr, axis=0)
err_cross_corr = np.std(cross_corr_arr, axis=0)/len(cross_corr_arr)**0.5

ten_ms = int(SAMPLING_RATE/1000 * 10)
xpos = np.linspace(-ten_ms, ten_ms, len(mean_cross_corr)) / (SAMPLING_RATE) * 1000
        
plt.plot(xpos, mean_cross_corr)
plt.fill_between(xpos, mean_cross_corr+err_cross_corr, mean_cross_corr-err_cross_corr, alpha=0.2)
plt.xlabel('Lag in ms')
plt.ylabel('Squared cross-correlation')
format_plot(plt.gca(), legend=False)
plt.show()

def get_cross_correlation_lag_peak (xpos, cross_corr):
    peak_idxs = signal.find_peaks(cross_corr)[0]
    
    if len(peak_idxs):
        peak_heights = np.array(cross_corr)[peak_idxs]
        max_peak_position = xpos[peak_idxs[np.argmax(peak_heights)]]
        
        return max_peak_position
    else:
        return np.nan
    
lags = [get_cross_correlation_lag_peak (xpos, cross_corr) for cross_corr in cross_corr_arr]
lags = [lag for lag in lags if np.isfinite(lag)]

plt.hist(lags, density=True)
plt.title(f'Lag distribution, mean = {round(np.mean(lags), 2)}ms')
plt.xlabel('Lag in ms')
plt.ylabel('Density')
plt.show()

np.mean(lags), stats.ttest_1samp(lags, 0)

# Spike-LFP coherence

In [None]:
def ResampleLinear1D(original, targetLen):
    original = np.array(original, dtype=np.float)
    index_arr = np.linspace(0, len(original)-1, num=targetLen, dtype=np.float)
    index_floor = np.array(index_arr, dtype=np.int) #Round down
    index_ceil = index_floor + 1
    index_rem = index_arr - index_floor #Remain

    val1 = original[index_floor]
    val2 = original[index_ceil % len(original)]
    interp = val1 * (1.0-index_rem) + val2 * index_rem
    assert(len(interp) == targetLen)
    return interp

# Bin width = 60 time points or 2ms
bin_width = SAMPLING_RATE//1000 *2
# Giving a sampling frequency of 500 Hz
fs = 500
# Use a window of 0.5s
window = 250
# With 0 overlap
overlap = 0

def get_spike_lfp_coherence (burst_pair, shuffle):
    burst_a, burst_b = burst_pair
    b, a, _ = align_overlapping_bursts(burst_a, burst_b)
    
    spikes_b_times, spikes_b_wf = detect_MUA(b, 4, silent=True)
    spikes_b, _ = np.histogram(spikes_b_times, np.arange(0, len(b)+1, bin_width))

    if shuffle:
        np.random.shuffle(spikes_b)
    
    # Filter in 4-80 Hz range
    lfp_a = butter_bandpass_filter(a, 4, 80)
    # Then resample at 500 Hz (same as spike train hist)
    lfp_a = ResampleLinear1D(a, len(spikes_b))
    
    if len(lfp_a) < window or len(spikes_b_times) < 10:
        return np.nan, np.nan
    else:
        f, Cxy = signal.coherence(lfp_a, spikes_b, fs=fs, nperseg=window, noverlap=overlap)
        f, Cxy = get_psd_in_range((f, Cxy), [4, 80])
        return f, Cxy

# Store the coherence spectra of each pair of co-occuring events
Cxy_arr = []
for burst_pair_idx, burst_pair in enumerate(cooccuring_events["Ctx-CP"]):
    if (burst_pair_idx+1) % 100 == 0:
        print('Burst pair', burst_pair_idx+1)
        
    f, Cxy = get_spike_lfp_coherence (burst_pair, shuffle=False)
    if Cxy is not np.nan:
        Cxy_arr.append(Cxy)
    
    
# Cxy_arr
Cxy_arr_shuffled = []
for shuffle_iteration in range(1000):
    print('Shuffle', shuffle_iteration + 1)
    Cxy_arr_shuffled_iteration = []
    
    for burst_pair_idx, burst_pair in enumerate(cooccuring_events["Ctx-CP"]):
        f, Cxy = get_spike_lfp_coherence (burst_pair, shuffle=True)
        if Cxy is not np.nan:
            Cxy_arr_shuffled_iteration.append(Cxy)
        
    Cxy_arr_shuffled.append( np.mean(Cxy_arr_shuffled_iteration, axis=0) )
    
y = np.mean(Cxy_arr, axis=0)
err = np.std(Cxy_arr, axis=0)/len(Cxy_arr)**0.5
plt.plot(f, y)
plt.fill_between(f, y+err, y-err, alpha=0.5)

y = np.nanpercentile(Cxy_arr_shuffled, 95, axis=0)
plt.plot(f, y, label='Shuffled')

y = np.nanpercentile(Cxy_arr_shuffled, 5, axis=0)
plt.plot(f, y, label='Shuffled')

plt.legend()
plt.show()

# Proportion of co-occuring events

In [None]:
values = {
    '5-6': [],
    '7-8': [],
    '9-10': [],
    '11-12': []
}

for area_name, v in proportion_cooccuring_events.items():
    for k, i in v.items():
        values[k] += i
    
print(' 5-6 vs. 7-8', stats.ttest_ind(values['5-6'], values['7-8'], equal_var=False))
print(' 7-8 vs. 9-10', stats.ttest_ind(values['7-8'], values['9-10'], equal_var=False))
print(' 9-10 vs. 11-12', stats.ttest_ind(values['9-10'], values['11-12'], equal_var=False))

In [None]:
# Aggregate proportions (over all ages)
bar_means = []
bar_stderrs = []
bar_labels = []
    
for comparison in proportion_cooccuring_events.keys():
    data = sum([val for val in proportion_cooccuring_events[comparison].values()], [])
    mean = np.mean(data)
    stderr = np.std(data) / len(data)
    
    bar_means.append(mean)
    bar_stderrs.append(stderr)
    bar_labels.append(comparison)
xpos = np.arange(len(bar_means))    
        
plt.bar(xpos, bar_means, yerr=bar_stderrs, capsize=5, color=["tab:blue", "tab:orange", 'tab:green', 'tab:red'])
plt.xticks(xpos, bar_labels)
plt.ylabel('Proportion of total events')
format_plot(plt.gca(), legend=False)
plt.show()

# Bonferonni correction for multiple comparisons
correction_val = 6
for comparison_a in proportion_cooccuring_events.keys():
    data_a = sum([val for val in proportion_cooccuring_events[comparison_a].values()], [])
    for comparison_b in proportion_cooccuring_events.keys():
        if comparison_a == comparison_b:
            continue
        
        data_b = sum([val for val in proportion_cooccuring_events[comparison_b].values()], [])
        
        t, p = stats.ttest_ind(data_a, data_b, equal_var=False)
        if p*correction_val < 0.05:
            disp = '{} (mn={}, std={}) vs {} (mn={}, std={})\nt={}, p corrected={}, dof={}\n\n'.format(
                comparison_a, np.mean(data_a), np.std(data_a),
                comparison_b, np.mean(data_b), np.std(data_b),
                t, p*correction_val, welch_dof(data_a, data_b)
            )
            print(disp)

            
# By age
for comparison in proportion_cooccuring_events.keys():
    data_by_age = proportion_cooccuring_events[comparison]
    
    labels = []
    means = []
    stderrs = []
    
    for data, label in zip(data_by_age.values(), data_by_age.keys()):
        if len(data):
            means.append(np.mean(data))
            stderrs.append(np.std(data) / len(data)**0.5)
            labels.append(label)
          
    xpos = np.arange(len(labels))
    means = np.array(means)
    stderrs = np.array(stderrs)
        
    plt.plot(xpos, means, label=comparison)
    plt.fill_between(xpos, means-stderrs, means+stderrs, alpha=0.25)
    plt.xticks(xpos, labels)
    
plt.xlabel('Age')
plt.ylabel('Proportion of total events')
plt.legend(bbox_to_anchor=(1,1), loc="upper left")
format_plot(plt.gca(), legend=True)
plt.show()

print(twoway_anova(proportion_cooccuring_events, ['burst_pair', 'age', 'values']))

In [None]:
normalized_burst_frequency = {
    'both_ngb': {
        'Thal-CP': [],
        'Ctx-CP': [],
        'Ctx-thal': []
    },
    'both_sb': {
        'Thal-CP': [],
        'Ctx-CP': [],
        'Ctx-thal': []
    },
    'mixed': {
        'Thal-CP': [],
        'Ctx-CP': [],
        'Ctx-thal': []
    }
}

In [None]:
key = "Ctx-CP"

for recording in rms_processed_recordings:
    if not "thalamus_bursts" in recording:
        continue
    
    recording_ngb = []
    recording_sb = []
    
    recording_ngb_pairs = []
    recording_sb_pairs = []
    recording_mixed_pairs = []

    # By burst type
    for burst_pair in cooccuring_events[key]:
        if not burst_pair[0] in recording["cortex_bursts"]:
            continue

        if False:
            ctx_burst, thal_burst = burst_pair
            if ctx_burst in cortex_ngb and thal_burst in thalamus_ngb:
                recording_ngb_pairs.append(burst_pair)
                recording_ngb.append(ctx_burst)
                recording_ngb.append(thal_burst)
            elif ctx_burst in cortex_sb and thal_burst in thalamus_sb:
                recording_sb_pairs.append(burst_pair)
                recording_sb.append(ctx_burst)
                recording_sb.append(thal_burst)
            elif ctx_burst in cortex_ngb and thal_burst in thalamus_sb:
                recording_mixed_pairs.append(burst_pair)
                recording_ngb.append(ctx_burst)
                recording_sb.append(thal_burst)
            elif ctx_burst in cortex_sb and thal_burst in thalamus_ngb:
                recording_mixed_pairs.append(burst_pair) 
                recording_sb.append(ctx_burst)
                recording_ngb.append(thal_burst)

        elif False:
            thal_burst, cp_burst = burst_pair
            if thal_burst in thalamus_ngb and cp_burst in striatum_ngb:
                recording_ngb_pairs.append(burst_pair)
                recording_ngb.append(thal_burst)
                recording_ngb.append(cp_burst)
            elif thal_burst in thalamus_sb and cp_burst in striatum_sb:
                recording_sb_pairs.append(burst_pair)
                recording_sb.append(thal_burst)
                recording_sb.append(cp_burst)
            elif thal_burst in thalamus_ngb and cp_burst in striatum_sb:
                recording_mixed_pairs.append(burst_pair)
                recording_ngb.append(thal_burst)
                recording_sb.append(cp_burst)
            elif thal_burst in thalamus_sb and cp_burst in striatum_ngb:
                recording_mixed_pairs.append(burst_pair) 
                recording_sb.append(thal_burst)
                recording_ngb.append(cp_burst)
        elif True:
            ctx_burst, cp_burst = burst_pair
            if ctx_burst in cortex_ngb and cp_burst in striatum_ngb:
                recording_ngb_pairs.append(burst_pair)
                recording_ngb.append(ctx_burst)
                recording_ngb.append(cp_burst)
            elif ctx_burst in cortex_sb and cp_burst in striatum_sb:
                recording_sb_pairs.append(burst_pair)
                recording_sb.append(ctx_burst)
                recording_sb.append(cp_burst)
            elif ctx_burst in cortex_ngb and cp_burst in striatum_sb:
                recording_mixed_pairs.append(burst_pair)
                recording_ngb.append(ctx_burst)
                recording_sb.append(cp_burst)
            elif ctx_burst in cortex_sb and cp_burst in striatum_ngb:
                recording_mixed_pairs.append(burst_pair) 
                recording_sb.append(ctx_burst)
                recording_ngb.append(cp_burst)

            
    if len(recording_ngb_pairs):
        normalized_burst_frequency['both_ngb'][key].append( len(recording_ngb_pairs) / len(recording_ngb) )
    if len(recording_sb_pairs):
        normalized_burst_frequency['both_sb'][key].append( len(recording_sb_pairs) / len(recording_sb) )
    if len(recording_mixed_pairs):
        normalized_burst_frequency['mixed'][key].append( len(recording_mixed_pairs) / (len(recording_ngb)+len(recording_sb)) )

ngb_pairs = normalized_burst_frequency['both_ngb'][key]
sb_pairs = normalized_burst_frequency['both_sb'][key]
mixed_pairs = normalized_burst_frequency['mixed'][key]
        
xpos = np.arange(3)
yvals = [np.mean(ngb_pairs), np.mean(sb_pairs), np.mean(mixed_pairs)]
yerrs = [
    np.std(ngb_pairs)/len(ngb_pairs)**0.5,
    np.std(sb_pairs)/len(sb_pairs)**0.5,
    np.std(mixed_pairs)/len(mixed_pairs)**0.5
]

print('Mean =', yvals)
print('Std =', [np.std(ngb_pairs), np.std(sb_pairs), np.std(mixed_pairs)])
print('NGB vs SB', stats.ttest_ind(ngb_pairs, sb_pairs, equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(ngb_pairs, sb_pairs))
print('DOF = ', welch_dof(ngb_pairs, sb_pairs))


print('\nNGB vs mixed', stats.ttest_ind(ngb_pairs, mixed_pairs, equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(ngb_pairs, mixed_pairs))
print('DOF = ', welch_dof(ngb_pairs, mixed_pairs))

print('\nSB vs mixed', stats.ttest_ind(sb_pairs, mixed_pairs, equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(sb_pairs, mixed_pairs))
print('DOF = ', welch_dof(sb_pairs, mixed_pairs))

In [None]:
print(twoway_anova(normalized_burst_frequency, ['burst_pair', 'region', 'value']))

xpos = np.array([1, 2, 3])

region_colors = ['tab:orange', 'tab:blue', 'tab:green']
fig = plt.figure(dpi=100)

for region_idx, region_key in enumerate(['Ctx-CP', 'Thal-CP', 'Ctx-thal']):
    mns = []
    errs = []
    for burst_pair_key in ['both_ngb', 'both_sb', 'mixed']:
        vals = normalized_burst_frequency[burst_pair_key][region_key]

        mns.append(np.mean(vals))
        errs.append(np.std(vals)/len(vals)**0.5)
        
    xpos = np.array([1, 3, 5]) + (region_idx-1)/3
    plt.bar(xpos, mns, yerr=errs, width=1/3, label=region_key, capsize=3, facecolor=region_colors[region_idx])

plt.xticks([1, 3, 5], ['Both NGB', 'Both SB', 'NGB and SB'])
plt.ylabel('Normalized frequency')
plt.legend()
format_plot(plt.gca(), legend=True)
plt.show()

In [None]:
normalized_burst_frequency_pairs = {
    'both_ngb': [],
    'both_sb': [],
    'mixed': []
}
for burst_pair in normalized_burst_frequency:
    burst_pair_vals = normalized_burst_frequency[burst_pair]
    
    for region in burst_pair_vals:
        region_vals = burst_pair_vals[region]
        normalized_burst_frequency_pairs[burst_pair] += region_vals
        
ngb_pairs = normalized_burst_frequency_pairs['both_ngb']
sb_pairs = normalized_burst_frequency_pairs['both_sb']
mixed_pairs = normalized_burst_frequency_pairs['mixed']
        
print('Mean =', [np.mean(ngb_pairs), np.mean(sb_pairs), np.mean(mixed_pairs)])
print('Std =', [np.std(ngb_pairs), np.std(sb_pairs), np.std(mixed_pairs)])
print('NGB vs SB', stats.ttest_ind(ngb_pairs, sb_pairs, equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(ngb_pairs, sb_pairs))
print('DOF = ', welch_dof(ngb_pairs, sb_pairs))

print('\nNGB vs mixed', stats.ttest_ind(ngb_pairs, mixed_pairs, equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(ngb_pairs, mixed_pairs))
print('DOF = ', welch_dof(ngb_pairs, mixed_pairs))

print('\nSB vs mixed', stats.ttest_ind(sb_pairs, mixed_pairs, equal_var=False))
print('Cohen\'s d =', cohen_d_for_welch(sb_pairs, mixed_pairs))
print('DOF = ', welch_dof(sb_pairs, mixed_pairs))

# Cross-region spiking activity

In [None]:
def load_data (recording, channel):
    data_range = get_slice_from_s(0, 60*60)
    recording_n = recording["recording"]
    striatum_channel_n = recording["striatum_channel"]
    thalamus_channel_n = recording["thalamus_channel"] if recording["thalamus"] else 0
    cortex_channel_n = recording["cortex_channel"]

    session = Session(ROOT + recording["path"])
    
    if channel == 'striatum':
        return session.recordings[recording_n].continuous[0].samples[data_range, striatum_channel_n]
    elif channel == 'thalamus':
        return session.recordings[recording_n].continuous[0].samples[data_range, thalamus_channel_n]
    elif channel == 'cortex':
        return session.recordings[recording_n].continuous[0].samples[data_range, cortex_channel_n]
    
def does_chunk_contain_bursts (bursts, chunk_times):
    c0, c1 = chunk_times
    for burst in bursts:
        b0, b1 = burst.time        

        ret = max(c0,b0) <= min(c1,b1)
        if ret == True:
            return True
    return False

In [None]:
# Get proportion of burst events that have co-occuring spiking elevated above baseline

key = "Thal-CP"
thresh = 2.5
min_spks = 0

proportion_cooccuring_spikes = []

for recording_idx, recording in enumerate(rms_processed_recordings):
    print(recording['path'], recording_idx, '/', len(rms_processed_recordings))
    
    if not "thalamus_bursts" in recording:
        continue
    
    cortex_data = load_data (recording, 'thalamus')
    striatum_data = load_data (recording, 'striatum')
    
    co_spiking_count = 0
    tot_count = 0

    # By burst type
    for burst_pair in cooccuring_events[key]:
        if not burst_pair[0] in recording["thalamus_bursts"]:
            continue
            
        ctx_burst, cp_burst = burst_pair
        one_s = SAMPLING_RATE

        len_ctx = len(ctx_burst.data)
        baseline_ctx = cortex_data[ctx_burst.time[0]-len_ctx-one_s:ctx_burst.time[0]-one_s]
        if not len(baseline_ctx):
            continue

        len_cp = len(cp_burst.data)
        baseline_cp = striatum_data[cp_burst.time[0]-len_cp-one_s:cp_burst.time[0]-one_s]
        if not len(baseline_cp):
            continue

        
        spike_times_ctx, _ = detect_MUA (ctx_burst.data, 5, silent=True)
        spike_times_ctx_baseline, _ = detect_MUA (baseline_ctx, 5, silent=True)
        
        spike_times_cp, _ = detect_MUA (cp_burst.data, 5, silent=True)
        spike_times_cp_baseline, _ = detect_MUA (baseline_cp, 5, silent=True)

        if (len(spike_times_ctx)>=min_spks and len(spike_times_cp)>=min_spks):
            if (len(spike_times_ctx)>thresh*len(spike_times_ctx_baseline)) and \
                    (len(spike_times_cp)>thresh*len(spike_times_cp_baseline)): 
                co_spiking_count += 1
            if (len(spike_times_ctx)>thresh*len(spike_times_ctx_baseline)) or \
                (len(spike_times_cp)>thresh*len(spike_times_cp_baseline)): 
                tot_count += 1

    if tot_count:
        proportion_cooccuring_spikes.append(co_spiking_count/tot_count)
    else:
        proportion_cooccuring_spikes.append(np.nan)
    
print(np.nanmean(proportion_cooccuring_spikes))
print(proportion_cooccuring_spikes)

In [None]:
# Now plot burst results

grouped_spks = {
    '5-6': [],
    '7-8': [],
    '9-10': [],
    '11-12': []
}
filtered_recordings = [r for r in rms_processed_recordings]# if 'thalamus_bursts' in r]

for prop_spks, recording in zip(proportion_cooccuring_spikes, filtered_recordings):
    prop_spks = prop_spks*100
    if recording['age'] == 5 or recording['age'] == 6:
        grouped_spks['5-6'].append(prop_spks)
    elif recording['age'] == 7 or recording['age'] == 8:
        grouped_spks['7-8'].append(prop_spks)
    elif recording['age'] == 9 or recording['age'] == 10:
        grouped_spks['9-10'].append(prop_spks)
    elif recording['age'] == 11 or recording['age'] == 12:
        grouped_spks['11-12'].append(prop_spks)

mn = [np.nanmean(g) for g in grouped_spks.values()]
er = [np.nanstd(g)/len(g)**0.5 for g in grouped_spks.values()]

plt.errorbar(grouped_spks.keys(), mn, er)
plt.show()

In [None]:
# Get proportion of non-burst events that have co-occuring spiking elevated above baseline

key = "Thal-CP"
thresh = 2.5
min_spks = 0

proportion_cooccuring_spikes_nonburst = []

for recording_idx, recording in enumerate(rms_processed_recordings):
    if not "thalamus_bursts" in recording:
        continue

    chunk_len = np.mean([b.time[1]-b.time[0] for b in recording["thalamus_bursts"]])/SAMPLING_RATE

    print(recording['path'], recording_idx, '/', len(rms_processed_recordings))
        
    print('\tLoading cortex data')
    cortex_data = load_data (recording, 'thalamus')[:]
    print('\tLoading thalamus data')
    striatum_data = load_data (recording, 'striatum')[:]
    
    co_spiking_count = 0
    tot_count = 0
    
    for t in range(0, len(cortex_data), int(chunk_len*SAMPLING_RATE)):
        ctx_chunk = cortex_data[t:t+int(chunk_len*SAMPLING_RATE)]
        ctx_basel = cortex_data[t-int((chunk_len+1)*SAMPLING_RATE):t-SAMPLING_RATE]
        
        cp_chunk = striatum_data[t:t+int(chunk_len*SAMPLING_RATE)]
        cp_basel = striatum_data[t-int((chunk_len+1)*SAMPLING_RATE):t-SAMPLING_RATE]
        
        should_continue = False
        for d in [ctx_chunk, ctx_basel, cp_chunk, cp_basel]:
            if not len(d):
                should_continue = True
        if should_continue:
            continue
            
        if does_chunk_contain_bursts(recording["thalamus_bursts"], [t-(chunk_len+1)*SAMPLING_RATE, t+chunk_len*SAMPLING_RATE]):
            continue
        if does_chunk_contain_bursts(recording["striatum_bursts"], [t-(chunk_len+1)*SAMPLING_RATE, t+chunk_len*SAMPLING_RATE]):
            continue
        
        spike_times_ctx, _ = detect_MUA (ctx_chunk, 5, silent=True)
        spike_times_ctx_baseline, _ = detect_MUA (ctx_basel, 5, silent=True)
        
        spike_times_cp, _ = detect_MUA (cp_chunk, 5, silent=True)
        spike_times_cp_baseline, _ = detect_MUA (cp_basel, 5, silent=True)

        if (len(spike_times_ctx)>=min_spks and len(spike_times_cp)>=min_spks):
            if (len(spike_times_ctx)>thresh*len(spike_times_ctx_baseline)) and \
                    (len(spike_times_cp)>thresh*len(spike_times_cp_baseline)): 
                co_spiking_count += 1
            if (len(spike_times_ctx)>thresh*len(spike_times_ctx_baseline)) or \
                (len(spike_times_cp)>thresh*len(spike_times_cp_baseline)): 
                tot_count += 1

    if tot_count:
        proportion_cooccuring_spikes_nonburst.append(co_spiking_count/tot_count)
    else:
        proportion_cooccuring_spikes_nonburst.append(np.nan)
    
print(np.nanmean(proportion_cooccuring_spikes_nonburst))
print(proportion_cooccuring_spikes_nonburst)

In [None]:
# Now plot non-bursts results

grouped_spks_nonburst = {
    '5-6': [],
    '7-8': [],
    '9-10': [],
    '11-12': []
}
filtered_recordings = [r for r in rms_processed_recordings if 'thalamus_bursts' in r]

for prop_spks, recording in zip(proportion_cooccuring_spikes_nonburst, filtered_recordings):
    prop_spks = prop_spks*100
    if recording['age'] == 5 or recording['age'] == 6:
        grouped_spks_nonburst['5-6'].append(prop_spks)
    elif recording['age'] == 7 or recording['age'] == 8:
        grouped_spks_nonburst['7-8'].append(prop_spks)
    elif recording['age'] == 9 or recording['age'] == 10:
        grouped_spks_nonburst['9-10'].append(prop_spks)
    elif recording['age'] == 11 or recording['age'] == 12:
        grouped_spks_nonburst['11-12'].append(prop_spks)

mn_nonburst = [np.nanmean(g) for g in grouped_spks_nonburst.values()]
er_nonburst = [np.nanstd(g)/len(g)**0.5 for g in grouped_spks_nonburst.values()]

plt.errorbar(grouped_spks.keys(), mn, er, label='Burst events')
plt.errorbar(grouped_spks_nonburst.keys(), mn_nonburst, er_nonburst, label='Non-burst periods')
plt.legend()
plt.xlabel('Age (post-natal days)')
plt.ylabel('% events with spiking activtiy\nacross both regions')
format_plot(plt.gca(), legend=True)
plt.show()

mn, mn_nonburst

# Cross-spectral coherence over time

In [None]:
# Cross-spectral coherence over time

def align_overlapping_bursts (burst_a, burst_b):
    start_overlap = max(burst_a.time[0], burst_b.time[0])
    end_overlap = min(burst_a.time[1], burst_b.time[1])
    
    
    start_a = start_overlap-burst_a.time[0]
    end_a = start_a + (end_overlap-start_overlap)
    start_b = start_overlap-burst_b.time[0]
    end_b = start_b + (end_overlap-start_overlap)
    
    burst_a_trimmed = burst_a.data[start_a:end_a]
    burst_b_trimmed = burst_b.data[start_b:end_b]
    
    return (
        burst_a_trimmed,
        burst_b_trimmed,
        get_xticks(slice(start_overlap, end_overlap))
    )

csc_over_time_ngb = {
    '5-6': [],
    '7-8': [],
    '9-10': [],
    '11-12': []
}
csc_over_time_sb = {
    '5-6': [],
    '7-8': [],
    '9-10': [],
    '11-12': []
}

window_size = 0.5
window = int(SAMPLING_RATE*window_size)
overlap = int(SAMPLING_RATE*window_size*0)

for burst_pair_idx, burst_pair in enumerate(cooccuring_events["Ctx-thal"]):
    burst_a, burst_b = burst_pair
    if not hasattr(burst_a, 'age') or not hasattr(burst_b, 'age'):
        continue
        
    if burst_a.age < 7:
        age = '5-6'
    elif burst_a.age >= 7 and burst_a.age < 9:
        age = '7-8'
    elif burst_a.age >= 9 and burst_a.age < 11:
        age = '9-10'
    elif burst_a.age >= 11 and burst_a.age < 13:
        age = '11-12'
    else:
        continue
        
    a, b, _ = align_overlapping_bursts(burst_a, burst_b)
    
    if (burst_pair_idx+1) % 100 == 0:
        print('Burst pair', burst_pair_idx+1)

    a = butter_bandpass_filter(a, 4, 80)
    b = butter_bandpass_filter(b, 4, 80)
    
    if len(a)/SAMPLING_RATE < 0.5:
        continue

    f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
    f, Cxy = get_psd_in_range((f, Cxy), [4, 80])
    
    if np.mean(Cxy) > 0.8:
        continue

    if burst_a in cortex_ngb and burst_b in thalamus_ngb:
        csc_over_time_ngb[age].append(np.mean(Cxy))
    elif burst_a in cortex_sb and burst_b in thalamus_sb:
        csc_over_time_sb[age].append(np.mean(Cxy))
        
x = np.arange(4)
x_ticks = [k for k in csc_over_time_ngb.keys()]

y_mn_ngb = [np.mean(v) for v in csc_over_time_ngb.values()]
y_er_ngb = [np.std(v)/len(v)**0.5 for v in csc_over_time_ngb.values()]

y_mn_sb = [np.mean(v) for v in csc_over_time_sb.values()]
y_er_sb = [np.std(v)/len(v)**0.5 for v in csc_over_time_sb.values()]

plt.errorbar(x, y_mn_sb, yerr=y_er_sb, label='SB', c='tab:red')
plt.errorbar(x, y_mn_ngb, yerr=y_er_ngb, label='NGB', c='tab:blue')
plt.legend()
plt.xticks(x, x_ticks)
plt.xlabel('Age (postnatal days)')
plt.ylabel('Mean CSC')
format_plot(plt.gca(), legend=True)
plt.show()

csc_over_time_data = {
    'NGB': csc_over_time_ngb,
    'SB': csc_over_time_sb
}
print(twoway_anova(csc_over_time_data, ['burst_type', 'age', 'value']))

print('NGB 5-6 mean', np.mean(csc_over_time_ngb['5-6']), 'std', np.std(csc_over_time_ngb['5-6']))
print('NB 11-12 mean', np.mean(csc_over_time_ngb['11-12']), 'std', np.std(csc_over_time_ngb['11-12']))
print('SB 5-6 mean', np.mean(csc_over_time_sb['5-6']), 'std', np.std(csc_over_time_sb['5-6']))
print('SB 11-12 mean', np.mean(csc_over_time_sb['11-12']), 'std', np.std(csc_over_time_sb['11-12']))

# Cross-spectal coherence

In [None]:
def align_overlapping_bursts_full (burst_a, burst_b, burst_a_channel_key, burst_b_channel_key):
    start = min(burst_a.time[0], burst_b.time[0])
    end = max(burst_a.time[1], burst_b.time[1])
    
    time_slice = slice(start, end)
    
    for recording in RECORDINGS:
        recording_n = recording["recording"]
        
        if recording["path"] == burst_a.recording_path:
            session = Session(ROOT + recording["path"])
            data_a = session.recordings[recording_n].continuous[0].samples[time_slice, recording[burst_a_channel_key]]
            
        if recording["path"] == burst_b.recording_path:
            session = Session(ROOT + recording["path"])
            data_b = session.recordings[recording_n].continuous[0].samples[time_slice, recording[burst_b_channel_key]]
    
    return data_a, data_b

window_size = 0.5
window = int(SAMPLING_RATE*window_size)
overlap = int(SAMPLING_RATE*window_size*0)

Cxy_arr_shuffled_ngb = []
Cxy_arr_ngb = []

Cxy_arr_shuffled_sb = []
Cxy_arr_sb = []

mean_coherence_ngb = []
mean_coherence_shuffled_ngb = []

mean_coherence_sb = []
mean_coherence_shuffled_sb = []

bursts_a_ngb = []
bursts_b_ngb = []

bursts_a_sb = []
bursts_b_sb = []

for burst_pair_idx, burst_pair in enumerate(cooccuring_events["Ctx-CP"]):
    burst_a, burst_b = burst_pair
    if not hasattr(burst_a, 'age') or not hasattr(burst_b, 'age'):
        continue
        
    a, b, _ = align_overlapping_bursts(burst_a, burst_b) #, "cortex_channel", "striatum_channel")
    
    if (burst_pair_idx+1) % 100 == 0:
        print('Burst pair', burst_pair_idx+1)

    a = butter_bandpass_filter(a, 4, 80)
    b = butter_bandpass_filter(b, 4, 80)
    
    if len(a)/SAMPLING_RATE < 0.5:
        continue

    if burst_a in cortex_ngb and burst_b in striatum_ngb:     
        #f, Cxy = imag_coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
        f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)

        if np.mean(Cxy) <= 0.8:        
            f, Cxy = get_psd_in_range((f, Cxy), [4, 80])

            Cxy_arr_ngb.append(Cxy)
            mean_coherence_ngb.append(np.mean(Cxy))
            bursts_a_ngb.append(a)
            bursts_b_ngb.append(b)
    elif burst_a in cortex_sb and burst_b in striatum_sb:
        #f, Cxy = imag_coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
        f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
        f, Cxy = get_psd_in_range((f, Cxy), [4, 80])

        if np.mean(Cxy) <= 0.8:
            Cxy_arr_sb.append(Cxy)
            mean_coherence_sb.append(np.mean(Cxy))
            bursts_a_sb.append(a)
            bursts_b_sb.append(b)

for i in range(1000):
    if (i+1) % 10 == 0:
        print('Shuffle iteration', i+1)
    
    burst_pairs_shuffled_ngb = np.array([bursts_a_ngb, bursts_b_ngb], dtype=object).T
    burst_pairs_shuffled_sb = np.array([bursts_a_sb, bursts_b_sb], dtype=object).T

    Cxy_shuffled_inner_ngb = []
    Cxy_shuffled_inner_sb = []
    
    for burst_pair in burst_pairs_shuffled_ngb:
        a, b = np.random.permutation(burst_pair[0]), np.random.permutation(burst_pair[1])
        
        f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
        f, Cxy = get_psd_in_range((f, Cxy), [4, 80])
            
        Cxy_shuffled_inner_ngb.append(Cxy)
        
    for burst_pair in burst_pairs_shuffled_sb:        
        a, b = np.random.permutation(burst_pair[0]), np.random.permutation(burst_pair[1])
         
        f, Cxy = imag_coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
        f, Cxy = get_psd_in_range((f, Cxy), [4, 80])
        
        Cxy_shuffled_inner_sb.append(Cxy)
    
    Cxy_shuffled_inner_ngb_mean = np.mean(Cxy_shuffled_inner_ngb, axis=0)
    Cxy_shuffled_inner_sb_mean = np.mean(Cxy_shuffled_inner_sb, axis=0)
    
    Cxy_arr_shuffled_ngb.append(Cxy_shuffled_inner_ngb_mean)
    Cxy_arr_shuffled_sb.append(Cxy_shuffled_inner_sb_mean)
    
    mean_coherence_shuffled_ngb.append(np.mean(Cxy_shuffled_inner_ngb_mean))
    mean_coherence_shuffled_sb.append(np.mean(Cxy_shuffled_inner_sb_mean))
    
# Bar plot
y = [
    np.mean(mean_coherence_ngb),
    np.mean(mean_coherence_sb),
]
yerr = [
    np.std(mean_coherence_ngb)/len(mean_coherence_ngb)**0.5,
    np.std(mean_coherence_sb)/len(mean_coherence_sb)**0.5,
]
colors = [
    "tab:blue",
    "tab:red",
]
bars = plt.bar(["NGB", "SB"], y, yerr=yerr, capsize=5)
for idx, bar_i in enumerate(bars):
    bar_i.set_color(colors[idx])
plt.ylabel('Mean coherence')
format_plot(plt.gca(), legend=False)
plt.show()
print('ngb versus sb', stats.ttest_ind(mean_coherence_ngb, mean_coherence_sb, equal_var=False))
print('ngb', stats.ttest_ind(mean_coherence_ngb, mean_coherence_shuffled_ngb, equal_var=False))
print('sb', stats.ttest_ind(mean_coherence_sb, mean_coherence_shuffled_sb, equal_var=False))

# Coherence plot
mean_ngb = np.nanmean(Cxy_arr_ngb, axis=0)
stderr_ngb = np.nanstd(Cxy_arr_ngb, axis=0) / len(Cxy_arr_ngb)**0.5
mean_sb = np.nanmean(Cxy_arr_sb, axis=0)
stderr_sb = np.nanstd(Cxy_arr_sb, axis=0) / len(Cxy_arr_ngb)**0.5

plt.plot(f, mean_ngb, label='NGB', c='tab:blue')
plt.fill_between(f, mean_ngb+stderr_ngb, mean_ngb-stderr_ngb, alpha=0.5, facecolor='tab:blue')

shuffle_idx_ngb = np.where(mean_coherence_shuffled_ngb == np.percentile(mean_coherence_shuffled_ngb, 95, interpolation='nearest'))[0][0]
shuffled_95_percentile_ngb = Cxy_arr_shuffled_ngb[shuffle_idx_ngb]
plt.plot(f, shuffled_95_percentile_ngb, '--', label='Shuffled NGB', c='tab:blue')

plt.xlabel('Frequency (Hz)')
plt.ylabel('Coherence')
plt.legend()
format_plot(plt.gca())
plt.show()

plt.plot(f, mean_sb, label='SB', c='tab:red')
plt.fill_between(f, mean_sb+stderr_sb, mean_sb-stderr_sb, alpha=0.5, facecolor='tab:red')

shuffle_idx_sb = np.where(mean_coherence_shuffled_sb == np.percentile(mean_coherence_shuffled_sb, 95, interpolation='nearest'))[0][0]
shuffled_95_percentile_sb = Cxy_arr_shuffled_sb[shuffle_idx_sb]
plt.plot(f, shuffled_95_percentile_sb, '--', label='Shuffled SB', c='tab:red')

plt.xlabel('Frequency (Hz)')
plt.ylabel('Coherence')
plt.legend()
format_plot(plt.gca())

print('SB', f[np.argmax(mean_sb)])
print('NGB', f[np.argmax(mean_ngb)])

# Cross-correlation lag of instantaneous amplitudes

In [None]:
key = "Ctx-CP"

def get_xcorr_null_dist (a, b, iterations=1000):
    max_xcorr = []
    for i in range(iterations):
        shift = int(np.random.uniform(1*SAMPLING_RATE, len(a)))
    
        a_copy = a[shift:]
        b_copy = b[:-shift]
        
        cross_corr = correlate_template(
            a_copy, b_copy, mode='full', normalize=None, demean=False
        )**2
        
        two_hundred_ms = int(SAMPLING_RATE*0.2)

        max_xcorr.append(np.max(cross_corr))
    return max_xcorr

max_lag = []

for recording_idx, recording in enumerate(rms_processed_recordings[-4:]):
    print(recording['path'], recording_idx, '/', len(rms_processed_recordings))
    
    #if not "thalamus_bursts" in recording:
    #    continue
    
    region_a_data = load_data (recording, 'cortex')
    region_b_data = load_data (recording, 'striatum')
    
    burst_count = 0
    total_bursts = len([b for b in cooccuring_events[key] if b[0] in recording["cortex_bursts"]])
    
    # By burst type
    for burst_pair in cooccuring_events[key]:
        if not burst_pair[0] in recording["cortex_bursts"]:
            continue
        
        burst_count += 1
        print('\tBurst', burst_count, '/', total_bursts)
        
        burst_a, burst_b = burst_pair
        start_t = min(burst_a.time[0], burst_b.time[0])
        end_t = max(burst_a.time[1], burst_b.time[1])
        
        a_data = region_a_data[start_t:end_t]
        b_data = region_b_data[start_t:end_t]
        
        a_data = butter_bandpass_filter(a_data, 4, 16)
        b_data = butter_bandpass_filter(b_data, 4, 16)
        
        if len(a_data) < SAMPLING_RATE*3:
            continue
    
        analytic_signal_a = signal.hilbert(a_data)
        amplitude_a = np.abs(analytic_signal_a)
        amplitude_a = amplitude_a - np.mean(amplitude_a)

        analytic_signal_b = signal.hilbert(b_data)
        amplitude_b = np.abs(analytic_signal_b)
        amplitude_b = amplitude_b - np.mean(amplitude_b)

        cross_corr = correlate_template(
            amplitude_a, amplitude_b, mode='full', normalize=None, demean=False
        )**2
        
        two_hundred_ms = int(SAMPLING_RATE*0.2)
        start_t_xcorr = (len(cross_corr)-two_hundred_ms)//2
        
        lags = np.arange(-len(cross_corr)//2, len(cross_corr))/SAMPLING_RATE
        lags_centre = lags[start_t_xcorr:start_t_xcorr+two_hundred_ms]
        cross_corr_centre = cross_corr[start_t_xcorr:start_t_xcorr+two_hundred_ms]
                
        max_xcorr = np.max(cross_corr_centre)
        null_dist = get_xcorr_null_dist(amplitude_a, amplitude_b)
                
        if max_xcorr> np.percentile(null_dist, 97.5):
            max_lag.append(lags_centre[np.argmax(cross_corr_centre)])
            
    plt.hist(max_lag)
    plt.show()
    print(np.mean(max_lag)*1000)

# Cross-correlation lag

In [None]:
FIG_ROOT = ''

# Get cross-spectral coherence for burst pairs    
window_size = 0.5
window = int(SAMPLING_RATE*window_size)
overlap = int(SAMPLING_RATE*window_size*0)

xpos_arr = None
cross_corr_arr_test = []
peaks_arr_test = []

cross_corr_arr_control = []
peaks_arr_control = []

bursts_a = []
bursts_b = []

events = np.concatenate([cooccuring_events['Ctx-CP']])

for burst_pair_idx, burst_pair in enumerate(events):
    if (burst_pair_idx+1) % 10 == 0:
        print( 'Burst pair {}/{}'.format(burst_pair_idx+1, len(events) ))

    burst_a, burst_b = burst_pair
        
    if not (burst_a in cortex_ngb and burst_b in striatum_ngb):
        continue
    if burst_a.age > 8:
        continue
        
    a, b, ticks = align_overlapping_bursts (burst_a, burst_b)
    
    bursts_a.append(a)
    bursts_b.append(b)
    
    if len(a) < 0.5*SAMPLING_RATE:
        continue

    f, Cxy = signal.coherence(a, b, fs=SAMPLING_RATE, nperseg=window, noverlap=overlap)
    f, Cxy = get_psd_in_range((f, Cxy), [4, 80])
    
    if np.mean(Cxy) > 0.8:
        continue
            
    xpos, cross_corr = get_lags(a, b, [16, 40], prewhiten=True)
    xpos_arr = xpos
    
    cross_corr_arr_test.append(cross_corr)
    
    peaks_arr_test.append(xpos[np.argmax(cross_corr)])
    
    xpos, cross_corr = get_lags(a, b, [4, 16], prewhiten=True)
    xpos_arr = xpos
    cross_corr_arr_control.append(cross_corr)
    peaks_arr_control.append(xpos[np.argmax(cross_corr)])
    
    
means_test = np.mean(cross_corr_arr_test, axis=0)
stderrs_test = np.std(cross_corr_arr_test, axis=0) / len(cross_corr_arr_test)**0.5
means_control = np.mean(cross_corr_arr_control, axis=0)
stderrs_control = np.std(cross_corr_arr_control, axis=0) / len(cross_corr_arr_control)**0.5




# Line plot
plt.plot(xpos_arr, means_test, label='16-40 Hz', c='tab:purple')
plt.fill_between(xpos_arr, means_test-stderrs_test, means_test+stderrs_test, alpha=0.25, facecolor='tab:purple')
means_test_max = np.max(means_test)

plt.plot(xpos_arr, means_control, label='4-16 Hz', c='tab:cyan')
plt.fill_between(xpos_arr, means_control-stderrs_control, means_control+stderrs_control, alpha=0.25, facecolor='tab:cyan')
means_control_max = np.max(means_control)

plt.ylabel('Squared cross-correlation')
plt.xlabel('Lag (ms)')
plt.legend()
format_plot(plt.gca())
plt.show()




# Bar plot
means_test = np.mean(peaks_arr_test, axis=0)
stderrs_test = np.std(peaks_arr_test, axis=0) / len(peaks_arr_test)**0.5
means_control = np.mean(peaks_arr_control, axis=0)
stderrs_control = np.std(peaks_arr_control, axis=0) / len(peaks_arr_control)**0.5

bars = plt.bar([0, 1], [means_test, means_control], yerr=[stderrs_test, stderrs_control], capsize=5, facecolor='dimgray')
bars[0].set_color('tab:purple')
bars[1].set_color('tab:cyan')
plt.xticks([0, 1], ["16-40 Hz", "4-16 Hz"])
plt.ylabel('Lag (ms)')
format_plot(plt.gca(), legend=False)
plt.show()

print('16-40 Hz', stats.ttest_1samp([p for p in peaks_arr_test if p != 100 and p != -100], popmean=0))
print('4-16 Hz', stats.ttest_1samp(peaks_arr_control, popmean=0))

# Histogram
fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True)
axs[0].hist(peaks_arr_test, bins=100);
axs[0].set_title('16-40 Hz filtered, mean lag = {}'.format(np.mean(peaks_arr_test)));
axs[1].hist(peaks_arr_control, bins=100);
axs[1].set_title('4-16 Hz filtered, mean lag = {}'.format(np.mean(peaks_arr_control)));
format_plot(plt.gca(), legend=False)
plt.tight_layout()