In [None]:
import os
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt

from scipy.interpolate import interp1d
from scipy.signal import butter, filtfilt, find_peaks
from scipy.ndimage import gaussian_filter1d

from sklearn.preprocessing import scale
import seaborn as sns
from tqdm import tqdm

In [None]:
#Acquisition of data
# Need to do some file sifting bc of incosistent naming
file_data = {"Filename":[], "DataType":[], "Condition":[], "Context":[], "Timepoint":[],  "Subject":[]}
folder = "collected_data"
files = os.listdir(folder)

for file in files:
    if os.path.isdir(os.path.join(folder, file)):
        continue
    if ".DS_Store" in file:  # Check if the file is a .DS_Store file (unwanted system file)
        print(f"Skipping system file: {file}")
        continue  # Skip this iteration and move to the next file
    
    # assign each file to a data type
    if ("DLC" in file) and ("filtered" in file):
        data_type = "dlc"
    elif ("DeepCut" in file) and ("filter" not in file):
        data_type = "dlc_unfiltered"
    elif "alldata" in file:
        data_type = "photometry"
    else:
        data_type = "video_timestamps"
    
    # get the group name, context, and test time point
    if ("CON" in file) or ("Con" in file):
        condition = "con"  # control
    elif "Run" in file:
        condition = "run"  # running
    elif "IRR" in file:
        condition = "irr"  # irradiation
    else:
        raise(NameError)
    
    if ("empty" in file):
        context = "homecage"
    elif "A" in file:
        context = "A"
    elif ("B" in file) and ~("empty" in file):
        context = "B"
    else:
        raise(NameError)
    
    if ("recent" in file) or ("Recent" in file):
        timepoint = "recent"
    elif ("inter" in file) or ("Inter" in file):
        timepoint = "inter"
    elif ("remote" in file) or ("Remote" in file):
        timepoint = "remote"
    else:
        raise(NameError)

    for n in range(1,20):
        subject_tag = f"M{n}"
        if subject_tag in file:
            subject = subject_tag

    file_data["Filename"].append(os.path.join(folder, file))
    file_data["DataType"].append(data_type)
    file_data["Condition"].append(condition)
    file_data["Context"].append(context)
    file_data["Timepoint"].append(timepoint)
    file_data["Subject"].append(subject)

file_df = pd.DataFrame(file_data)
pd.set_option('display.max_rows', None)
file_df

In [None]:
# Peak extraction
def filter_channel_data(pd_data, led_state, region="Region0G", sigma=1.):
    led_state = str(led_state)
    channel = pd_data.query("LedState=={}".format(led_state))[region].values
    ch_ts = pd_data.query("LedState=={}".format(led_state))["Timestamp"].values
    f_ch = interp1d(ch_ts, channel, fill_value="extrapolate")
    ch_interp = f_ch(ts)
    return gaussian_filter1d(ch_interp, sigma)

# specify mouse ID
condition = "con"
timepoint = "recent"
subject = "M1"
fs = 10

for context in ["homecage", "A", "B"]:
    session_files = file_df.query("Condition == @condition and Context == @context and Timepoint == @timepoint and Subject == @subject")

    photometry_file = session_files.query("DataType == 'photometry'")["Filename"].values[0]
    data = pd.read_csv(photometry_file)
    ts = np.arange(data["Timestamp"][0], data["Timestamp"].values[-1], 1/fs)
    
    # filter the original photometry data in each channel
    ch1_interp = filter_channel_data(data, led_state=1)
    ch2_interp = filter_channel_data(data, led_state=2)
        
    # apply the Butterworth high-pass filter to the normalized data within a given time window
    ts, ch1_interp, ch2_interp = ts[start*fs:end*fs], ch1_interp[start*fs:end*fs], ch2_interp[start*fs:end*fs]
    ts_rel = ts - ts[0]
    b, a = butter(2, 0.05, btype="high", fs=10)
    norm = filtfilt(b, a, (ch2_interp - ch1_interp) / ch1_interp)

    # peak extraction
    mad = np.median(np.abs(norm - np.median(norm)))
    thresh1 = np.median(norm) + 2*mad

    point_under_thresh = norm[norm < thresh1]
    mad2 = np.median(np.abs(point_under_thresh - np.median(point_under_thresh)))
    thresh2 = np.median(point_under_thresh) + 2*mad2

    norm_smooth = gaussian_filter1d(norm, 10)
    peaks, peak_data = find_peaks(norm_smooth, height=thresh2)
    
    # visualization
    plt.figure(figsize=(15,5))
    plt.plot(norm, color='#999998', linewidth=1)
    plt.axhline(thresh2, ls="--", c="red")
    plt.scatter(peaks, norm.max()*np.ones(len(peaks)), c="red")
    plt.title(f"{condition, subject, timepoint, context}")

    #Peak count
    peak_count = len(peaks)
    print("Number of peaks:", peak_count)
    print(2*mad2)

