In [None]:
VERSION=1

In [None]:
%reload_ext autoreload
%autoreload 2
import os
import sys
import defopt
import pickle
import numpy as np
import pandas as pd
import scipy.stats
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42 # save text as text not outlines
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.colors import ListedColormap
import matplotlib
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import warnings

# from depth_analysis_2p.filepath import generate_filepaths
# from depth_analysis_2p.vis_stim import vis_stim_structure
# from depth_analysis_2p.process_params.process_params import create_speed_arr, create_trace_arr_per_roi, calculate_OF, thr
# # from depth_analysis_2p.plotting.plotting_utils import get_binned_arr, get_confidence_interval, plot_raster, plot_line_with_error, gaussian_func, plot_dFF_binned_speed, plot_frame_off
# from depth_analysis_2p.plotting.plotting_utils import *

from cottage_analysis.depth_analysis.filepath import generate_filepaths
from cottage_analysis.stimulus_structure.sphere_structure import *
from cottage_analysis.depth_analysis.depth_preprocess.process_params import create_speed_arr, create_trace_arr_per_roi, calculate_OF, thr
from cottage_analysis.depth_analysis.plotting.plotting_utils import *
from cottage_analysis.depth_analysis.plotting.basic_vis_plots import *

In [None]:

def get_trace_arrs(roi, dffs, depth_list, stim_dict,
                                            mode='sort_by_depth', protocol='fix_length',
                                            blank_period=5, frame_rate=15):
    # Trace array of dFF
    trace_arr, _ = create_trace_arr_per_roi(roi, dffs, depth_list, stim_dict,
                                            mode='sort_by_depth', protocol='fix_length',
                                            blank_period=blank_period, frame_rate=frame_rate)
    # trace_arr_mean = np.nanmean(trace_arr, axis=1)
    trace_arr_noblank, _ = create_trace_arr_per_roi(roi, dffs, depth_list, stim_dict,
                                                    mode='sort_by_depth', protocol='fix_length',
                                                    blank_period=0, frame_rate=frame_rate)
    # trace_arr_noblank_mean = np.nanmean(trace_arr_noblank, axis=1)
    trace_arr_blank, _ = create_trace_arr_per_roi(roi, dffs, depth_list, stim_dict,
                                                  mode='sort_by_depth', protocol='fix_length',
                                                  isStim=False, blank_period=0,
                                                  frame_rate=frame_rate)
    return trace_arr_noblank, trace_arr_blank

    
MIN_SIGMA=0.5

def gaussian_func(x, a, x0, log_sigma,b):
    a = a
    sigma = np.exp(log_sigma)+MIN_SIGMA
    return (a * np.exp(-(x - x0) ** 2) / (2 * sigma ** 2))+b

# def plot_depth_tuning_curve(trace_arr_noblank,
#                             depth_list,
#                             plot_rows, plot_cols, fontsize_dict,
#                             grid=True,
#                             this_depth=None,
#                            gaussian_fit=False, popt = [],
#                            title='Depth tuning', ylabel='dF/F', xlabel='Depth (cm)', linewidth=1):
#     # --- Plot 4 (0,3): Depth tuning curve ---
#     trace_arr_noblank_cp = trace_arr_noblank.copy()
# #     trace_arr_noblank_cp[speed_arr_noblank < speed_thr_cal] = np.nan
#     trace_arr_mean_eachtrial = np.nanmean(trace_arr_noblank_cp, axis=2)
#     CI_lows = np.zeros(len(depth_list))
#     CI_highs = np.zeros(len(depth_list))
#     for idepth in range(len(depth_list)):
#         CI_lows[idepth], CI_highs[idepth] = get_confidence_interval(
#             trace_arr_mean_eachtrial[idepth, :],
#             mean_arr=np.nanmean(trace_arr_mean_eachtrial, axis=1)[idepth].reshape(-1, 1))
        
