1. ISI
2. Templates overtime 
3. All templates
4. UMAP overtime
5. Features overtime (Pointplots)

In [1]:
import matplotlib
matplotlib.use('Agg')

import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import spikeinterface.core as sc
import spikeinterface.curation as scu
import spikeinterface.extractors as se
import spikeinterface.widgets as sw
import sys 

from tqdm.auto import tqdm

sys.path.append('src')

from utils import plot_isi_by_session, blackrock_channel_indices

sorted_date = '20231229'
mice = sorted([mouse_path.split(os.sep)[-1] for mouse_path in glob.glob(f'data{os.sep}sorted{os.sep}{sorted_date}{os.sep}**')])
print(f'Plotting mice: {mice} sorted on {sorted_date}')

Plotting mice: ['1_5', '5_7', '6_2', '6_3', '6_7', '7_2'] sorted on 20231229


In [None]:
for mouse in (pbar := tqdm(mice)):
    mouse_sorted_folder = f'data{os.sep}sorted{os.sep}{sorted_date}{os.sep}{mouse}'
    mouse_isi_folder = f'{mouse_sorted_folder}{os.sep}isi'
    os.makedirs(mouse_isi_folder, exist_ok=True)

    recording_processed = sc.load_extractor(f'{mouse_sorted_folder}{os.sep}processed')

    sorting = se.NpzSortingExtractor(f'{mouse_sorted_folder}{os.sep}sorting{os.sep}sorter_output{os.sep}firings.npz')
    # spikeinterface https://github.com/SpikeInterface/spikeinterface/pull/1378
    sorting = scu.remove_excess_spikes(sorting, recording_processed)

    waveform_extractor = sc.load_waveforms(
        folder=f'{mouse_sorted_folder}{os.sep}waveforms', with_recording=True, sorting=sorting
    )
    extremum_channels = sc.get_template_extremum_channel(waveform_extractor, peak_sign='neg')

    sessions = pd.read_csv(f'{mouse_sorted_folder}{os.sep}sessions.csv').sort_values(by='date')

    for unit_id in sorting.unit_ids:
        pbar.set_description(f'{mouse} -> [{unit_id} / {len(sorting.unit_ids)}]')
        unit_isi_plot_file = f'{mouse_isi_folder}{os.sep}{unit_id}.png'
        if not os.path.isfile(unit_isi_plot_file):
            plot_isi_by_session(sorting, unit_id, sessions, savepath=unit_isi_plot_file)

In [90]:
"""Functions for calculating and plotting feature pointplots"""

import utils
from calculate_features import features_5
import scipy.stats
import anndata as ad

