In [1]:
import matplotlib
matplotlib.use('Agg') # disable interactive matplotlib to save RAM
matplotlib.rcParams['agg.path.chunksize'] = 10000

from bs4 import BeautifulSoup
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
import sys
from tqdm.auto import tqdm
sys.path.append('src')

from importrhdutilities import load_file as read_rhd
from utils import *

export_date =  '240124'
joystick_length = 15 #cm
analog_threshold = 71

mice = ['6_2', '6_3']

input_folder = f'data{os.sep}raw{os.sep}Behavior'
reports_folder = f'data{os.sep}reports{os.sep}{export_date}'
si_folder = f'data{os.sep}spikeinterface-0_98_2{os.sep}Behavior'
os.makedirs(reports_folder, exist_ok=True)
print(mice, 'saving to', reports_folder)

['6_2', '6_3'] saving to data/reports/240124


In [2]:
def read_intan(session_path, trace_indices):
    recording_paths = sorted(glob.glob(f'{session_path}{os.sep}*.rhd'))
    session_length = 0
    session_recording_paths = []

    intan_recordings = { 'adc': [], 'dig': [], 't_adc': [], 't_dig': [], 'traces': [], 'recordings': []  }
    for recording_path in (pbar := tqdm(recording_paths)):
        pbar.set_description(recording_path)

        raw_data, data_present = read_rhd(recording_path)
        if not data_present:
            print(f'[no data] Skipping! {recording_path}')
            continue
        sampling_frequency = raw_data['frequency_parameters']['amplifier_sample_rate']
        session_recording_paths.append(recording_path)
        intan_recordings['adc'].append(raw_data['board_adc_data'])
        intan_recordings['dig'].append(raw_data['board_dig_in_data'])
        intan_recordings['t_adc'].append(raw_data['t_board_adc'])
        intan_recordings['t_dig'].append(raw_data['t_dig'])
        intan_recordings['traces'].append(raw_data['amplifier_data'][trace_indices])
        intan_recordings['recordings'].append({
            'file': recording_path,
            'file_start': session_length,
            'file_length': raw_data['amplifier_data'].shape[1],
            'sampling_frequency': sampling_frequency,
        })
        session_length += raw_data['amplifier_data'].shape[1]

    intan_recordings['adc'] = np.hstack(intan_recordings['adc'])
    intan_recordings['dig'] = np.hstack(intan_recordings['dig'])
    intan_recordings['t_adc'] = np.hstack(intan_recordings['t_adc'])
    intan_recordings['t_dig'] = np.hstack(intan_recordings['t_dig'])
    intan_recordings['traces'] = np.hstack(intan_recordings['traces'])
    intan_recordings['recordings'] = pd.json_normalize(intan_recordings['recordings'])
    
    return intan_recordings, sampling_frequency, session_recording_paths

intan_joystick_recording_start_duraiton_threshold = 3 # s
intan_video_recording_start_duraiton_threshold = 0.5 # s

def extract_trigger_start_and_end_indices(triggers):
    start_indices = np.where(~triggers[:-1] & triggers[1:])[0] + 1 # From False to True
    end_indices = np.where(triggers[:-1] & ~triggers[1:])[0] + 1 # From True to False

    if triggers[0]:
        start_indices = np.insert(start_indices, 0, 0)

    if triggers[-1]:
        end_indices = np.append(end_indices, len(triggers) - 1)
    return start_indices, end_indices

def compute_trigger_durations(times, start_indices, end_indices):
    return np.array([
        times[end_index] - times[start_index] for start_index, end_index in zip(start_indices, end_indices)
    ])

def extract_spike_train_at_trigger(spike_train, trigger_indices, sampling_frequency, window_ms, offset=True):
    n_samples_per_ms = sampling_frequency / n_ms_per_s
    window_samples = int(n_samples_per_ms * window_ms)
    events = []
    for trigger_index in trigger_indices:
        trigger_events = spike_train[(spike_train >= trigger_index - window_samples) * (spike_train <= trigger_index + window_samples)]
        if offset:
            trigger_events = trigger_events - (trigger_index - window_samples)
        events.append(trigger_events)
    return events

