In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d 

from spikingDataUtilities import loadMATData, firingRate

import xarray as xr
from frites.conn import conn_dfc, define_windows

# TEST

In [None]:
data = loadMATData('/home/gabricasa/code_repo/BrainHack2026/session1.mat')

In [None]:
firing_rate = firingRate(data['spikes_hpc'][:,0])
gaussian_firing_rate = gaussian_filter1d(firing_rate[:,1],sigma=1.5,axis=0)

#time = int(0.05 * 10000)
#plt.plot(firing_rate[:,0][:time], firing_rate[:,1][:time])
#plt.plot(firing_rate[:,0][:time], gaussian_firing_rate[:time])
#plt.show()

In [None]:
firing_rate_hpc = firingRate(data['spikes_hpc'][:,0], smooth=True)
firing_rate_pfc = firingRate(data['spikes_pfc'][:,0], smooth=True)
firing_rate_nr = firingRate(data['spikes_nr'][:,0], smooth=True)

In [None]:
#plt.plot(firing_rate_hpc[:,0][:500],firing_rate_hpc[:,1][:500])
#plt.plot(firing_rate_pfc[:,0][:500],firing_rate_pfc[:,1][:500])
time_mask = (firing_rate_hpc[:,0] >= data['protocol_times'][0][0]) & (firing_rate_hpc[:,0] <= data['protocol_times'][0][-1])
plt.plot(firing_rate_hpc[:,0][time_mask],firing_rate_hpc[:,1][time_mask])
plt.show()

# Sleep1

In [None]:
firing_rate_hpc = firingRate(data['spikes_hpc'][:,0], smooth=True)
firing_rate_pfc = firingRate(data['spikes_pfc'][:,0], smooth=True)
firing_rate_nr = firingRate(data['spikes_nr'][:,0], smooth=True)

time_mask = (firing_rate_hpc[:,0] >= data['protocol_times'][0][0]) & (firing_rate_hpc[:,0] <= data['protocol_times'][0][-1])
s1_hpc_rate = firing_rate_hpc[:,1][time_mask]
s1_hpc_times = firing_rate_hpc[:,0][time_mask]
s1_pfc_rate = firing_rate_pfc[:,1][time_mask]
s1_pfc_times = firing_rate_pfc[:,0][time_mask]
s1_nr_rate = firing_rate_nr[:,1][time_mask]
s1_nr_times = firing_rate_nr[:,0][time_mask]

In [None]:
# REM intervals within sleep 1, exclude 1st and last REM
# Draw random number of intervals (same as number of ripples in S1) for control, then drawn uniformly within those intervals
rem_intervals = int(np.max(np.where(data['rem'][:,1] <= data['protocol_times'][0][-1])[-1]))-1
rnd_size = len(np.where(data['ripples'] < data['protocol_times'][0][-1])[1])

rnd_indexes = np.random.choice(np.arange(1,rem_intervals), size=rnd_size)
u = np.random.uniform(0,1, size=rnd_size)

starts = data['rem'][rnd_indexes,0]
ends = data['rem'][rnd_indexes,1]
test_times = starts + u * (ends - starts)

In [None]:
# Check activity in the 3 different regions around hippocampal ripples
i, j, k = 0, 0, 0
t_width = 1  # time window around ripple in seconds
sum_hpc = []
sum_nr = []
sum_pfc = []
rnd_sum_hpc = []
rnd_sum_nr = []
rnd_sum_pfc = []

# Create subplots for each region
fig, ax = plt.subplots(1,3, figsize=(15,5))

while data['ripples'][i] < data['protocol_times'][0][-1]:
    red_mask = (s1_hpc_times > data['ripples'][i]-t_width) & (s1_hpc_times < data['ripples'][i]+t_width)
    rnd_mask = (s1_hpc_times > test_times[i]-t_width) & (s1_hpc_times < test_times[i]+t_width)
    #plt.plot(s1_hpc_rate[red_mask])
    sum_hpc.append(s1_hpc_rate[red_mask])
    rnd_sum_hpc.append(s1_hpc_rate[rnd_mask])
    i += 1