def cal_features(recording, sorting, waveform_extractor, extremum_channels, unit_id, sessions, channel_indices):
    """
    Calculates eight features across sessions for a given unit. 
        
    Arguments
    -----------------------------------------------
    recording: SI recording object. We will usually want the processed recording object (e.g., SNR calculates noise from processed recording)
    sorting: SI sorting object
    waveform_extractor: SI waveform_extractor object
    extremum_channels: dict of units and corresponding extremum channel (key = unit, value = extremum channel) 
    unit_id: unit id 
    sessions: pandas dataframe of session data
    channel_indices: numpy array of channels by shank

    Returns
    -----------------------------------------------
    features: dict of n_sessions key-value pairs (key = session, value = dict of 8 key-value pairs (key = feature_name, value = feature value))
              the 8 features are ['duration', 'peak_trough_ratio', 'halfwidth', 'repolarization_slope', 'recovery_slope', 'amplitude', 'snr', 'firing_rate']
    """
    
    # Get traces, spike trains, and shank / extremum waveforms for all sessions 
    traces = recording.get_traces().T
    session_spike_trains = utils.split_unit_spike_train_indicies_by_session(sorting, unit_id, sessions)

    unit_extremum_waveforms = waveform_extractor.get_waveforms(unit_id)[:, :, extremum_channels[unit_id]]
    unit_extremum_waveform_adata = utils.get_session_waveforms_adata(unit_extremum_waveforms, session_spike_trains, max_count_per_session=1000)

    unit_shank_waveforms = utils.get_unit_shank_waveforms(waveform_extractor, extremum_channels, channel_indices, unit_id)
    unit_shank_waveform_adata = utils.get_session_waveforms_adata(unit_shank_waveforms, session_spike_trains, max_count_per_session=1000)
    
    # Empty features dict (output)
    features = {}

    for session_i in range(len(sessions)):
        
        # Empty features_session dict for session_i
        features_session = {}

        # Get waveforms for this session (session_shank_waveform not currently being used; maybe we delete)
        session_shank_waveform_adata = unit_shank_waveform_adata[unit_shank_waveform_adata.obs['session_i'] == session_i]
        session_extremum_waveform_adata = unit_extremum_waveform_adata[unit_extremum_waveform_adata.obs['session_i'] == session_i]

        session_shank_waveform = np.array(session_shank_waveform_adata.X)
        session_extremum_waveform = np.array(session_extremum_waveform_adata.X)

        # Calculate duration, peak_trough_ratio, halfwidth, repolarization_slope, recovery_slope
        features_session = features_5(session_extremum_waveform, sorting.sampling_frequency, feature_names = ['peak_to_valley', 'peak_trough_ratio', 'halfwidth', 'repolarization_slope', 'recovery_slope'])
        features_session['duration'] = features_session.pop('peak_to_valley') # change the key of 'peak_to_valley' entry to 'duration'

        # Calculate amplitude
        amplitude = np.max(session_extremum_waveform, axis=1) - np.min(session_extremum_waveform, axis=1)        
        features_session['amplitude'] = amplitude

        # Calculate snr
        session_start = sessions.at[session_i, 'session_start']
        session_end = sessions.at[session_i, 'session_start'] + sessions.at[session_i, 'session_length']
        
        session_recording = traces[extremum_channels[unit_id], session_start:session_end] # did not create an Anndata recording array due to size
        noise_level = scipy.stats.median_abs_deviation(session_recording, scale='normal')
        snr = np.nanmean(amplitude) / noise_level
        features_session['snr'] = snr
        
        # Calculate firing_rate
        if np.isnan(snr):
            firing_rate = np.nan
        else:
            n_spikes = len(session_spike_trains[session_i])
            firing_rate = n_spikes / (sessions.at[session_i, 'session_length'] / sorting.sampling_frequency)
        
        features_session['firing_rate'] = firing_rate

        # Collect features_session into features
        features[session_i] = features_session
    
    return features