In [None]:
from itertools import product

def onset_offset(vect):
      onsets = np.argwhere([(vect[t]==0) & (vect[t+1]==1) for t in range(len(vect)-1)]).ravel()
      offsets = np.argwhere([(vect[t]==1) & (vect[t+1]==0) for t in range(len(vect)-1)]).ravel()
      return onsets, offsets

In [None]:
# Bout extraction
condition = 'con'

fs = 10
filt_sigma = 2.5
contexts = file_df.Context.unique()
timepoints = file_df.Timepoint.unique()
mice = file_df.Subject.unique()

window_size = np.array([-fs*9, fs*3]).astype(int)
windows = np.array([]).reshape(0, window_size[1]-window_size[0])
windows_motion = np.array([]).reshape(0, window_size[1]-window_size[0])
time_index = np.arange(windows.shape[1])/fs + window_size[0]/fs

results = {"mouse":[], "context":[], "timepoint":[], "t":[],  "activity":[], "motion":[]}
results_freezing = {"mouse":[], "context":[], "timepoint":[], "freezing":[], "condition":[]}

df_con = file_df.query("Condition == @condition")
for ix, (context, timepoint, mouse) in tqdm(enumerate(product( contexts, timepoints, mice))):
    query_string = "Context == @context and Timepoint == @timepoint and Subject == @mouse"
    session_files = df_con.query(query_string)
    if len(session_files) == 0:
        continue
    print('processing {} {} {} {}...'.format(condition, timepoint, mouse, context))
    
    # read photometry data
    photometry_file = session_files.query("DataType == 'photometry'")["Filename"].values[0]
    data = pd.read_csv(photometry_file)
    ts = np.arange(data["Timestamp"][0], data["Timestamp"].values[-1], 1/fs)
    
    # filter the original photometry data in each channel
    ch1_interp = filter_channel_data(data, led_state=1, sigma=2.5)
    ch2_interp = filter_channel_data(data, led_state=2, sigma=2.5)
    norm = (ch2_interp - ch1_interp)/ch1_interp
    
    
    # read DLC data
    dlc_file = session_files.query("DataType == 'dlc'")["Filename"].values[0]
    dlc_data = pd.read_csv(dlc_file, index_col=0)
    
    # interpolate all DLC data
    try:
        x_nose = dlc_data.iloc[2:,0].astype(float).values
        y_nose = dlc_data.iloc[2:,1].astype(float).values
        x_head = dlc_data.iloc[2:,3].astype(float).values
        y_head = dlc_data.iloc[2:,4].astype(float).values
        x_tail = dlc_data.iloc[2:,6].astype(float).values
        y_tail = dlc_data.iloc[2:,7].astype(float).values
    except:
        continue
    
    video_timestamp_file = session_files.query("DataType == 'video_timestamps'")["Filename"].values[0]
    video_ts_data = pd.read_csv(video_timestamp_file, header=None)
    video_ts = video_ts_data.iloc[:, 0].values
    video_ts_rel = video_ts - video_ts[0]
    # ts = np.arange(video_ts_rel[0], video_ts_rel[-1], 1/fs)
    x_nose, y_nose, x_head, y_head, x_tail, y_tail = (x_nose[:len(video_ts)], y_nose[:len(video_ts)], 
                                                      x_head[:len(video_ts)], y_head[:len(video_ts)], 
                                                      x_tail[:len(video_ts)], y_tail[:len(video_ts)])
    
    # process DLC data to make coordinates
    xy_list = [x_nose, y_nose, x_head, y_head, x_tail, y_tail]
    coords = []
    
    for c in xy_list:
        f_c = interp1d(video_ts, c, fill_value="extrapolate")
        c_interp = f_c(ts)
        coords.append(c_interp)

    coords = np.stack(coords)

    coords = coords[:, start*fs:end*fs]
    norm = norm[start*fs:end*fs]

    b, a = butter(2, 0.01, btype="high", fs=fs)
    norm_filt = scale(filtfilt(b, a, norm))


    # motion processing to detect freezing bouts
    motion_raw = np.mean(np.abs(coords - np.roll(coords, 1, axis=1)), axis=0)
    motion_raw[0] = 0
    motion = gaussian_filter1d(motion_raw, filt_sigma)
    motion_scale = scale(motion)

    thresh = 0.005*np.median(np.sqrt((coords[0, :]-coords[4, :])**2 + (coords[1, :] - coords[5, :])**2))
    freezing = (motion-motion.min()) < thresh

    results_freezing["freezing"].append(np.mean(freezing))
    results_freezing["context"].append(context)
    results_freezing["timepoint"].append(timepoint)
    results_freezing["mouse"].append(mouse)
    results_freezing["condition"].append(condition)

    freezing_cleaned = freezing.copy()

    onsets, offsets = onset_offset(freezing)

    if onsets[0] > offsets[0]:
        offsets = offsets[1:]

    if freezing[-1] == 1:
        onsets = onsets[:-1]
    
    # eliminate freezing/non-freezing bouts that are shorter than 1 second
    for oo in range(len(onsets)):
        if oo == len(onsets)-1:
            break
        if (offsets[oo] - onsets[oo]) < 1*fs:
            freezing_cleaned[onsets[oo]:offsets[oo]+1] = 0
        if (onsets[oo+1] - offsets[oo]) < 1*fs:
            freezing_cleaned[offsets[oo]:onsets[oo+1]+1] = 1

    onsets, offsets = onset_offset(freezing_cleaned)


    # Add checks before creating arrays:
    if len(onsets) > 1 and window_size[1] > window_size[0]:
        windows = np.zeros((len(onsets) - 1, window_size[1] - window_size[0]))
        windows_motion = np.zeros((len(onsets) - 1, window_size[1] - window_size[0]))
    else:
        print("Error: Invalid dimensions for array creation - len(onsets):", len(onsets), "window_size:", window_size[1] - window_size[0])

    # Assuming other parts of your code are correct and in place
    for oo in range(len(onsets) - 1):
        if (onsets[oo] + window_size[1] >= len(norm_filt)) or (onsets[oo] + window_size[0] <= 0):
            continue

    # Convert the relevant segment to a pandas Series to use diff()
        window_segment = pd.Series(freezing[onsets[oo]:onsets[oo] + window_size[1]])
        if any(window_segment.diff() == -1):
            continue  # Skip this window if freezing ends

    # If the segment is valid, perform your operations
        windows[oo, :] = norm_filt[onsets[oo] + window_size[0]: onsets[oo] + window_size[1]]
        windows_motion[oo, :] = motion[onsets[oo] + window_size[0]: onsets[oo] + window_size[1]]

    # Handle overlaps and NaN assignments as previously coded
        if oo > 0 and ((onsets[oo] + window_size[0]) < offsets[oo - 1]):
            overlap = -window_size[0] - (onsets[oo] - offsets[oo - 1])
            windows[oo, :overlap] = np.nan
            windows_motion[oo, :overlap] = np.nan

        if offsets[oo] < (onsets[oo] + window_size[1]):
            overlap = (onsets[oo] + window_size[1]) - offsets[oo]
            windows[oo, -overlap:] = np.nan
            windows_motion[oo, -overlap:] = np.nan

        if np.sum(~np.isnan(windows[oo, :])) < 10:
            continue
        
        
    for oo in range(windows.shape[0]):
        for t_index, time in enumerate(time_index):
            results["t"].append(time)
            results["activity"].append(windows[oo, t_index])
            results["motion"].append(windows_motion[oo, t_index])
            results["context"].append(context)
            results["timepoint"].append(timepoint)
            results["mouse"].append(mouse)
            
print('done')
# print(results["motion"])