avg_hpc = np.mean(sum_hpc, axis=0)
rnd_avg_hpc =  np.mean(rnd_sum_hpc, axis=0)
ax[0].plot(avg_hpc, color='black', linewidth=2)
ax[0].plot(rnd_avg_hpc, color='red', linewidth=2)
ax[0].set_title('HPC')

while data['ripples'][j] < data['protocol_times'][0][-1]:
    red_mask = (s1_pfc_times > data['ripples'][j]-t_width) & (s1_pfc_times < data['ripples'][j]+t_width)
    rnd_mask = (s1_pfc_times > test_times[j]-t_width) & (s1_pfc_times < test_times[j]+t_width)
    #plt.plot(s1_pfc_rate[red_mask])
    sum_pfc.append(s1_pfc_rate[red_mask])
    rnd_sum_pfc.append(s1_pfc_rate[rnd_mask])
    j += 1
avg_pfc = np.mean(sum_pfc, axis=0)
rnd_avg_pfc = np.mean(rnd_sum_pfc, axis=0)
ax[1].plot(avg_pfc, color='black', linewidth=2)
ax[1].plot(rnd_avg_pfc, color='red', linewidth=2)
ax[1].set_title('PFC')

while data['ripples'][k] < data['protocol_times'][0][-1]:
    red_mask = (s1_nr_times > data['ripples'][k]-t_width) & (s1_nr_times < data['ripples'][k]+t_width)
    rnd_mask = (s1_nr_times > test_times[k]-t_width) & (s1_nr_times < test_times[k]+t_width)
    #plt.plot(s1_nr_rate[red_mask])
    sum_nr.append(s1_nr_rate[red_mask])
    rnd_sum_nr.append(s1_nr_rate[rnd_mask])
    k += 1
avg_nr = np.mean(sum_nr, axis=0)
rnd_avg_nr = np.mean(rnd_sum_nr, axis=0)
ax[2].plot(rnd_avg_nr, color='red', linewidth=2)
ax[2].plot(avg_nr, color='black', linewidth=2)
ax[2].set_title('NR')

plt.show()

# Dynamical Functional Connectivity in ripples events

In [None]:
# Export data from MATLAB file 
data = loadMATData('/home/gabricasa/code_repo/BrainHack2026/session1.mat')

# Define regions of interest 
roi_order = ['HPC', 'PFC', 'NR']

task_to_study = 'sleep1'
task_index = data['protocol_names'].index(task_to_study)

bin_lenght = 0.005 # 5 ms bin size for firing rate
t_start = data['protocol_times'][task_index][0]
t_end = data['protocol_times'][task_index][-1]

# Compute firing rates for each region with specified bin size and restrict it to interest task time window
firing_rate_hpc = firingRate(data['spikes_hpc'][:,0], start=t_start, stop=t_end, bin_size=bin_lenght, smooth=True)
firing_rate_pfc = firingRate(data['spikes_pfc'][:,0], start=t_start, stop=t_end, bin_size=bin_lenght, smooth=True)
firing_rate_nr = firingRate(data['spikes_nr'][:,0], start=t_start, stop=t_end, bin_size=bin_lenght, smooth=True)

task_rates = np.array([firing_rate_hpc[:,1], firing_rate_pfc[:,1], firing_rate_nr[:,1]])
ripples_array = data['ripples'][data['ripples'] < data['protocol_times'][task_index][-1]]

print(np.shape(firing_rate_hpc[:,0]), np.shape(task_rates), np.shape(ripples_array))

In [None]:
print(len(firing_rate_hpc[:,1])-np.count_nonzero(firing_rate_hpc[:,1]))
print(len(firing_rate_hpc[:,1])-np.count_nonzero(firing_rate_hpc[:,1]))
print(len(firing_rate_hpc[:,1])-np.count_nonzero(firing_rate_hpc[:,1]))