def plot_features_pointplot(recording, sorting, waveform_extractor, extremum_channels, unit_id, sessions, channel_indices, savepath=None):
    """
    Plots pointpolots of eight features across sessions for a given unit.
        
    Arguments
    -----------------------------------------------
    recording: SI recording object. We will usually want the processed recording object (e.g., SNR calculates noise from processed recording)
    sorting: SI sorting object
    waveform_extractor: SI waveform_extractor object
    extremum_channels: dict of units and corresponding extremum channel (key = unit, value = extremum channel) 
    unit_id: unit id 
    sessions: pandas dataframe of session data
    channel_indices: numpy array of channels by shank
    savepath: pointplot save location and filename

    Returns
    -----------------------------------------------
    None
    """

    # Calculate features for all sessions for a single unit
    features = cal_features(recording, sorting, waveform_extractor, extremum_channels, unit_id, sessions, channel_indices)
    
    # Plot features pointplot for all sessions for a single unit
    features_title = ['Duration (s)', 'Peak Trough Ratio', 'Halfwidth (s)', 'Repolarization Slope', 'Recovery Slope', 'Amplitude ($\mu$V)', 'SNR', 'Firing Rate (Hz)']
    fig, axs = plt.subplots(2, 4, figsize=(20, 10))

    for feature_i, feature_name in enumerate(features[0].keys()):
        row = int(feature_i / 4)
        col = feature_i % 4
        
        feature_mean = []
        feature_std = []
        for session_i in range(len(sessions)):
            feature_mean.append(np.nanmean(features[session_i][feature_name]))
            feature_std.append(np.nanstd(features[session_i][feature_name]))
            
        axs[row, col].plot(feature_mean, '-o', c='r', markersize=8, lw=3)
        for session_i in range(len(sessions)):
            axs[row, col].plot([session_i, session_i], [feature_mean[session_i]-feature_std[session_i], feature_mean[session_i]+feature_std[session_i]], '-_', c='r', markersize=8, lw=3)

        axs[row, col].set_xlabel('Sessions', fontsize=15)
        axs[row, col].xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(1))        
        axs[row, col].set_xlim(0 - 0.5, len(sessions) - 1 + 0.5)
        axs[row, col].set_ylabel(features_title[feature_i], fontsize=15)
        axs[row, col].tick_params(axis='both', which='major', labelsize=12)
        axs[row, col].ticklabel_format(axis='y', style='sci', scilimits=(0,0))

    # Formatting at aggregate figure level
    ax = fig.add_subplot(1, 1, 1, frame_on=False)
    plt.tick_params(labelcolor='none', bottom=False, left=False)
    plt.title(f'Unit {unit_id} Long Term Waveform Features', pad=30, fontsize=30)
    plt.tight_layout()

    # Save file
    if savepath is not None:
        plt.savefig(savepath, bbox_inches='tight')

def get_2d_location_from_amplitude(amplitude, shank_of_channel):
    """
    TBU - Based on _get_2d_location_from_amplitude function from stability_prover_representational_drift.py
    """
    


def plot_unit_location(recording, sorting, waveform_extractor, extremum_channels, unit_id, sessions, channel_indices, savepath=None):
    """
    TBU - Based on plot_unit_location function from stability_prover_representational_drift.py
    """

    # Get traces, spike trains, and shank / extremum waveforms for all sessions 
    traces = recording.get_traces().T
    session_spike_trains = utils.split_unit_spike_train_indicies_by_session(sorting, unit_id, sessions)

    unit_extremum_waveforms = waveform_extractor.get_waveforms(unit_id)[:, :, extremum_channels[unit_id]]
    unit_extremum_waveform_adata = utils.get_session_waveforms_adata(unit_extremum_waveforms, session_spike_trains, max_count_per_session=1000)

    unit_shank_waveforms = utils.get_unit_shank_waveforms(waveform_extractor, extremum_channels, channel_indices, unit_id)
    unit_shank_waveform_adata = utils.get_session_waveforms_adata(unit_shank_waveforms, session_spike_trains, max_count_per_session=1000)

    # Empty locations dict
    locs = {}

    for session_i in range(len(sessions)):
        
        # Empty list holding unit location in this session
        locs[session_i] = []

        # Get waveforms for this session
        session_shank_waveform_adata = unit_shank_waveform_adata[unit_shank_waveform_adata.obs['session_i'] == session_i]
        session_extremum_waveform_adata = unit_extremum_waveform_adata[unit_extremum_waveform_adata.obs['session_i'] == session_i]

        session_shank_waveform = np.array(session_shank_waveform_adata.X)
        session_extremum_waveform = np.array(session_extremum_waveform_adata.X)

        # Get mean shank waveform and split concatenated shank waveform into individual spikes (ndarray, dim_0 = n_spike, dim_1 = n_sample)
        mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
        frames_per_channel = int(session_shank_waveform.shape[1] / channel_indices.shape[1]) # samples per concatenated shank_waveform divided by channels per shank
        temp = []
        for channel_i in range(channel_indices.shape[1]):
            temp.append(mean_session_shank_waveform[channel_i*frames_per_channel:(channel_i+1)*frames_per_channel])
        mean_session_shank_waveform = np.array(temp)
        
        # Get mean waveform amplitude at each channel of shank
        session_shank_amplitude = np.nanmax(mean_session_shank_waveform, axis=1) - np.nanmin(mean_session_shank_waveform, axis=1)
        x, y = get_2d_location_from_amplitude(session_shank_amplitude, unit_id, channel_indices)


    return None


