In [1]:
import h5py
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cmx
import scipy as sp
from scipy.stats import sem
from scipy import fftpack
from scipy.signal import detrend
import plotly.express as px
import seaborn as sns
import pandas as pd
import itertools

from ballrig_analysis.utils.plot_tools import plot_w_error

from flystim1.trajectory import SinusoidalTrajectory

In [2]:
def nansem(a, axis=0, ddof=1, nan_policy='omit'):
    return sem(a, axis, ddof, nan_policy)

def uneven_list2d_to_np(v, fillval=np.nan):
    lens = np.array([len(item) for item in v])
    if len(np.unique(lens)) == 1:
        return np.asarray(v)
    mask = lens[:,None] > np.arange(lens.max())
    out = np.full(mask.shape,fillval)
    out[mask] = np.concatenate(v)
    return out

def expand_relative_timestamps(timestamps, stim_timestamps):
    '''
    Takes in absolute timestamps and stim_timestamps that have time from stimulus start to end.
    Returns a timestamp array with 0 when the stimulus starts and negative values prior to that and continuation of positive values after the stimulus ends.
    This could fail if the stimulus was ever paused, as absolute time would have discontinuation.
    '''
    stim_start_idx = np.where(~np.isnan(stim_timestamps))[0][0]
    return (timestamps - timestamps[stim_start_idx])
def generate_standard_timestamp(timestamps):
    '''
    timestamps: 2d numpy array with nan padding for uneven timestamp lengths

    Finds mean framerate and generates a single timestamp series starting from 0 evenly spaced to the max timestamp.
    '''
    if not isinstance(timestamps, np.ndarray):
        timestamps = uneven_list2d_to_np(timestamps)
    mean_diff = np.nanmean(np.diff(timestamps))
    min_time = np.nanmin(timestamps)
    max_time = np.nanmax(timestamps)

    return np.arange(min_time, max_time, mean_diff)

def interpolate_to_new_timestamp(y, t, nt):
    '''
    y: 1d data, length same as t
    t: original timestamp
    nt: new timestamp to interpolate to
    Returns ny, linearly interpolated data at nt
    '''
    not_nan = ~np.isnan(y)
    return np.interp(nt, t[not_nan], y[not_nan])


def align_traces_to_standardized_timestamp(ts, xs):
    ts_standard = generate_standard_timestamp(ts)
    xs_standardized = np.array([interpolate_to_new_timestamp(xs[i], ts[i], ts_standard) for i in range(len(xs))])

    return ts_standard, xs_standardized
# def fly_trial_theta(fly_h5f, align_to = 0):
#     '''
#     Input:
#         trial_key
#     Returns:
#         ft_timestamps_rel: timestamps relative to trial start time
#         ft_theta
#     '''
#     start_time = fly_h5f.attrs['t_start']
#     ft_timestamps_abs = fly_h5f['ft_timestamps'][()]
#     ft_timestamps_rel = ft_timestamps_abs-start_time

#     ft_theta = -fly_h5f['ft_theta'][()]
#     ft_theta = np.unwrap(ft_theta)
#     ft_theta_at_t0 = np.interp(align_to, ft_timestamps_rel, ft_theta)
#     ft_theta -= ft_theta_at_t0
#     ft_theta = np.rad2deg(ft_theta)

#     return ft_timestamps_rel, ft_theta


# def convert_to_velocity(timestamp, x):
#     return np.diff(x) / np.diff(timestamp)

# def fly_trial_theta_velocity(fly_h5f, align_to=0):
#     ft_timestamps_rel, ft_theta = fly_trial_theta(fly_h5f, align_to)
#     theta_vel = convert_to_velocity(ft_timestamps_rel, ft_theta)

#     return ft_timestamps_rel[1:], theta_vel
def fly_trials_traces(fly_h5f, tkeys, fun_trial_trace=fly_trial_theta, verbose=False):
    '''
    Input:
        tkeys: list of trial keys
    Returns:
        timestamps
        signals
    '''
    ts_sig_pairs = [fun_trial_trace(fly_h5f, tkey) for tkey in tkeys]
    ts_sig_pairs_filtered = [pair for pair in ts_sig_pairs if pair is not None]

    if verbose:
        filtered_out_count = len(ts_sig_pairs) - len(ts_sig_pairs_filtered)
        if filtered_out_count > 0:
            print(f"{filtered_out_count}/{len(ts_sig_pairs)} trials filtered from {fly_h5f.attrs['save_prefix']}.")

    timestamps = [pair[0] for pair in ts_sig_pairs_filtered]
    signals = [pair[1] for pair in ts_sig_pairs_filtered]

    return timestamps, signals
