In [None]:
import brainsss
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.cluster import AgglomerativeClustering
import scipy
from scipy.cluster.hierarchy import dendrogram
from scipy.cluster.hierarchy import fcluster
from scipy.cluster import hierarchy
import matplotlib as mpl
from matplotlib.pyplot import cm
import random
from scipy.stats import sem
import time
import h5py
import ants
import nibabel as nib
import matplotlib
from scipy.ndimage import gaussian_filter1d
from scipy import signal
from scipy.interpolate import interp1d

In [None]:
data_dir = '/oak/stanford/groups/trc/data/Brezovec/2P_Imaging/20190101_walking_dataset'
superfly_path = os.path.join(data_dir, 'date_superfly', 'superslices')
flies = [1,2,3]

true_exp_len = 1800000 #in ms. 

In [None]:
###########################
### PREP VISUAL STIMULI ###
###########################

stim_ids = []
stimulus_start_times = []
starts_angle_ms = {0:[],180:[]}
starts_angle = {0:[],180:[]}

for j,fly in enumerate(flies):
    vision_path = os.path.join(data_dir, F'fly_{fly}', 'func_0', 'visual')

    ### Load Photodiode ###
    t, ft_triggers, pd1, pd2 = brainsss.load_photodiode(vision_path)
    print(len(t))
    stimulus_start_times_ = brainsss.extract_stim_times_from_pd(pd2, t)+true_exp_len*j/1000 # in sec
    
    stimulus_start_times.extend(stimulus_start_times_)
    print(stimulus_start_times_[-1])

    ### Get Metadata ###
    stim_ids_, angles = brainsss.get_stimulus_metadata(vision_path)
    stim_ids.extend(stim_ids_)
    print(F"Found {len(stim_ids_)} presented stimuli.")

    # *100 puts in units of 10ms, which will match fictrac
    for angle in [0,180]:
        starts_angle[angle].extend([int(sst*100) for i,sst in enumerate(stimulus_start_times_) if angles[i] == angle])
        # get 1ms version to match neural timestamps    
        starts_angle_ms[angle] = np.array(starts_angle[angle]) * 10
    print(F"starts_angle_0: {len(starts_angle[0])}. starts_angle_180: {len(starts_angle[180])}")
    
    
### Stimulus concatenation is done here ###


In [None]:
####################
### Prep Fictrac ###
####################

fps = 100
resolution = 10 #desired resolution in ms
behaviors = ['dRotLabY', 'dRotLabZ']
shorts = ['Y','Z']
fictrac = {'Y':[],'Z':[]}

for fly in flies:
    fictrac_path = os.path.join(data_dir, F'fly_{fly}', 'func_0', 'fictrac')
    fictrac_raw = brainsss.load_fictrac(fictrac_path)
    expt_len = fictrac_raw.shape[0]/fps*1000
    for behavior,short in zip(behaviors,shorts):
        fictrac_smo = brainsss.smooth_and_interp_fictrac(fictrac_raw, fps, resolution, expt_len, behavior)
        fictrac[short].extend(fictrac_smo[:true_exp_len])
fictrac_timestamps = np.arange(0,true_exp_len*len(flies),resolution)


### Fictrac concatenation is done here ###

In [None]:
###########################################
### Extract Stimulus Triggered Behavior ###
###########################################

pre_window = 200
post_window = 400

behavior_traces = {}
mean_trace = {}
sem_trace = {}
for angle in [0,180]:
    behavior_traces[angle],mean_trace[angle],sem_trace[angle] = brainsss.extract_traces(fictrac,
                                                                                        starts_angle[angle],
                                                                                        pre_window,
                                                                                        post_window)

In [None]:
plt.figure(figsize=(10,10))

for angle,color in zip([0,180],['blue','red']):
    plt.plot(mean_trace[angle],color=color,linewidth=3)
    plt.fill_between(np.arange(len(mean_trace[angle])),mean_trace[angle]-sem_trace[angle], mean_trace[angle]+sem_trace[angle], color=color,alpha=0.3)
for line in [200,250,300]:
    plt.axvline(line,color='k',linestyle='--',lw=2)

In [None]:
#######################################
### Visually evoked behavior ##########
#######################################

stim_start = 250# make it a bit bigger
stim_stop = 300
av_thresh = 50

mean_turn = (mean_trace[0] + mean_trace[180]*-1)/2

ve_turns = {}
ve_turn_times = {}
for angle,direction in zip([0,180],['neg','pos']):
    ve_turns[angle], ve_turn_times[angle] = brainsss.get_visually_evoked_turns(behavior_traces[angle],
                                                                      mean_turn = mean_turn,
                                                                      start=stim_start, 
                                                                      stop=stim_stop, 
                                                                      r_thresh=.2, 
                                                                      av_thresh=av_thresh, 
                                                                      stim_times=starts_angle_ms[angle],
                                                                      expected_direction=direction)
#### 0.2 in what unit ###

In [None]:
def get_stimuli_where_no_behavior(traces, start, stop, num_traces_to_return, stim_times):
    amount_of_behavior = np.mean(np.abs(traces[:,start:stop]),axis=-1)
    indicies = np.argsort(amount_of_behavior)
    top_x_indicies = indicies[:num_traces_to_return]
    return traces[top_x_indicies,:], np.asarray(stim_times)[top_x_indicies]

