In [None]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import vr2p
# import vr2p.signal
from vr2p.gimbl.transform import add_ranged_timestamp_values
import linear2ac.place

import os
import pickle

import dask
from dask import delayed, compute
from tqdm import tqdm

# Helper functions

In [None]:
# Initializes and configures a Tyche1dProtocol instance for place field detection.
def create_pf_detect(min_speed, df_window_size=50, df_sigma_baseline=600, bootstrap_do_test=False):
    pf_detect = linear2ac.place.Tyche1dProtocol()
    pf_detect.config.min_speed = min_speed
    pf_detect.config.bootstrap_do_test = bootstrap_do_test
    pf_detect.config.df_window_size = df_window_size
    pf_detect.config.df_sigma_baseline = df_sigma_baseline
    return pf_detect

In [None]:
def get_correlation_matrix(data, data_ref):
    from scipy.stats import pearsonr

    corr_matrix = np.zeros((data_ref.shape[0],data.shape[0]))

    for ind_data_ref in range(data_ref.shape[0]):
        for ind_data in range(data.shape[0]):
            corr_matrix[ind_data_ref, ind_data] =  pearsonr(data[ind_data,:], data_ref[ind_data_ref,:])[0]

    return corr_matrix

import seaborn as sns
def plot_correlation_matrix_sns(ax, data, data_ref, ttl=None, vmin = -1, vmax = 1, reward_positions = [], reward_positions_ref = [], show_reward_positions=False):
    corr_matrix = get_correlation_matrix(data, data_ref)

    im = sns.heatmap(corr_matrix,cmap = 'icefire', vmin=vmin, vmax=vmax, ax=ax, square=True, cbar=True)
    ax.set_xlabel('Position (cm)',fontsize = 10)
    ax.set_ylabel('Position (cm)',fontsize = 10)
    xticks = [ii*10 for ii in range(1+int(np.floor(data.shape[0]/10)))]
    yticks = [ii*10 for ii in range(1+int(np.floor(data_ref.shape[0]/10)))]
    ax.set_xticks(ticks = xticks , fontsize=10)
    ax.set_yticks(ticks = yticks, fontsize=10)

    ax.set_title(ttl,fontsize=20)
    tick_font_size = 4
    xlabels = [str(ii*50) for ii in range(1+int(np.floor(data.shape[0]/10)))]
    ylabels = [str(ii*50) for ii in range(1+int(np.floor(data_ref.shape[0]/10)))]
    ax.set_xticklabels(xlabels, fontsize = 10)
    ax.set_yticklabels(ylabels, fontsize = 10)
    # ax.collections[0].colorbar.set_ticklabels([''])

    if show_reward_positions:
        for rwd_ps in reward_positions:
            ax.axvspan((rwd_ps-10)/5, (rwd_ps+10)/5, color='blue', alpha=0.25,linewidth=0)

        for rwd_ps in reward_positions_ref:
            ax.axhspan((rwd_ps-10)/5, (rwd_ps+10)/5, color='blue', alpha=0.25,linewidth=0)

# Load data

In [None]:
data_dir = '../stretched_trials/'

In [None]:
# List of datasets
animals = ['D1','F1_4','F5_2']
animal = 'F5_2'

In [None]:
# Initialize empty lists to store data for each animal
F_raw_list = []
vr_list = []
position_list = []

for animal in animals:
    # Load raw imaging data
    data_file_name = animal + '_F_raw.npy'
    F_raw = np.load(os.path.join(data_dir, data_file_name))

    # Load the vr object from the saved pickle file
    vr_dir = data_dir + animal + '_vr.pkl'
    with open(vr_dir, 'rb') as file:
        vr = pickle.load(file)

    # Load the dataFrame with the behavioral data
    position_dir = data_dir + animal + '_position.pkl'
    position = pd.read_pickle(position_dir)

    F_raw_list.append(F_raw)
    vr_list.append(vr)
    position_list.append(position)

# Detect Place Fields 

In [None]:
# Parameters
bin_size = 5
min_speed = 5
stop_potision = 230
stop_potision_stretch = 330
reward_positions_array = np.array([[140, 190],[190,290]])

pf_criteria = 'significant' # 'putative' or 'significant'

For each of the three datasets, calculate place fields for each cell during both regular and stretch trials, and obtain the average neural response for each trial.