In [3]:
mice_recordings = []
mice_intan = {}
for mouse in mice:
    session_paths = glob.glob(f'{input_folder}{os.sep}{mouse}{os.sep}{mouse}_*_{export_date}_*')
    if len(session_paths) == 0:
        print(f'{mouse} with no sessions found')
        continue
    elif len(session_paths) > 1:
        print(f'{mouse} with {len(session_paths)} sessions found')
        for session_index, session_path in enumerate(session_paths):
            print(f'    [{session_index}+1] {session_path}')
            
    session_path = session_paths[0]
    session = session_path.split(os.sep)[-1]
    session_si_path = f'{si_folder}{os.sep}{mouse}{os.sep}{session}'
    os.makedirs(session_si_path, exist_ok=True)

    mice_intan[mouse], sampling_frequency, session_recording_paths = read_intan(session_path, list(range(0, 7)) + list(range(9, 32)))
    
    intan_joystick_trigger_indices = extract_trigger_start_and_end_indices(mice_intan[mouse]['dig'][0])
    intan_joystick_trigger_durations = compute_trigger_durations(mice_intan[mouse]['t_dig'], *intan_joystick_trigger_indices)

    # Find the starting index of joystick recording in intan by find the index pair of the trigger box.
    intan_joystick_recording_start_indices_index = np.where(intan_joystick_trigger_durations > intan_joystick_recording_start_duraiton_threshold)[0].item()
    intan_joystick_recording_start_index = intan_joystick_trigger_indices[0][intan_joystick_recording_start_indices_index]
    
    intan_recorded_trigger_indices = intan_joystick_trigger_indices[0][intan_joystick_recording_start_indices_index+1:] - intan_joystick_recording_start_index

    mice_intan[mouse]['adc'] = mice_intan[mouse]['adc'][:, intan_joystick_recording_start_index:]
    mice_intan[mouse]['dig'] = mice_intan[mouse]['dig'][:, intan_joystick_recording_start_index:]
    mice_intan[mouse]['t_adc'] = mice_intan[mouse]['t_adc'][intan_joystick_recording_start_index:]
    mice_intan[mouse]['t_dig'] = mice_intan[mouse]['t_dig'][intan_joystick_recording_start_index:]
    mice_intan[mouse]['traces'] = mice_intan[mouse]['traces'][:, intan_joystick_recording_start_index:]
    session_duration = mice_intan[mouse]['traces'].shape[1] / sampling_frequency

    with open(f'{session_path}{os.sep}settings.xml', 'r') as f:
        intan_settings = f.read()
    intan_settings = BeautifulSoup(intan_settings, 'xml')
    session_trigger_enabled = all([channel['Enabled']=='True' for channel in intan_settings.find_all('SignalGroup', {'Prefix': 'DIGITAL-IN'})[0].find_all('Channel')])

    mice_recordings.append({
        'date': export_date,
        'mouse': mouse,
        'session': session,
        'days_trained': re.match(rf'{mouse}_(?P<days_trained>.*)_{export_date}', session).group('days_trained'),
        'timestamp': datetime.datetime.strptime(re.search(r'\d{6}_\d{6}$', session).group(), '%y%m%d_%H%M%S'),
        'intan_trigger_enabled': session_trigger_enabled,
        'intan_sampling_frequency': sampling_frequency,
        'intan_duration(s)': session_duration,
        'intan_recording_paths': session_recording_paths,
        'intan_recorded_trigger_indices': intan_recorded_trigger_indices,
    })
    
    if not os.path.isfile(f'{session_si_path}{os.sep}processed{os.sep}traces.png'):
        recording = se.NumpyRecording(traces_list=mice_intan[mouse]['traces'].T, sampling_frequency=sampling_frequency)
        multi_shank_probe = create_probe(intan_channel_indices, savepath=f'{session_si_path}{os.sep}probe.png')
        recording.set_probe(multi_shank_probe, in_place=True)
        recording_processed = preprocess_recording(recording, steps=['bp', 'cmr', 'clip'])
        recording_processed.save(folder=f'{session_si_path}{os.sep}processed', memory=memory_limit)
        plot_traces(recording_processed.get_traces().T, recording.sampling_frequency, intan_channel_indices, title=f'{mouse} -> {session}', savepath=f'{session_si_path}{os.sep}processed{os.sep}traces.png')

    recording_processed = sc.load_extractor(f'{session_si_path}{os.sep}processed')
    probegroup = create_probegroup(intan_channel_indices)
    recording_processed.set_probegroup(probegroup, group_mode='by_probe', in_place=True)

    print('Begin sorting ...')
    if not os.path.isfile(f'{session_si_path}{os.sep}sorting{os.sep}sorter_output{os.sep}firings.npz'):
        sorting = ss.run_sorter_by_property(
            sorter_name='mountainsort4',
            recording=recording_processed,
            grouping_property='group',
            working_folder=f'{session_si_path}{os.sep}sorting',
            mode_if_folder_exists='overwrite',
            **sorter_parameters,
        )

        os.makedirs(f'{session_si_path}{os.sep}sorting{os.sep}sorter_output', exist_ok=True)
        se.NpzSortingExtractor.write_sorting(sorting, f'{session_si_path}{os.sep}sorting{os.sep}sorter_output{os.sep}firings.npz')

    sorting = se.NpzSortingExtractor(f'{session_si_path}{os.sep}sorting{os.sep}sorter_output{os.sep}firings.npz')

    n_cols = 5
    n_rows = int(np.ceil(len(sorting.unit_ids) / n_cols))
    plt.figure(figsize=(n_cols*5, n_rows*5))
    for unit_id in sorting.unit_ids:
        row, col = unit_id // n_cols, unit_id % n_cols
        ax = plt.subplot(n_rows, n_cols, unit_id+1)
        unit_spike_train = sorting.get_unit_spike_train(unit_id=unit_id)
        unit_events = extract_spike_train_at_trigger(unit_spike_train, intan_recorded_trigger_indices, sampling_frequency, window_ms=1000)
        ax.eventplot(unit_events)
        ax.set_axis_off()
        ax.set_title(f'unit {unit_id}')
    plt.savefig(f'{reports_folder}{os.sep}{mouse}-units.png',bbox_inches='tight')

    print('Begin waveform extraction ...')
    if not os.path.isfile(f'{session_si_path}{os.sep}waveforms{os.sep}templates_average.npy'):
        sc.extract_waveforms(
            recording_processed, sorting, 
            folder=f'{session_si_path}{os.sep}waveforms',
            ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=1000,
            return_scaled=False,
            overwrite=True,
            use_relative_path=True,
        )

    waveform_extractor = sc.load_waveforms(
        folder=f'{session_si_path}{os.sep}waveforms', with_recording=True, sorting=sorting
    )
    extremum_channels = sc.get_template_extremum_channel(waveform_extractor, peak_sign='neg')

    os.makedirs(f'{session_si_path}{os.sep}units', exist_ok=True)
    for unit_id in sorting.unit_ids:
        unit_plot_file = f'{session_si_path}{os.sep}units{os.sep}{unit_id}.png'
        if not os.path.isfile(unit_plot_file):
            plot_unit(waveform_extractor, extremum_channels, sorting, unit_id, intan_channel_indices, initdate=surgery_dates[mouse], savepath=unit_plot_file)
        
    mice_intan[mouse]['recordings'].to_csv(f'{session_si_path}{os.sep}files.csv', index=False)

  0%|          | 0/32 [00:00<?, ?it/s]