In [3]:
"""Features over time"""

for mouse in (pbar := tqdm(mice)):
    mouse_sorted_folder = f'data{os.sep}sorted{os.sep}{sorted_date}{os.sep}{mouse}'
    mouse_features_folder = f'{mouse_sorted_folder}{os.sep}features'
    os.makedirs(mouse_features_folder, exist_ok=True)

    recording_processed = sc.load_extractor(f'{mouse_sorted_folder}{os.sep}processed')

    sorting = se.NpzSortingExtractor(f'{mouse_sorted_folder}{os.sep}sorting{os.sep}sorter_output{os.sep}firings.npz')
    # spikeinterface https://github.com/SpikeInterface/spikeinterface/pull/1378
    sorting = scu.remove_excess_spikes(sorting, recording_processed)

    waveform_extractor = sc.load_waveforms(
        folder=f'{mouse_sorted_folder}{os.sep}waveforms', with_recording=True, sorting=sorting
    )
    extremum_channels = sc.get_template_extremum_channel(waveform_extractor, peak_sign='neg')

    sessions = pd.read_csv(f'{mouse_sorted_folder}{os.sep}sessions.csv').sort_values(by='date')

    for unit_id in sorting.unit_ids:
        pbar.set_description(f'{mouse} -> [{unit_id} / {len(sorting.unit_ids)}]')
        unit_features_plot_file = f'{mouse_features_folder}{os.sep}{unit_id}.png'
        if not os.path.isfile(unit_features_plot_file):
            plot_features_pointplot(recording_processed, sorting, waveform_extractor, extremum_channels, unit_id, sessions, blackrock_channel_indices, savepath=unit_features_plot_file)


  0%|          | 0/6 [00:00<?, ?it/s]

  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  snr = np.nanmean(amplitude) / noise_level
  feature_mean.append(np.nanmean(features[session_i][feature_name]))
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  snr = np.nanmean(amplitude) / noise_level
  feature_mean.append(np.nanmean(features[session_i][feature_name]))
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  snr = np.nanmean(amplitude) / noise_level
  feature_mean.append(np.nanmean(features[session_i][feature_name]))
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + 

In [91]:
"""Location over time"""

for mouse in (pbar := tqdm(mice)):
    mouse_sorted_folder = f'data{os.sep}sorted{os.sep}{sorted_date}{os.sep}{mouse}'
    mouse_loc_folder = f'{mouse_sorted_folder}{os.sep}location'
    os.makedirs(mouse_loc_folder, exist_ok=True)

    recording_processed = sc.load_extractor(f'{mouse_sorted_folder}{os.sep}processed')

    sorting = se.NpzSortingExtractor(f'{mouse_sorted_folder}{os.sep}sorting{os.sep}sorter_output{os.sep}firings.npz')
    # spikeinterface https://github.com/SpikeInterface/spikeinterface/pull/1378
    sorting = scu.remove_excess_spikes(sorting, recording_processed)

    waveform_extractor = sc.load_waveforms(
        folder=f'{mouse_sorted_folder}{os.sep}waveforms', with_recording=True, sorting=sorting
    )
    extremum_channels = sc.get_template_extremum_channel(waveform_extractor, peak_sign='neg')

    sessions = pd.read_csv(f'{mouse_sorted_folder}{os.sep}sessions.csv').sort_values(by='date')

    for unit_id in sorting.unit_ids:
        pbar.set_description(f'{mouse} -> [{unit_id} / {len(sorting.unit_ids)}]')
        unit_loc_plot_file = f'{mouse_loc_folder}{os.sep}{unit_id}.png'
        if not os.path.isfile(unit_loc_plot_file):
            plot_unit_location(recording_processed, sorting, waveform_extractor, extremum_channels, unit_id, sessions, blackrock_channel_indices, savepath=unit_loc_plot_file)

  0%|          | 0/6 [00:00<?, ?it/s]