%% Running this cell takes long (~30min).

In [None]:
class PlaceFieldData:
    pass

datasets = []

for ii in range(len(animals)):
    F_raw = F_raw_list[ii]
    vr = vr_list[ii]
    position = position_list[ii] 

    trial = vr.trial.copy()
    stopTrial = position.trial_number.max()
    if animal == 'F5_2':
        stopTrial = 111

    smooth_binF_normal = []
    smooth_binF_probe = []
    order_normal = []
    order_probe = []
    centers_normal = []
    centers_probe = []

    # Define a function for processing a single reward_id with parallelization
    @delayed
    def process_reward_id(reward_id):
        # Filter trials and frames
        selected_trials = trial.loc[(trial.reward_id == reward_id) & (trial.status != 'NO_RESPONSE') &
                                    (trial.trial_number <= stopTrial), 'trial_number']
        selected_frames_normal = position.loc[(position['probe_status'] == False) & position['trial_number'].isin(selected_trials), 'frame']
        selected_frames_stretch = position.loc[(position['probe_status'] == True) & position['trial_number'].isin(selected_trials), 'frame']

        # Detect place cells for normal and probe trials
        pf_detect_normal = create_pf_detect(min_speed)
        pf_detect_normal.detect(F_raw, vr, bin_size, selected_frames_normal, verbose=False)

        pf_detect_probe = create_pf_detect(min_speed)
        pf_detect_probe.detect(F_raw, vr, bin_size, selected_frames_stretch, track_size=stop_potision_stretch, verbose=False)

        # Collect results
        results = {
            "smooth_binF_normal": pf_detect_normal.smooth_binF,
            "centers_normal": pf_detect_normal.pf_significant.centers,
            "smooth_binF_probe": pf_detect_probe.smooth_binF,
            "centers_probe": pf_detect_probe.pf_significant.centers
        }

        # Append orders based on criteria
        if pf_criteria == 'significant':
            results["order_normal"] = pf_detect_normal.pf_significant.order
            results["order_probe"] = pf_detect_probe.pf_significant.order
        elif pf_criteria == 'putative':
            results["order_normal"] = pf_detect_normal.pf_putative.order
            results["order_probe"] = pf_detect_probe.pf_putative.order

        return results

    # Parallelize processing over reward_ids
    tasks = [process_reward_id(reward_id) for reward_id in [1, 2]]
    results = compute(*tasks)

    # Collect and organize results
    for result in results:
        smooth_binF_normal.append(result["smooth_binF_normal"])
        centers_normal.append(result["centers_normal"])
        smooth_binF_probe.append(result["smooth_binF_probe"])
        centers_probe.append(result["centers_probe"])
        order_normal.append(result["order_normal"])
        order_probe.append(result["order_probe"])

    pf_dataset = PlaceFieldData()
    pf_dataset.smooth_binF_normal = smooth_binF_normal
    pf_dataset.smooth_binF_probe = smooth_binF_probe
    pf_dataset.order_normal = order_normal
    pf_dataset.order_probe = order_probe
    pf_dataset.centers_normal = centers_normal
    pf_dataset.centers_probe = centers_probe 

    datasets.append(pf_dataset)   
        

# Figure 5 e, f

Identify each neuron’s place field during both regular and stretched trials.

In [None]:
def get_pf_locations(n_neuron,centers1,centers2):  
    com_array = []      
    for ii in range(n_neuron):
        com1 = centers1[centers1[:,0]==ii,1]
        com2 = centers2[centers2[:,0]==ii,1]

        if ~com1.any(): com1 = -1
        if ~com2.any(): com2 = -1

        com = np.array(np.meshgrid(com1,com2)).T.reshape(-1,2)
        com = np.hstack((ii*np.ones((len(com),1)),com))

        com_array.append(com)

    com_array = np.vstack(com_array)
        
    return com_array


In [None]:
pf_arrays_normal_probe = []

for ii in range(len(datasets)):
    n_neuron = datasets[ii].smooth_binF_normal[0].shape[0]
    
    pf_array_normal_probe = []

    for reward_id in range(2):
        com_normal_probe = get_pf_locations(n_neuron,datasets[ii].centers_normal[reward_id],datasets[ii].centers_probe[reward_id])
        pf_array_normal_probe.append(com_normal_probe)

    pf_arrays_normal_probe.append(pf_array_normal_probe)