Reading data/raw/Behavior/6_2/6_2_day24_240124_083523/6_2_day24_240124_083523.rhd
Found 32 amplifier channels.
Found 3 auxiliary input channels.
Found 2 board ADC channels.
Found 2 board digital input channels.
File contains 60.002 seconds of data.  Amplifiers were sampled at 30.00 kS/s.
No missing timestamps in data.

Reading data/raw/Behavior/6_2/6_2_day24_240124_083523/6_2_day24_240124_083623.rhd
Found 32 amplifier channels.
Found 3 auxiliary input channels.
Found 2 board ADC channels.
Found 2 board digital input channels.
File contains 60.002 seconds of data.  Amplifiers were sampled at 30.00 kS/s.
No missing timestamps in data.

Reading data/raw/Behavior/6_2/6_2_day24_240124_083523/6_2_day24_240124_083723.rhd
Found 32 amplifier channels.
Found 3 auxiliary input channels.
Found 2 board ADC channels.
Found 2 board digital input channels.
File contains 60.002 seconds of data.  Amplifiers were sampled at 30.00 kS/s.
No missing timestamps in data.

Reading data/raw/Behavior/6_2/6_2_day

In [None]:
mice_recordings = pd.json_normalize(mice_recordings).sort_values(by='timestamp')
mice_recordings