def fly_trials_traces_aligned(fly_h5f, tkeys, fun_trial_trace=fly_trial_theta, verbose=False):
    '''
    Input:
        tkeys: list of trial keys
    Returns:
        timestamp_standard
        signals_standardized
    '''
    timestamps, signals = fly_trials_traces(fly_h5f, tkeys, fun_trial_trace, verbose)

    timestamps_np = uneven_list2d_to_np(timestamps)
    signals_np = uneven_list2d_to_np(signals)

    timestamp_standard, signals_standardized = align_traces_to_standardized_timestamp(timestamps_np, signals_np)

    return timestamp_standard, signals_standardized

def mean_and_error(signals_aligned, error_fun=nansem):
    '''
    Input:
        timestamp: 1d array of timestamps
        signals_aligned: 2d array of signals, trials by time
        do_plot
    Returns:
        fly_mean
        fly_error (stdev or sem)
    Side effect:
        Plot the mean with errorbars
    '''

    fly_mean = np.nanmean(signals_aligned, axis=0)
    fly_error = error_fun(signals_aligned, axis=0)

    return fly_mean, fly_error

def fly_mean_and_error_with_flip(fly_h5f, tidxes_reg, tidxes_flip, fun_trial_trace=fly_trial_theta, verbose=False):
    '''

    '''
    tidxes_combined = tidxes_reg + tidxes_flip
    ts, signals = fly_trials_traces_aligned(fly_h5f, tidxes_combined, fun_trial_trace, verbose)
    if len(tidxes_flip) > 0:
        signals[-len(tidxes_flip):] *= -1 #Flip trials
    mean, error = mean_and_error(signals)
    return ts, mean, error, signals

NameError: name 'fly_trial_theta' is not defined

In [None]:
data_dir = '/Users/Shirley/Desktop/fixationdatanew/'
flies = [x for x in sorted(os.listdir(data_dir)) if x[-3:]==".h5"]
flies


In [None]:
period = 8
savefig = False
fix_sine_amplitude = 15
fix_sine_period = 2
sin_traj = SinusoidalTrajectory(amplitude=fix_sine_amplitude, period=fix_sine_period) # period of 1 second

In [None]:
for fly in flies:
    fly_path = os.path.join(data_dir, fly)
    fly_h5f = h5py.File(fly_path, 'r')
    ts, trace = fly_trial_theta(fly_h5f, align_to=0)
    closed_loop = fly_h5f.attrs['closed_loop']
    fly_h5f.close()
    
    start_idxes = np.concatenate(([0],np.where(np.diff(np.digitize(ts, np.arange(0,ts[-1],period))).astype(bool))[0]+1,[len(ts)]))

    tss = []
    traces_absolute = []
    traces_relative = []
    for i in range(len(start_idxes)-1):
        tss.append(ts[start_idxes[i]:start_idxes[i+1]]%period)
        trial_trace = trace[start_idxes[i]:start_idxes[i+1]]
        traces_relative.append(trial_trace - trial_trace[0])
        traces_absolute.append((trial_trace+180) % 360 - 180)

    ts_std, traces_absolute_std = align_traces_to_standardized_timestamp(tss, traces_absolute)
    _, traces_relative_std = align_traces_to_standardized_timestamp(tss, traces_relative)

    bar_traj = sin_traj.eval_at(ts_std)
    
    fig, ax = plt.subplots(1,2, figsize=(14,4))
    plot_w_error(x = ts, y = [trace], title=fly, xlabel = "Time [s]", ylabel = "\u03B8 [\u00B0]", ax=ax[0])
    
    if not closed_loop:
        plot_w_error(x = ts_std,  
                     y = [np.mean(traces_relative_std, axis=0), bar_traj], 
                     ye= [    sem(traces_relative_std, axis=0),   [None]],
                     xlabel = "Time [s]", ylabel = "\u03B8 [\u00B0]",
                     legend=['Fly', 'Bar'], show_legend=True, 
                     ax=ax[1])
    else:
        plot_w_error(x = ts_std,  
                     y = [np.mean(traces_absolute_std, axis=0), np.mean(traces_relative_std, axis=0), bar_traj], 
                     ye= [    sem(traces_absolute_std, axis=0),     sem(traces_relative_std, axis=0), [None]],
                     xlabel = "Time [s]", ylabel = "\u03B8 [\u00B0]",
                     legend=['Fly (abs)', 'Fly (rel)', 'Bar'], show_legend=True, 
                     ax=ax[1])

    if savefig:
        fig.savefig(data_dir+os.path.sep+fly[:-3]+".png")