#     if gaussian_fit:
#         if (this_depth == None) or (this_depth!=len(depth_list)): # we can't plot the tuning curve for non-depth-selective neurons
#             plot_line_with_error(arr=np.nanmean(trace_arr_mean_eachtrial, axis=1), CI_low=CI_lows,
#                                  CI_high=CI_highs, linecolor='b', fontsize_dict=fontsize_dict, linewidth=linewidth)

#             trace_arr_mean_eachtrial = np.nanmean(trace_arr_noblank, axis=2)
#             x = np.log(np.repeat(np.array(depth_list), trace_arr_mean_eachtrial.shape[1]))
#             roi_number = np.where(depth_neurons == roi)[0][0]
            
#             plt.plot(np.linspace(0, len(depth_list) - 1, 100),
#                      gaussian_func(np.linspace(np.log(depth_list[0]*100), np.log(depth_list[-1]*100), 100), *popt), 'gray', linewidth=3)
#             plt.xticks(np.arange(len(depth_list)), (np.array(depth_list) * 100).astype('int'),
#                        fontsize=fontsize_dict['xticks'])
#             plt.yticks(fontsize=fontsize_dict['yticks'])
#             plt.ylabel(ylabel, fontsize=fontsize_dict['title'])
#             plt.xlabel(xlabel, fontsize=fontsize_dict['xlabel'])
#             plt.title(title, fontsize=fontsize_dict['ylabel'])
#             plot_frame_off()
#     else:
#         plot_line_with_error(arr=np.nanmean(trace_arr_mean_eachtrial, axis=1), CI_low=CI_lows,
#                              CI_high=CI_highs, linecolor='royalblue', fontsize_dict=fontsize_dict, linewidth=linewidth)
#         plt.xticks(np.arange(len(depth_list)), (np.array(depth_list) * 100).astype('int'),
#                    fontsize=fontsize_dict['xticks'])
#         plt.ylabel(ylabel, fontsize=fontsize_dict['ylabel'])
#         plt.xlabel(xlabel, fontsize=fontsize_dict['xlabel'])
#         plt.title(title, fontsize=fontsize_dict['title'])
#         plot_frame_off()
        
        
def plot_line_with_error(arr, CI_low, CI_high, linecolor, fontsize_dict, 
                         label=None, marker='-', markersize=None, xarr=[], xlabel=None, ylabel=None, 
                         title_on=False, title=None, suffix=None, linewidth=0.5, rasterized=False):
    if len(xarr) == 0:
        plt.plot(arr, marker, c = linecolor, linewidth=linewidth, label=label, alpha = 1, markersize=markersize, rasterized=rasterized)
        plt.fill_between(np.arange(len(arr)), CI_low, CI_high, color=linecolor, alpha=0.3, edgecolor=None, rasterized=rasterized)
    else:
        plt.plot(xarr, arr, marker, c = linecolor, linewidth=linewidth, label=label, alpha = 1, markersize=markersize, rasterized=rasterized)
        plt.fill_between(xarr, CI_low, CI_high, color=linecolor, alpha=0.3, edgecolor=None, rasterized=rasterized)
    plt.xlabel(xlabel, fontsize=fontsize_dict['xlabel'])
    plt.ylabel(ylabel, fontsize=fontsize_dict['ylabel'])
    if title_on:
        plt.title(title+' '+suffix, fontsize=fontsize_dict['title'])
    else:
        plt.title(suffix, fontsize=fontsize_dict['title'])
 
 

def calculate_neuron_sta(roi, depth_list, stim_fp, shift, resolution=1):
    STA_depths = np.zeros((len(depth_list),stim_fp.shape[1],stim_fp.shape[2]))
    STA_depths_normed_stim_all_shifts = np.zeros((len(depth_list),stim_fp.shape[1],stim_fp.shape[2]))
    STA_total_normed_stim_all_shifts = np.zeros((1,stim_fp.shape[1],stim_fp.shape[2]))
    
    for idepth in range(len(depth_list)):
        stim_this_depth = np.load(root+'data/stim_depth_'+str(depth_list[idepth])+'_res'+str(resolution)+'.npy')
        stim_this_depth_flat = stim_this_depth.reshape(stim_this_depth.shape[0],-1)

        #shift spikes
        spks_this_depth = np.load(root+'data/spks_depth_'+str(depth_list[idepth])+'_res'+str(resolution)+'.npy')
        spks_this_depth_shift = np.roll(spks_this_depth, int(np.round(shift)),axis=1)
        
        # Normalize by the averageof spikes