[[-3.70589960e+00 -3.18502400e+00 -2.33357659e+00 -1.33444805e+00
   4.30248314e-01  2.61395362e+00  4.29991226e+00  5.84015484e+00
   6.58543387e+00  3.60104930e+00 -2.48435133e+00 -5.57491949e+00
  -2.91257103e+00  9.36880427e-01  1.93927086e+00  9.76808339e-01
  -5.49757330e-01 -2.62223374e+00 -4.33564095e+00 -4.32905796e+00
  -2.61057677e+00 -9.37694852e-02  2.30736958e+00  3.84915224e+00
   4.32363285e+00  4.02863027e+00  3.26251559e+00  2.29025930e+00
   1.35825468e+00  4.34067735e-01]
 [-4.87154595e+00 -4.12132275e+00 -2.91658561e+00 -1.73974203e+00
   1.58227020e-01  2.74747677e+00  5.19467467e+00  7.48166923e+00
   8.03342182e+00  3.42560634e+00 -4.48444084e+00 -7.85200187e+00
  -3.96292780e+00  1.23398894e+00  2.96167897e+00  2.12659478e+00
   2.09370298e-01 -2.44824094e+00 -4.43697345e+00 -4.31158188e+00
  -2.29269207e+00  6.40851728e-01  3.46639355e+00  5.13155042e+00
   5.30894562e+00  4.51845227e+00  3.26670609e+00  1.84603825e+00
   5.22718520e-01 -6.57653987e-01]
 [-3.4

  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))


[[ 2.99829567e-01  2.15110468e+00  3.16162764e+00  3.76459810e+00
   4.16885272e+00  3.47876158e+00  2.46636435e+00  2.04616247e+00
  -1.00053365e+00 -8.69196120e+00 -1.45250882e+01 -1.14743908e+01
  -2.96634453e+00  3.29188386e+00  5.90055310e+00  7.28210744e+00
   7.90835164e+00  7.46512597e+00  6.35161612e+00  4.51993174e+00
   1.97733834e+00 -2.56781569e-01 -1.48432846e+00 -2.30137234e+00
  -3.32432494e+00 -4.15986038e+00 -4.19571775e+00 -3.58955038e+00
  -2.97634992e+00 -2.53551280e+00]
 [ 2.70395605e-01  2.53929538e+00  3.92614235e+00  4.79342764e+00
   5.41940286e+00  4.67612774e+00  3.38943754e+00  2.63021505e+00
  -1.55547275e+00 -1.14368191e+01 -1.86348547e+01 -1.45172382e+01
  -3.65398619e+00  4.39097954e+00  7.90587764e+00  9.62022848e+00
   1.00923171e+01  9.28505759e+00  7.87543405e+00  5.81792912e+00
   2.99696062e+00  1.47987964e-01 -2.05394612e+00 -3.59095533e+00
  -4.76422421e+00 -5.60834768e+00 -5.63262705e+00 -4.73405703e+00
  -3.76954149e+00 -3.18355964e+00]
 [ 3.8

  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))
  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))


[[ 20.59402029  22.33017808  20.55202261  17.00235829  11.96532021
    5.41215281  -1.1379915   -6.55339279 -11.23776255 -14.63418663
  -15.17605159 -13.32801191 -10.96281562  -8.3178379   -4.8396439
   -1.42471975   0.83604508   2.24185052   3.31560495   4.00420439
    4.40580944   4.58119893   4.13290957   3.20895011   2.78508479
    3.10398575   3.19546846   2.4830253    1.52713087   0.86222012]
 [ 27.59548568  29.88063054  27.56391056  22.59428582  15.70166409
    7.07869217  -1.60172063  -8.88655188 -15.0020296  -19.4202193
  -20.40382204 -18.18577617 -14.86108192 -11.14390471  -6.70152661
   -2.41571488   0.68675882   2.93565137   4.70741917   5.71660471
    6.04926054   5.99383182   5.39889009   4.47357243   4.0146354
    4.14571601   4.09595178   3.39848665   2.30050863   1.20518432]
 [ 34.32827711  37.89481     35.67246711  29.8386187   21.65061585
   11.52045513   0.91091734  -8.44438982 -16.18892196 -21.89661273
  -24.14707887 -22.83641168 -19.63114243 -15.40886514 -10.22527

  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))
  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))