In [None]:
arduino_events = []
arduino_event_paths = glob.glob(f'{input_folder}{os.sep}Event{os.sep}*')
for arduino_event_path in arduino_event_paths:
    with open(arduino_event_path, 'r') as f:
        line = f.readline() # skip the header
        while line.startswith('Date'):
            line = f.readline()
        timestamp = datetime.datetime.utcfromtimestamp(int(line.split(',')[0]))
        if timestamp.strftime('%y%m%d') == export_date:
            arduino_events.append({
                'timestamp': timestamp,
                'arduino_event_path': arduino_event_path,
            })
if len(arduino_events) > 0:
    arduino_events = pd.json_normalize(arduino_events).sort_values(by='timestamp')
else:
    arduino_events = pd.DataFrame({'timestamp': [], 'arduino_event_path': []})
    arduino_events['timestamp'] = pd.to_datetime(arduino_events['timestamp'])
arduino_events

In [None]:
# Merge export_date's joystick events
mice_recordings = pd.merge_asof(mice_recordings, arduino_events, on='timestamp', direction='nearest', tolerance=pd.Timedelta('300s'))
mice_recordings

In [None]:
def extract_video_df(name, paths):
    videos = []
    for path in paths:
        matched = re.search(r'\d{4}-\d{2}-\d{2} \d{2}_\d{2}_\d{2}', path)
        if matched is not None:
            timestamp = datetime.datetime.strptime(matched.group(), '%Y-%m-%d %H_%M_%S')
            if timestamp.strftime('%y%m%d') == export_date:
                videos.append({
                    'timestamp': timestamp,
                    name: path,
                })
    if len(videos) == 0:
        videos = pd.DataFrame({'timestamp':[], name:[]})
        videos['timestamp'] = pd.to_datetime(videos['timestamp'])
    else:
        videos = pd.json_normalize(videos)
    return videos.sort_values(by='timestamp')

video_folder = f'{input_folder}{os.sep}Video'
video_timestamps = extract_video_df('video_timestamp_path', glob.glob(f'{video_folder}{os.sep}*timestamp.txt'))
videos = extract_video_df('video_path', glob.glob(f'{video_folder}{os.sep}*.h264'))
videos = pd.merge_asof(videos, video_timestamps, on='timestamp', direction='nearest', tolerance=pd.Timedelta('60s'))
videos

In [None]:
# Merge export_date's video recordings
mice_recordings = pd.merge_asof(mice_recordings, videos, on='timestamp', direction='nearest', tolerance=pd.Timedelta('300s'))
mice_recordings

In [None]:
n_ms_per_s = 1000
n_s_per_min = 60
n_mm_per_cm = 10
ms_before, ms_after = 100, 300
analog_halfwidth = 2048
v_half = 3.3 / 2
bin_size = 20 #ms
bin_scale = n_ms_per_s / bin_size

def analog_to_digital(analog, max_joystick_angle=22.5):
    return joystick_length * (np.sin(np.deg2rad(max_joystick_angle * analog/analog_halfwidth))) # cm