#         spks_shift = np.roll(spks, int(shift),axis=1)
        spks_ave = np.mean(spks_this_depth, axis=1).reshape(-1,1)
#         spks_shift = spks_shift/spks_sum
        spks_this_depth_shift = spks_this_depth_shift-spks_ave

        STA = np.dot(stim_this_depth_flat.T,spks_this_depth_shift[roi])
        STA = STA.reshape(stim_fp.shape[1],stim_fp.shape[2])
        STA_depths[idepth] = STA
        stim_fp_flat = stim_fp.reshape(stim_fp.shape[0],-1)
#         STA_total = np.dot(stim_fp_flat.T,spks_shift[roi])
#         STA_total =  STA_total.reshape(stim_fp.shape[1],stim_fp.shape[2])
    #     plt.imshow(STA_total.T, vmin=0,vmax=np.nanmax(STA_depths), extent=extent)
    #     plt.colorbar()
    #     plt.tight_layout(pad=2)


        # normalize by sum of stimulus matrix across time
        stim_all_depths_average = np.zeros((len(depth_list),stim_fp.shape[1],stim_fp.shape[2]))
        for idepth in range(len(depth_list)):
            stim_this_depth = np.load(root+'data/stim_depth_'+str(depth_list[idepth])+'_res'+str(resolution)+'.npy')
            stim_this_depth_average = np.sum(stim_this_depth,axis=0)
            stim_all_depths_average[idepth] = stim_this_depth_average
        STA_depths_normed_stim = np.divide(STA_depths,stim_all_depths_average)
#         STA_depths_normed_stim = STA_depths - stim_all_depths_average
        STA_depths_normed_stim_all_shifts = STA_depths_normed_stim
    return STA_depths_normed_stim       
        


In [None]:
save_root = '/camp/lab/znamenskiyp/home/shared/presentations/Cosyne2023/ver'+str(VERSION)+'/STA_examples/'
if not os.path.exists(save_root):
    os.makedirs(save_root)
root = '/camp/lab/znamenskiyp/home/shared/presentations/PetrTalk202302/STA/'
depth_list = [0.06, 0.19, 0.60, 1.9, 6]
blank_period = 0
cmap = cm.cool.reversed()
legend_on = False
frame_rate = 15
speed_thr_cal = 0.2  # m/s, threshold for running speed when calculating depth neurons
speed_thr = 0.01  #m/s

max_sphere_num = 24
azi_min = -30
azi_max = 210
ele_min = -40
ele_max = 40
sphere_size = 10 # degrees
frame_rate = 30
resolution = 1
azi_n = int((azi_max-azi_min)/resolution)
ele_n = int((ele_max-ele_min)/resolution)
azi_range = azi_max - azi_min
ele_range = ele_max - ele_min



dffs = np.load(root+'data/dffs_ast.npy')
with open(root+'data/img_VS.pickle', 'rb') as handle:
    img_VS = pickle.load(handle)
with open(root+'data/stim_dict.pickle', 'rb') as handle:
    stim_dict = pickle.load(handle)
stim_fp = np.memmap(root+'data/stim_all_res'+str(resolution)+'_new.npy', dtype=np.int32, mode='r', 
               shape=((len(img_VS), int(azi_n), int(ele_n))))
depth_neurons = np.load(root+'data/depth_neurons.npy')
with open(root+'data/gaussian_depth_tuning_fit_new_0.5.pickle', 'rb') as handle:
    gaussian_depth= pickle.load(handle)


speeds = img_VS.MouseZ.diff() / img_VS.HarpTime.diff() # with no playback. EyeZ and MouseZ should be the same.
speeds[0] = 0
speeds = thr(speeds, speed_thr)
speed_arr_original, _ = create_speed_arr(speeds, depth_list, stim_dict, mode='sort_by_depth', protocol='fix_length',
                                blank_period=0, frame_rate=frame_rate)