[[ 4.75514008e-01  4.77767709e-01  7.59344110e-02 -3.84799115e-01
  -4.42577016e-01 -9.57788910e-01 -1.85246204e+00 -1.23304022e+00
  -1.53144484e+00 -8.65014776e+00 -1.91063386e+01 -2.03800139e+01
  -9.06147814e+00  5.09216911e+00  1.42492481e+01  1.85810782e+01
   1.96143607e+01  1.78263671e+01  1.43064946e+01  9.87674542e+00
   4.71985213e+00 -5.23902365e-01 -5.05817867e+00 -8.51858625e+00
  -1.06274042e+01 -1.12357346e+01 -1.06377467e+01 -9.28745790e+00
  -7.48531030e+00 -5.44798626e+00]
 [ 5.84334326e-01  6.86448012e-01 -5.03533810e-02 -8.56824537e-01
  -8.98923169e-01 -1.53631675e+00 -2.58923316e+00 -1.54904209e+00
  -2.32217888e+00 -1.27311627e+01 -2.67725457e+01 -2.72385259e+01
  -1.10012880e+01  7.75509861e+00  1.92943029e+01  2.50021172e+01
   2.68010068e+01  2.45720262e+01  1.97392070e+01  1.37652284e+01
   6.71976266e+00 -7.53082570e-01 -7.18322399e+00 -1.18152050e+01
  -1.46286736e+01 -1.55924555e+01 -1.48012904e+01 -1.28043236e+01
  -1.02862376e+01 -7.55319464e+00]
 [ 1.4

  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))
  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))


[[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]]
[nan nan nan nan nan nan]
[[-6.50790802e-01 -1.16053598e-01  5.14592901e-01 -5.28995338e-01
  -9.25581342e-01  8.55702812e-01  6.64224992e-01 -2.22487147e+00
  -1.16027541e+00  4.22588573e+00  5.09309316e+00  4.63283291e-02
  -2.54652730e+00 -6.9839

  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))


[[ 1.70637376e+00  1.64535102e+00  1.21807665e+00  8.06961431e-01
   5.00141716e-01 -7.61757242e-01 -1.55382341e+00 -8.61017028e-01
  -5.25648828e+00 -1.88363984e+01 -2.91097084e+01 -2.19621500e+01
  -3.73921305e+00  9.74572398e+00  1.49098186e+01  1.67281230e+01
   1.66246229e+01  1.42691651e+01  1.13047171e+01  8.55570550e+00
   5.55874522e+00  2.76609601e+00  6.73537380e-01 -1.26056092e+00
  -3.18246563e+00 -4.54079598e+00 -5.36474019e+00 -6.01393318e+00
  -6.32108678e+00 -6.09133654e+00]
 [ 1.69936112e+00  1.32724626e+00  5.74827880e-01  3.43425891e-01
   1.74782862e-01 -1.57286707e+00 -2.46633942e+00 -1.39192285e+00
  -7.35953046e+00 -2.44856229e+01 -3.57055810e+01 -2.47347949e+01
  -1.66857078e+00  1.42367153e+01  1.98892551e+01  2.14120758e+01
   2.04499344e+01  1.71104982e+01  1.33834778e+01  9.94123775e+00
   6.24811707e+00  2.73999429e+00 -1.03834153e-01 -2.56457537e+00
  -4.65571637e+00 -6.09979501e+00 -7.00882908e+00 -7.53590455e+00
  -7.57920145e+00 -7.21168568e+00]
 [ 1.3

  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))