Place field locations in the stretched Near trials plotted against those in the regular Near trials (n = 3 mice, data pooled together). 

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

trial_type = ['NEAR', 'FAR']
colors = ['magenta', 'green']

aspect_ratio = 330 / 230

# Figure width can be defined freely
fig_width = 8

# Figure height is determined by the width and aspect ratio
fig_height = fig_width / aspect_ratio

# Now we use these dimensions to create a figure
fig, ax = plt.subplots(1, 2, figsize=(fig_width * 2, fig_height),dpi = 500)  # Multiplying width by 2 because there are two subplots

for rid in range(2):
    ax[rid].scatter(pf_arrays_normal_probe[0][rid][:,2], pf_arrays_normal_probe[0][rid][:,1], s=2, color=colors[rid], rasterized=True)
    ax[rid].scatter(pf_arrays_normal_probe[1][rid][:,2], pf_arrays_normal_probe[1][rid][:,1], s=2, color=colors[rid], rasterized=True)
    ax[rid].scatter(pf_arrays_normal_probe[2][rid][:,2], pf_arrays_normal_probe[2][rid][:,1], s=2, color=colors[rid], rasterized=True)
    ax[rid].set_xlim(0, 330)
    ax[rid].set_ylim(0, 230)
    ax[rid].set_title(trial_type[rid])
    ax[rid].set_xlabel('Position (cm)')
    ax[rid].set_ylabel('Position (cm)')
    ax[rid].set_aspect('equal')
    ax[rid].invert_yaxis()

    # Add rectangles
    rect1 = Rectangle((130, 0), 50, 230, linewidth=2, edgecolor='black', facecolor='none', linestyle='--')  # linewidth increased to 2
    rect2 = Rectangle((230, 0), 50, 230, linewidth=2, edgecolor='black', facecolor='none', linestyle='--')  # linewidth increased to 2
    ax[rid].add_patch(rect1)
    ax[rid].add_patch(rect2)

# plt.savefig('stretched_trials.pdf', format='pdf', dpi=500)  # Save the figure as a pdf
plt.show()

The histograms of place field locations in the two stretched regions.

In [None]:
# Prepare arrays to hold histogram data
rect1_hist_data_near = []
rect2_hist_data_near = []
rect1_hist_data_far = []
rect2_hist_data_far = []

bins = range(0, 230, 5)

for ii in range(3):  # to go over com_arrays_normal_probe[0] through [2]

    for rid in range(2):
        # filter data that are inside the rectangles
        rect1_data = pf_arrays_normal_probe[ii][rid][(pf_arrays_normal_probe[ii][rid][:,2] >= 130) & 
                                                      (pf_arrays_normal_probe[ii][rid][:,2] <= 180) &
                                                      (pf_arrays_normal_probe[ii][rid][:,1] >= 0) &
                                                      (pf_arrays_normal_probe[ii][rid][:,1] <= 230)]
        
        rect2_data = pf_arrays_normal_probe[ii][rid][(pf_arrays_normal_probe[ii][rid][:,2] >= 230) & 
                                                      (pf_arrays_normal_probe[ii][rid][:,2] <= 280) &
                                                      (pf_arrays_normal_probe[ii][rid][:,1] >= 0) &
                                                      (pf_arrays_normal_probe[ii][rid][:,1] <= 230)]
        
        # calculate histograms of y coordinates
        rect1_hist, _ = np.histogram(rect1_data[:,1], bins=bins, density=True)
        rect2_hist, _ = np.histogram(rect2_data[:,1], bins=bins, density=True)

        # store the histogram data
        if trial_type[rid] == 'NEAR':
            rect1_hist_data_near.append(rect1_hist)
            rect2_hist_data_near.append(rect2_hist)
        else:  # 'FAR'
            rect1_hist_data_far.append(rect1_hist)
            rect2_hist_data_far.append(rect2_hist)
        

# calculate mean and SEM
rect1_hist_near_mean = np.mean(rect1_hist_data_near, axis=0)
rect1_hist_near_sem = np.std(rect1_hist_data_near, axis=0) / np.sqrt(3)
rect2_hist_near_mean = np.mean(rect2_hist_data_near, axis=0)
rect2_hist_near_sem = np.std(rect2_hist_data_near, axis=0) / np.sqrt(3)