speed_arr_mean_original = np.nanmean(speed_arr_original,axis=1)
speed_arr_noblank_original,_ = create_speed_arr(speeds, depth_list, stim_dict, mode='sort_by_depth', protocol='fix_length', blank_period=0, frame_rate=frame_rate)
speed_arr_noblank_mean_original = np.nanmean(speed_arr_noblank_original,axis=1)
speed_arr_blank_original,_ = create_speed_arr(speeds, depth_list, stim_dict, mode='sort_by_depth', protocol='fix_length', isStim=False, blank_period=0, frame_rate=frame_rate)
frame_num_pertrial_max_original = speed_arr_noblank_original.shape[2]
total_trials_original = speed_arr_noblank_original.shape[1]
optics_original = calculate_OF(rs=speeds, img_VS=img_VS, mode='no_RF')


# plt.figure(figsize=(5,7))
# for i, roi in enumerate(select_rois):
#     plt.subplot(len(select_rois),1,i+1)
#     trace_arr_noblank, trace_arr_blank = get_trace_arrs(roi=roi, dffs=dffs, 
#                                                      depth_list=depth_list, stim_dict=stim_dict,
#                                 mode='sort_by_depth', protocol='fix_length',
#                                 blank_period=blank_period, frame_rate=frame_rate)

#     plot_depth_tuning_curve(dffs=dffs, 
#                             speeds=speeds, 
#                             roi=roi, 
#                             speed_thr_cal=speed_thr_cal, 
#                             depth_list=depth_list, 
#                             stim_dict=stim_dict, 
#                             depth_neurons=depth_neurons,
#                             gaussian_depth=gaussian_depth, 
#                             fontsize_dict=fontsize_dict, 
#                             ylim=None, 
#                             frame_rate=15,
#                             this_depth=5)
#     ylim_tuning = plt.gca().get_ylim()
#     plt.ylim((-0.01,ylim_tuning[1]))
#     plt.ylabel('\u0394F/F', fontsize=fontsize_dict['ylabel'])
#     plt.yticks(fontsize=fontsize_dict['yticks'])
#     plt.title('')


# plt.savefig(save_root+'ROI'+str(select_rois)+'_depth_tuning.pdf')

        
    

In [None]:

line_colors = []
norm = matplotlib.colors.Normalize(vmin=np.log(min(depth_list)), vmax=np.log(max(depth_list)))
for depth in depth_list:
    rgba_color = cmap(norm(np.log(depth)),bytes=True)
    rgba_color = tuple(it/255 for it in rgba_color)
    line_colors.append(rgba_color)
                                   

N = 256
vals = np.ones((N, 4))
vals[:, 0] = np.linspace(1, 1, N)
vals[:, 1] = np.linspace(1, 0, N)
vals[:, 2] = np.linspace(1, 0, N)
WhRdcmap = ListedColormap(vals)

#Fontsizes
fontsize_dict = {
    'title':10,
    'xlabel':30,
    'ylabel':30,
    'xticks':20,
    'yticks':20,
    'legend':15,
    'text':15
}

frame_rate = 15
shift = -0.13*frame_rate
extent=[-120,120,-40,40]


# select_rois = [110,164, 1, 18, 58, 15, 61, 22, 0, 4, 10, 21, 42]
# vmaxs=[30,30, 200,20,40,40,40, 15, 40, 20,20, 20, 6]
select_rois = [110, 1, 58, 61, 18, 15, 21, 42]
vmaxs=[30, 200, 40, 15, 20, 40, 10, 7]
plot_rows = 6
plot_cols = len(select_rois)