[[ 3.24549962e-01  3.51397722e-01 -3.77726728e-01 -1.61466902e+00
  -2.82380214e+00 -4.27340478e+00 -5.79042084e+00 -5.67718172e+00
  -3.06399108e+00  8.43138581e-01  4.97608701e+00  8.45599391e+00
   9.11184321e+00  5.98410495e+00  1.55912706e+00 -1.28509978e+00
  -2.21657012e+00 -2.39416631e+00 -2.60838808e+00 -2.64010866e+00
  -2.02239258e+00 -1.17385489e+00 -6.61069607e-01 -1.08998694e-01
   6.12524542e-01  8.27281370e-01  4.16133834e-01 -1.11710570e-01
  -4.40317504e-01 -2.63963462e-01]
 [ 1.16101402e+00  2.21156128e+00  2.78811429e+00  3.93384481e+00
   6.27135571e+00  8.50092981e+00  9.52326936e+00  6.72370417e+00
  -4.98539615e+00 -2.23355586e+01 -2.96092506e+01 -1.84307206e+01
  -7.78976158e-01  8.91693801e+00  1.04982444e+01  9.34611802e+00
   6.66074436e+00  3.49254224e+00  2.05434232e+00  1.56032671e+00
   2.43946354e-01 -9.19710974e-01 -6.36225254e-01  2.40042272e-01
   5.74327751e-01  4.30338998e-01  2.97271977e-01  3.30955967e-01
   3.65444629e-01  2.57069686e-01]
 [ 2.1

  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))


[[ -0.23603398   0.06554448   0.4199603   -0.21733275  -2.0710238
   -2.87658662  -0.56751191   3.50000598   5.32434182   1.76990221
   -5.35134014  -9.33397496  -6.08625128   0.59783691   4.89562017
    5.97650568   5.89180964   5.2614592    3.92699159   2.40463795
    1.09621847  -0.06798018  -1.04604184  -1.80403706  -2.48249776
   -2.96392204  -2.95584772  -2.64222539  -2.4990506   -2.39471914]
 [ -2.75397777  -2.70589421  -2.37238626  -3.24706426  -5.51240388
   -5.69518819  -1.31281963   5.31866267   9.75782737   8.77022174
    2.99040961  -1.78085446  -0.60461722   3.92119475   6.3027245
    5.77333658   4.43971408   2.81452125   0.81881936  -0.90319553
   -2.23330665  -3.47775543  -4.36979685  -4.73502965  -4.87490419
   -4.73147253  -4.07968651  -3.30151655  -2.75122848  -2.01568148]
 [ -0.87666363  -1.07177742  -0.53794536  -2.20907637  -7.13507521
   -9.25158712  -3.85500183   4.47722173   8.75069806   7.15618043
    1.84409148  -2.86429151  -2.79405913   1.14529111   4.2327

  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))
  mean_session_shank_waveform = np.nanmean(session_shank_waveform, axis=0)
  print(np.nanmax(mean_session_shank_waveform, axis=1))


[[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan]]
[nan nan nan nan nan nan]
[[ 4.49913529e+00  6.33086099e+00  8.62731014e+00  1.26677455e+01
   1.47523925e+01  1.12195842e+01  5.47800169e+00 -2.58927833e-01
  -1.10239451e+01 -2.57774653e+01 -3.18349263e+01 -2.34769034e+01
  -1.09457471e+01 -3.8768

KeyboardInterrupt: 

In [74]:
a = np.array([[1, 1],[2, 2],[4, 4]])
np.nanmean(a, axis=0)

array([2.33333333, 2.33333333])

In [75]:
a.shape

(3, 2)

In [64]:
c = b[np.array([0])]
c

TypeError: only integer scalar arrays can be converted to a scalar index