In [None]:
ve_no_turns = {}
ve_no_turn_times = {}
for angle in [0,180]:
    ve_no_turns[angle], ve_no_turn_times[angle] = get_stimuli_where_no_behavior(behavior_traces[angle],
                                  start=250,
                                  stop=300,
                                  num_traces_to_return=len(ve_turns[angle]), # get the same number as ve_turns
                                  stim_times=starts_angle_ms[angle])

In [None]:
###### Load Neural data ########

In [None]:
dims = {'x': 240,
        'y': 112,
        'z': 36,
        't': 3384}

In [None]:
#########################
### POST-WARP LOADING ###
#########################

n_clusters = 2000

load_file = os.path.join(superfly_path, 'cluster_labels.npy')
cluster_labels = np.load(load_file)

load_file = os.path.join(superfly_path, 'cluster_signals.npy')
all_signals = np.load(load_file)

dim_z = 36

In [None]:
fixed = brainsss.load_fda_meanbrain()
atlas = brainsss.load_roi_atlas()
explosion_rois = brainsss.load_explosion_groups()
all_rois = brainsss.unnest_roi_groups(explosion_rois)
roi_masks = brainsss.make_single_roi_masks(all_rois, atlas)
roi_contours = brainsss.make_single_roi_contours(roi_masks, atlas)

In [None]:
### Make STA_Brain ####

In [None]:
###########################
### Create Notch Filter ###
###########################

fs = 10.0  # Sample frequency (Hz)
f0 = 1.8797  # Frequency to be removed from signal (Hz)
Q = .8  # Quality factor <---- IMPORTANT
w0 = f0/(fs/2)  # Normalized Frequency

# Design notch filter
b_notch, a_notch = signal.iirnotch(w0, Q)

# Frequency response

w, h = signal.freqz(b_notch, a_notch)

# Generate frequency axis
freq = w*fs/(2*np.pi)

In [None]:
def make_STA_brain(neural_signals, neural_timestamps, event_times_list, neural_bins):
    #### super voxel version
    
    num_z = neural_signals.shape[0]

    STA_brain = []
    for z in range(num_z):
        all_bin_indicies = []
        for stim_idx in range(len(event_times_list)):
            stim_time = event_times_list[stim_idx]
            stim_centered_bins = neural_bins + stim_time
            bin_indicies = np.digitize(neural_timestamps[:,z] , stim_centered_bins)
            all_bin_indicies.append(bin_indicies)
        all_bin_indicies = np.asarray(all_bin_indicies)

        avg_neural_across_bins = []
        for bin_num in np.arange(1,len(neural_bins)):
            this_bin_sample_times = list(np.where(all_bin_indicies==bin_num)[1])
            average_neural_in_bin = np.mean(neural_signals[z,:,this_bin_sample_times],axis=0)
            avg_neural_across_bins.append(average_neural_in_bin)
        avg_neural_across_bins = np.asarray(avg_neural_across_bins)
        STA_brain.append(avg_neural_across_bins)
    STA_brain = np.asarray(STA_brain)
    return STA_brain

In [None]:
all_warps = {}
for condition in ['ve_no_0','ve_no_180','ve_0','ve_180']:
    print(condition)
    
    if '180' in condition:
        angle = 180
    else:
        angle = 0
    if 'no' in condition:
        event_times_list = ve_no_turn_times[angle]
    else:
        event_times_list = ve_turn_times[angle]
    
    t0 = time.time()
    STA_brain = make_STA_brain(neural_signals = all_signals,
                                       neural_timestamps = timestamps,
                                       event_times_list = event_times_list,
                                       neural_bins = neural_bins)
    print(F'STA {time.time()-t0}')
    print("Shape of STA_brain:", STA_brain.shape)

    
    STA_brain = signal.filtfilt(b_notch, a_notch, STA_brain, axis=1)
    reformed_STA_brain = brainsss.STA_supervoxel_to_full_res(STA_brain, cluster_labels, dims['x'], dims['y'], dims['z'])
    STA_brain = gaussian_filter1d(reformed_STA_brain,sigma=1,axis=1,truncate=1)
    
    ### upsample to (2,2,2)
    temp = STA_brain.copy()
    temp = np.moveaxis(temp,0,-1)
    temp = np.moveaxis(temp,0,-1)
    temp = ants.from_numpy(temp)
    temp.set_spacing((2.611,2.611,5,1))
    temp = ants.resample_image(temp,(314, 146, 91, len(neural_bins)-1),use_voxels=True)
    
    #t0 = time.time()
    #warps = brainsss.warp_STA_brain(STA_brain=STA_brain, fly='fly_134', fixed=fixed, anat_to_mean_type='myr')
    all_warps[condition] = temp.numpy()
    #print(F'Warps {time.time()-t0}')

In [None]:
all_explosions = {}
for condition in ['ve_no_0','ve_no_180','ve_0','ve_180']:
    print(condition)
    explosions = []
    t0 = time.time()
    for tp in range(len(neural_bins)-1):
        input_canvas = np.ones((500,500,3)) #+.5 #.5 for diverging
        data_to_plot = all_warps[condition][:,:,::-1,tp]#[tp]#[:,:,::-1]
        vmax = 1
        explosion_map = brainsss.place_roi_groups_on_canvas(explosion_rois,
                                                            roi_masks,
                                                            roi_contours,
                                                            data_to_plot,
                                                            input_canvas,
                                                            vmax=vmax,
                                                            cmap='seismic',
                                                            diverging=True)#'hot')
        explosions.append(explosion_map)
    print(F'Explosion {time.time()-t0}')
    all_explosions[condition] = explosions