colorbar_on = True
plt.figure(figsize=(82/2.54, 18/2.54))
for i, roi in enumerate(select_rois):
    STA_depths_normed_stim = calculate_neuron_sta(roi=roi, depth_list=depth_list, stim_fp=stim_fp, shift=shift, resolution=1)
    
    for idepth in range(len(depth_list)):
        ax_main=plt.subplot(plot_rows, plot_cols, plot_cols*(idepth+1)+i+1)
            
    #         plt.imshow(STA_depths_normed_stim[idepth].T,vmin=0,vmax=np.nanpercentile(STA_depths_normed_stim,99.9),extent=[-210,30,-40,40])
        img = ax_main.imshow(STA_depths_normed_stim[idepth].T, extent=extent, cmap='RdBu_r', vmin=-vmaxs[i], vmax=vmaxs[i], aspect='auto')
        # PCM=ax_main.get_children()[2] 
        # set_aspect_ratio(ax1, ratio=0.5)
        ax_main.text(-110,10,str(int(depth_list[idepth]*100))+' cm', fontsize=fontsize_dict['text'],c='gray')
        if i==0:
            if idepth==len(depth_list)//2:
                ax_main.set_ylabel('Elevation (degrees)',fontsize=fontsize_dict['ylabel'])
        if (i==len(select_rois)//2) & (idepth==len(depth_list)-1):
            ax_main.set_xlabel('Azimuth (degrees)',fontsize=fontsize_dict['xlabel'])

        if idepth!=len(depth_list)-1:
            ax_main.tick_params(axis='x',bottom=True, labelbottom=False) 
        if i!=0:
            ax_main.tick_params(axis='y',left=True, labelleft=False) 
        xticks =ax_main.get_xticks()
        plt.xticks(np.arange(-120,121,1)[0::120], [-120,0,120],fontsize=fontsize_dict['xticks'], rotation = 30)
        yticks =ax_main.get_yticks()
        plt.yticks([yticks[1],yticks[-2]],fontsize=fontsize_dict['xticks'])
        
        divider = make_axes_locatable(ax_main)
        cax = divider.append_axes("right", size="2%", pad=0.05)
        cbar = plt.colorbar(img, cax=cax)
        cbar.ax.tick_params(labelsize=fontsize_dict['legend'])
        if idepth!=len(depth_list)-1:
            cbar.ax.tick_params(labelsize=fontsize_dict['legend'], labelright=False)
        
        if idepth == 0:
            inset1= ax_main.inset_axes([0, 1.2, 0.5, 1])
            plot_depth_tuning_curve(dffs=dffs, 
                        speeds=speeds, 
                        roi=roi, 
                        speed_thr_cal=speed_thr_cal, 
                        depth_list=depth_list, 
                        stim_dict=stim_dict, 
                        depth_neurons=depth_neurons,
                        gaussian_depth=gaussian_depth, 
                        fontsize_dict=fontsize_dict, 
                        ylim=None, 
                        frame_rate=15,
                        this_depth=5, 
                        linewidth=2, 
                        ax=inset1)
            ylim_tuning = inset1.get_ylim()
            inset1.set_ylim((-0.01,ylim_tuning[1]))
            inset1.set_xticks(np.arange(5))
            inset1.tick_params(axis='x',bottom=True, labelbottom=False) 
            # inset1.set_xticklabels((np.array(depth_list) * 100).astype('int').tolist(), fontsize=fontsize_dict['xticks'])
            yticks =inset1.get_yticks()
            inset1.set_yticks([yticks[1],yticks[-2]])
            inset1.tick_params(axis='y',left=True, labelleft=False) 
            # inset1.tick_params(axis='y', labelsize=fontsize_dict['yticks'])
            # inset1.set_ylabel('\u0394F/F', fontsize=fontsize_dict['ylabel'])
            # inset1.set_xlabel('Virtual depth (cm)', fontsize=fontsize_dict['xlabel'])
            inset1.set_xlabel('')
            inset1.set_ylabel('')
            inset1.set_title('')
            inset1.spines["top"].set_visible(False)
            inset1.spines["right"].set_visible(False)
            
        

    
    # plt.tight_layout(pad=0.5)
    plt.subplots_adjust(wspace=0.4, hspace=0.3)


plt.savefig(f'{save_root}{select_rois}_RF_colorbar{colorbar_on}.pdf')