rect1_hist_far_mean = np.mean(rect1_hist_data_far, axis=0)
rect1_hist_far_sem = np.std(rect1_hist_data_far, axis=0) / np.sqrt(3)
rect2_hist_far_mean = np.mean(rect2_hist_data_far, axis=0)
rect2_hist_far_sem = np.std(rect2_hist_data_far, axis=0) / np.sqrt(3)

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(8, 4), dpi = 500)

# Use darker shades of green and pink-magenta
dark_green = '#008000' 
light_green = '#90EE90'
magenta = '#FF00FF'
pink = '#FFC0CB'

bin_centers = np.arange(5,230,5)

# plot mean and SEM for each rectangle and trial type
ax[0].plot(rect1_hist_near_mean, bin_centers, color=pink, label='Rectangle 1, NEAR')
ax[0].fill_betweenx(bin_centers, rect1_hist_near_mean-rect1_hist_near_sem, rect1_hist_near_mean+rect1_hist_near_sem, color=pink, alpha=0.3)
ax[1].plot(rect2_hist_near_mean, bin_centers, color=magenta, label='Rectangle 2, NEAR')
ax[1].fill_betweenx(bin_centers, rect2_hist_near_mean-rect2_hist_near_sem, rect2_hist_near_mean+rect2_hist_near_sem, color=magenta, alpha=0.3)
ax[2].plot(rect1_hist_far_mean, bin_centers, color=light_green, label='Rectangle 1, FAR')
ax[2].fill_betweenx(bin_centers, rect1_hist_far_mean-rect1_hist_far_sem, rect1_hist_far_mean+rect1_hist_far_sem, color=light_green, alpha=0.3)
ax[3].plot(rect2_hist_far_mean, bin_centers, color=dark_green, label='Rectangle 2, FAR')
ax[3].fill_betweenx(bin_centers, rect2_hist_far_mean-rect2_hist_far_sem, rect2_hist_far_mean+rect2_hist_far_sem, color=dark_green, alpha=0.3)

for ii in range(4):
    ax[ii].set_xlim(0, 0.023)  # Set the limit of the x-axis
    ax[ii].set_xlabel('Probability')
    ax[ii].set_ylabel('Y Position (cm)')
    ax[ii].invert_yaxis()

plt.tight_layout()
plt.savefig('stretched_trials_histogram_mean.pdf', format='pdf', dpi=500)
plt.show()

# Ext data figure

## (a) Licking Behavior

Example data from F1_4 

In [None]:
vr = vr_list[1]
trial = vr.trial.copy()
position = position_list[1]

trial_normal_near = [57, 61, 64, 68, 69]
trial_probe_near = [80, 85, 100, 105, 110]
trial_near = trial_normal_near + trial_probe_near

trial_normal_far = [23, 24, 27, 32, 39]
trial_probe_far = [40, 45, 70, 90, 95]
trial_far = trial_normal_far + trial_probe_far

trial_example = [trial_near,trial_far]

Example of licking behavior and its histogram for a single session

In [None]:
import seaborn as sns
color_cue = ['black', 'gray']
labels = ['Normal','Stretched']
fig, axs = plt.subplots(2,2,figsize=(16, 10))