In [None]:
for session_i, mouse in enumerate(mice_recordings['mouse']):
    mice_recording = mice_recordings.iloc[session_i:session_i+1]

    n_samples_per_ms = int(mice_recording['intan_sampling_frequency'].item() / n_ms_per_s)
    arduino_event_path = mice_recording['arduino_event_path'].item()

    if not os.path.isfile(str(arduino_event_path)):
        print(f'[No Event] {mice_recording["session"]} @ {arduino_event_path}')
        continue

    arduino_event = pd.read_csv(arduino_event_path, delimiter=',')
    arduino_event['time'] = (arduino_event['time'] - arduino_event['time'][0]) / n_ms_per_s

    fig = plt.figure(figsize=(20, 20), layout='constrained')
    fig.suptitle(f'{mice_recording["session"].item()} -> {arduino_event_path}', y=1, fontsize=25)

    # Plot Intan Recorded Joystick Position
    ax = plt.axes([0.05, 0.8, 0.9, 0.16])

    intan_t = mice_intan[mouse]['t_adc']  / n_s_per_min
    intan_position = np.sqrt(
        analog_to_digital((mice_intan[mouse]['adc'][0] - v_half) / v_half * analog_halfwidth) ** 2 + # x**2
        analog_to_digital((mice_intan[mouse]['adc'][1] - v_half) / v_half * analog_halfwidth) ** 2   # y**2
    )
    ax.plot(
        intan_t[mice_recording['intan_recording_start_index'].item():],
        intan_position[mice_recording['intan_recording_start_index'].item():],
        linewidth=0.5, color='black',
    )
    ax.scatter(
        intan_t[mice_recording['intan_recorded_trigger_indices'].item()],
        [-0.05] * len(mice_recording['intan_recorded_trigger_indices'].item()),
        color='steelblue', s=50, marker='|'
    )

    ax.set_title('Intan ADC')
    ax.set_ylabel('position (cm)')
    ax.set_xlabel('time (min)')
    ax.set_xlim(-1, 31)
    ax.set_ylim(-0.2, 2.5)


    # Plot Arduino Recorded Joystick Position
    ax = plt.axes([0.05, 0.6, 0.9, 0.16])
    arduino_joystick_start_indices = np.where((arduino_event['state'] == 0)[:-1].to_numpy() & (arduino_event['state'] == 1)[1:].to_numpy())[0] + 1 # From state 0 to state 1

    arduino_event['position'] = np.sqrt(
        analog_to_digital(arduino_event['x'] + arduino_event['joystickX_offset'] - analog_halfwidth) ** 2 +
        analog_to_digital(arduino_event['y'] + arduino_event['joystickY_offset'] - analog_halfwidth) ** 2)
    ax.plot(
        arduino_event['time'] / n_s_per_min,
        arduino_event['position'],
        linewidth=0.5, color='black'
    )
    ax.scatter(
        arduino_event['time'][arduino_joystick_start_indices]  / n_s_per_min,
        [-0.05] * len(arduino_joystick_start_indices),
        color='steelblue', s=50, marker='|'
    )
    ax.set_title('Arduino ADC')
    ax.set_ylabel('position (cm)')
    ax.set_xlabel('time (min)')
    ax.set_xlim(-1, 31)
    ax.set_ylim(-0.2, 2.5)

    # Plot the computed velocity by the binned time duration.
    ax = plt.axes([0.05, 0.4, 0.9, 0.16])

    arduino_event['binned_time'] = np.floor(arduino_event['time'] * bin_scale) / bin_scale
    binned_arduino_event = arduino_event[['binned_time', 'position']].groupby('binned_time').mean().reset_index().rename(columns={'position': 'binned_position'})
    # binned_joystick_event = arduino_event[['binned_time', 'position']].groupby('binned_time').apply(lambda x: x['position'].tolist()[0]).reset_index(name='binned_position')
    binned_arduino_event['dbinned_position'] = np.diff([binned_arduino_event['binned_position'][0]] + binned_arduino_event['binned_position'].tolist())
    binned_arduino_event['dbinned_time'] = np.diff([0] + binned_arduino_event['binned_time'].tolist())
    binned_arduino_event['outward_velocity'] = binned_arduino_event['dbinned_position'] / binned_arduino_event['dbinned_time']

    ax.plot(
        binned_arduino_event['binned_time'] / n_s_per_min, 
        binned_arduino_event['outward_velocity'], 
        color='black', linewidth=0.5
    )
    ax.set_xlabel('time (min)')
    ax.set_ylabel('Outward\nVelocity\n(cm/s)')
    ax.set_xlim(-1, 31)
    ax.set_ylim(-50, 50)
    

    # Plot joystick movements.
    ax = plt.axes([0.05, 0.2, 0.3, 0.16])
    for trial in arduino_event['trial'].unique():
        trial_indices = np.where(arduino_event['trial'] == trial)[0]
        arduino_event.loc[trial_indices, 'trial_position'] = arduino_event.loc[trial_indices, 'position'].max()
    reached_positions = arduino_event['trial_position'].unique()
    sns.histplot(data=reached_positions, binwidth=0.1, kde=True, ax=ax)
    ax.set_xlabel('Reached Position (cm)')


    # Plot joystick movements.
    ax = plt.axes([0.42, 0.2, 0.16, 0.16])
    for arduino_joystick_start_index in arduino_joystick_start_indices:
        trigger_time = arduino_event['time'][arduino_joystick_start_index]
        trigger_start_time, trigger_end_time = trigger_time - ms_before / n_ms_per_s, trigger_time + ms_after / n_ms_per_s
        tigger_times = np.where((arduino_event['time'] >= trigger_start_time) & (arduino_event['time'] <= trigger_end_time))[0]
        ax.plot(analog_to_digital(arduino_event['x'][tigger_times]), analog_to_digital(arduino_event['y'][tigger_times]))
        ax.set_title(f'({ms_before}ms before - {ms_after}ms after) triggered')
        ax.set_xlim(-2.5, 2.5)
        ax.set_ylim(-2.5, 2.5)
        ax.set_xlabel('Position (cm)')
        ax.set_ylabel('Position (cm)')

    # Plot joystick movements.
    ax = plt.axes([0.65, 0.2, 0.3, 0.16])
    sns.histplot(data = np.diff(arduino_event['time'][arduino_joystick_start_indices]), binwidth=5, kde=True, ax=ax)
    ax.set_xlabel('Inter-Trigger Interval (s)')


    # Plot the start and final reached position for each trial.
    ax = plt.axes([0.05, 0, 0.9, 0.16])

    arduino_event['offset_position'] = np.sqrt(analog_to_digital(arduino_event['joystickX_offset']-analog_halfwidth)**2 + analog_to_digital(arduino_event['joystickY_offset']-analog_halfwidth)**2)

    ax.plot(arduino_event['trial'], arduino_event['offset_position'], color='red', linewidth=1, linestyle='dashed')
    ax.plot(arduino_event['trial'], arduino_event['trial_position'], color='green', linewidth=1)
    ax.set_xlabel('Trial')
    ax.set_ylabel('Start (Red) and Reached(Green) \nPositions\n(cm)')
    ax.set_ylim(-0.1, 2.5)
    ax.set_xticks(np.arange(arduino_event['trial'].max()+1))

    # Save plots.
    joystick_event_plot_file = arduino_event_path.replace(os.sep, '_').replace('.TXT', f'-{mice_recordings.iloc[session_i]["mouse"]}.png')
    plt.savefig(f'{reports_folder}{os.sep}{joystick_event_plot_file}',bbox_inches='tight')
    plt.show()
    plt.close('all')
    plt.clf()

    mice_recordings.loc[session_i, 'joystick_event_plot'] = joystick_event_plot_file
    mice_recordings.loc[session_i, 'target_move_distance(mm)'] = analog_to_digital(analog_threshold) * n_mm_per_cm
    mice_recordings.loc[session_i, 'joystick_length(cm)'] = joystick_length
    mice_recordings.loc[session_i, 'successful_trials'] = arduino_event['trial'].max()
    mice_recordings.loc[session_i, 'arduino_joystick_duration(s)'] = arduino_event['time'].max() - arduino_event['time'].min()
    mice_recordings.loc[session_i, 'intan_joystick_duration(s)'] = mice_intan[mouse]['t_adc'][intan_joystick_trigger_indices[1][-1]] - mice_intan[mouse]['t_adc'][mice_recording['intan_recording_start_index'].item()]
    mice_recordings.loc[session_i, 'mean_reached_distance(mm)'] = reached_positions.mean() * n_mm_per_cm