In [None]:
def plt_firing_chunks(firing_rate_array, t_spans, roi, bin_size):
    lenght = int(t_spans / bin_size)
    progressive_len = 0
    progressive_t = 0
    while progressive_t < np.max(firing_rate_array[:,0]):
        if progressive_t + t_spans > np.max(firing_rate_array[:,0]):
            plt.plot(firing_rate_array[progressive_len:,0], firing_rate_array[progressive_len:,1])
            plt.title(f'Firing rate chunk {roi}')
            plt.show()
        else:
            plt.plot(firing_rate_array[progressive_len:progressive_len+lenght,0], firing_rate_array[progressive_len:progressive_len+lenght,1])
            plt.title(f'Firing rate chunk {roi}')
            plt.show()

        progressive_len += lenght
        progressive_t += t_spans

    return

In [None]:
time_spans = 50 # seconds
plt_firing_chunks(firing_rate_hpc, time_spans, 'HPC', bin_size=bin_lenght)

In [None]:
# Visually check the firing rates in separated windows
time_spans = 15
index_diplay = int(np.max(firing_rate_hpc[:,0]) / (time_spans*bin_lenght))
plt.plot(firing_rate_hpc[:,0], firing_rate_hpc[:,1], label='HPC')
plt.plot(firing_rate_pfc[:,0], firing_rate_pfc[:,1], label='PFC')
plt.plot(firing_rate_nr[:,0], firing_rate_nr[:,1], label='NR')
plt.legend()
plt.show()

In [None]:
# Reshape data to 3D array: (n_trials, n_roi, n_times)
# Consider hippocampal ripple events to be trails
def reshape_data_around_ripples(hpc_firing_times, firing_rates, ripple_times, roi, bin_size, t_window=1.0):
    """
    Reshape input data x to a 3D array with dimensions (n_trials, n_roi, n_times) to perform information based measures.
    Beware that all regions must have the same time axis.
    Input:
        hpc_firing_times: Input time array corresponding to firing rates of hyppocampus (assumeed to be the same for all regions) 
        firing_rates: Input data array, where each element is firing rate of one region (n_roi, task_time)
        ripple_times: List of ripple event times to define trials within the protocol time
        t_window: Time window (in seconds) around each ripple event to consider for each trial
    Output:
        data_3d: ... (n_trials, n_roi, n_times)
    """
    # TODO: Do we need to take track of the bin_size as a "sampling rate" in here?
    times = np.arange(0, (2*t_window) + bin_size, bin_size)
    size_lenght = int(t_window/bin_size)

    data_3d = []
    for t_ripple in ripple_times:
        nearest_idx = np.argmin(np.abs(hpc_firing_times - t_ripple))

        if hpc_firing_times[nearest_idx] < t_window or hpc_firing_times[nearest_idx] > (hpc_firing_times[-1] - t_window):
            continue  # Skip ripples too close to start or end of recording
        
        # Extract firing rates within the time window for each region. 
        # Ensure the all the times have the same lenght by a fixed number of points around given by t_window/bin_size
        trial_data = firing_rates[:, int(nearest_idx - size_lenght) : int(nearest_idx + size_lenght) + 1]
        data_3d.append(trial_data)
    data_3d = np.array(data_3d)

    trials = np.arange(len(ripple_times))
    
    x = xr.DataArray(data_3d, dims=('trials', 'roi', 'times'),
                     coords=(trials, roi, times))
                     
    return x

In [None]:
reshaped_data = reshape_data_around_ripples(firing_rate_hpc[:,0], task_rates, ripples_array, roi_order, bin_size=bin_lenght)
print(reshaped_data.shape)

In [None]:
# Compute dynamical functional connectivity on a sliding window
slwin_len = 0.5 # window of 500 ms
slwin_step = 0.02 # step of 20 ms (480ms overlap)

# TODO: Make it more general without changing it manually everytime
t_window = 1.0
times = np.arange(-t_window, t_window, bin_lenght)

sl_win = define_windows(times, slwin_len=slwin_len, slwin_step=slwin_step)[0]
#print(sl_win)

# compute the DFC on sliding windows
dfc = conn_dfc(reshaped_data, times='times', roi='roi', win_sample=sl_win)

In [None]:
# takes the mean over trials
dfc_m = dfc.mean('trials').squeeze()

# plot the mean over trials
dfc_m.plot.line(x='times', hue='roi')
plt.title(dfc.name), plt.ylabel('DFC')
plt.show()