for rewarding_id in range(1,3):    
    lick_example = []
    for ii in range(len(trial_example[rewarding_id-1])):
        lick_temp = vr.lick.loc[vr.lick.trial_number == trial_example[rewarding_id-1][ii], 'position'].to_numpy()
        lick_example.append(lick_temp)

    for ii in range(len(trial_near)):
        if ii % 5 == 0: 
            lgd = labels[int(ii/5)]
        else:
            lgd=str()
        axs[0,rewarding_id-1].scatter(lick_example[ii], ii*np.ones_like(lick_example[ii]), color = color_cue[int(ii/5)], label = lgd)
    axs[0,rewarding_id-1].invert_yaxis()

    trials = trial.loc[(trial.reward_id == rewarding_id) & (trial.status != 'INCOMPLETE') & (trial.is_guided==False)]
    trial_num_normal = position.loc[(position['probe_status'] == False),'trial_number']
    trial_num_probe = position.loc[(position['probe_status'] == True),'trial_number']

    trials_normal = trials.loc[trials['trial_number'].isin(trial_num_normal),'trial_number']
    trials_probe = trials.loc[trials['trial_number'].isin(trial_num_probe),'trial_number']
    licks_normal = vr.lick.loc[vr.lick.trial_number.isin(trials_normal)]
    licks_probe =  vr.lick.loc[vr.lick.trial_number.isin(trials_probe)]

    binwidth = 5
    bins_hist = range(0, stop_potision_stretch+binwidth, binwidth)
    sns.histplot(licks_normal.position.to_numpy(), bins=bins_hist, stat='density', color=color_cue[0], label="Normal", kde = True, ax=axs[1,rewarding_id-1])
    sns.histplot(licks_probe.position.to_numpy(), bins=bins_hist, stat='density', color=color_cue[1], label="Stretched", kde = True, ax=axs[1,rewarding_id-1])

    # show task regions
    for jj in range(2):
        for ii in range(2):
            min_x = reward_positions_array[ii][rewarding_id-1]-(vr.environment.reward_sizes[0][ii]/2)
            max_x = reward_positions_array[ii][rewarding_id-1]+(vr.environment.reward_sizes[0][ii]/2)
            axs[jj,rewarding_id-1].axvspan(min_x, max_x, color=color_cue[ii], alpha=0.3,linewidth=0)

        # show indicator
        min_x = vr.environment.indicator_position.values-(vr.environment.indicator_size.values/2)
        max_x = vr.environment.indicator_position.values+(vr.environment.indicator_size.values/2)
        axs[jj,rewarding_id-1].axvspan(min_x, max_x, color='b', alpha=0.3,linewidth=0)
        # show gray zone.
        min_x = vr.environment.gray_zone_position.values-(vr.environment.gray_zone_size.values/2)
        max_x = vr.environment.gray_zone_position.values+(vr.environment.gray_zone_size.values/2)
        # plot end
        axs[jj,rewarding_id-1].axvline(x=stop_potision_stretch,color='red',linestyle='--')
        axs[jj,rewarding_id-1].axvspan(min_x, max_x, color='gray', alpha=0.3,linewidth=0)

        axs[jj,rewarding_id-1].spines['right'].set_visible(False)
        axs[jj,rewarding_id-1].spines['top'].set_visible(False)
        # axs[jj,rewarding_id-1].set_xlabel('Position (cm)')
        axs[0,rewarding_id-1].set_yticklabels('')
        # axs[0,rewarding_id-1].set_ylabel('Trial number')
        axs[1,rewarding_id-1].set_ylabel('')
        axs[jj,rewarding_id-1].set_xlim(0,stop_potision_stretch)


axs[0,0].legend()
axs[1,0].legend()
plt.show()

## (b) PV Correlation: Regular vs Stretched trials

PV correlation between the average neural population activity in regular and stretched trials, shown for near trials (left column) and far trials (right column).

In [None]:
trial_type = ['NEAR', 'FAR']

In [None]:
ii = 0
fig, ax = plt.subplots(1,2,figsize=(21, 6),dpi = 600)
for jj in range(2):
    ttl = animals[ii] + ' : '+ trial_type[jj]
    plot_correlation_matrix_sns(ax[jj], datasets[ii].smooth_binF_probe[jj].T,datasets[ii].smooth_binF_normal[jj].T, ttl,reward_positions=[190, 290], reward_positions_ref=[140, 190])

plt.tight_layout()

In [None]:
ii = 1
fig, ax = plt.subplots(1,2,figsize=(21, 6),dpi = 600)
for jj in range(2):
    ttl = animals[ii] + ' : '+ trial_type[jj]
    plot_correlation_matrix_sns(ax[jj], datasets[ii].smooth_binF_probe[jj].T,datasets[ii].smooth_binF_normal[jj].T, ttl,reward_positions=[190, 290], reward_positions_ref=[140, 190])

plt.tight_layout()

In [None]:
ii = 2
fig, ax = plt.subplots(1,2,figsize=(21, 6),dpi = 600)
for jj in range(2):
    ttl = animals[ii] + ' : '+ trial_type[jj]
    plot_correlation_matrix_sns(ax[jj], datasets[ii].smooth_binF_probe[jj].T,datasets[ii].smooth_binF_normal[jj].T, ttl,reward_positions=[190, 290], reward_positions_ref=[140, 190])

plt.tight_layout()