In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from scipy.spatial.distance import cosine, euclidean
import vr2p
import zarr
import seaborn as sns
import matplotlib as mpl
from matplotlib.ticker import MaxNLocator
mpl.rcParams['pdf.fonttype'] = 42

def non_negative(arr):
    return np.where(arr < 0, 0, arr)

animal_name = 'Tyche-A5-SetA.zarr'
animal = animal_name[:8]
print(animal)

# Load data
path = f'/.../Set A/{animal_name}/'
data = vr2p.ExperimentData(path)
os.makedirs('Figure_3_FINAL', exist_ok=True)
# Generate index for days animal is performing Cue Set A only
day_count = []
for i in range(len(data.signals.multi_session.F)):
    if ('Cue Set A' in data.vr[i].trial.set.unique()) and (len(data.vr[i].trial.set.unique())==1):
        day_count.append(i)
    else:
        break

print(max(day_count))

# Load stored place field analysis for each day
range_A = range(max(day_count)+1)
criteria = 'significant'
zarr_file = zarr.open(f'/SetA/{animal}-PF.zarr', mode="r")
pf_all_Near = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()] for i in range_A]
pf_all_Far = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()] for i in range_A]
binF_Near = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]
binF_Far = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]

# Perform analysis for a specific session (e.g., session 0)
session = 5

# Define threshold based on mean and standard deviation
threshold_std = 2  # Number of standard deviations above the mean

corr_coeffs = []
max_amp_diff_scores = []
binF_Near_filtered = []
binF_Far_filtered = []
original_cell_numbers = []

# Calculate max amplitude for each cell across all positions for both trial types
max_amplitude_Near = np.max(non_negative(binF_Near[session]), axis=1)
max_amplitude_Far = np.max(non_negative(binF_Far[session]), axis=1)

# Calculate threshold based on mean and standard deviation of activity values for the session
activity_values = np.concatenate((binF_Near[session][:, :], binF_Far[session][:, :]))
threshold = np.mean(activity_values) + threshold_std * np.std(activity_values)

for j in range(len(max_amplitude_Near)):
    if max_amplitude_Near[j] >= threshold or max_amplitude_Far[j] >= threshold:
        corr_coeff, _ = pearsonr(binF_Near[session][j, :], binF_Far[session][j, :])
        corr_coeffs.append(corr_coeff)

        max_amp_ratio = max_amplitude_Near[j] / max_amplitude_Far[j]
        max_amp_diff_score = 1 - min(max_amp_ratio, 1/max_amp_ratio)
        max_amp_diff_scores.append(max_amp_diff_score)

        # Append the filtered binF arrays
        binF_Near_filtered.append(binF_Near[session][j, :])
        binF_Far_filtered.append(binF_Far[session][j, :])

        # Store the original cell number
        original_cell_numbers.append(j)

# Convert the filtered binF arrays to numpy arrays
binF_Near_filtered = np.array(binF_Near_filtered)
binF_Far_filtered = np.array(binF_Far_filtered)

positions = np.linspace(0, 230, binF_Near_filtered.shape[1])

# Set the style to "white"
#sns.set_style("white")

# Plot scatter plot with selected cells
fig_scatter, ax_scatter = plt.subplots(figsize=(16, 14), dpi=500)
sns.scatterplot(x=corr_coeffs, y=max_amp_diff_scores, s=120, color='darkgray', alpha=1, ax=ax_scatter)

# Define corr_middle
corr_middle = 0.2

# Define quadrants and representative cells
quadrants = [
    (np.array(corr_coeffs) < corr_middle) & (np.array(max_amp_diff_scores) >= 0.5),  # Quadrant 1: Splitter
    (np.array(corr_coeffs) >= corr_middle) & (np.array(max_amp_diff_scores) >= 0.5),  # Quadrant 2: Splitter
    (np.array(corr_coeffs) < corr_middle) & (np.array(max_amp_diff_scores) < 0.5),  # Quadrant 3: Remapping Splitter
    (np.array(corr_coeffs) >= corr_middle) & (np.array(max_amp_diff_scores) < 0.5)  # Quadrant 4: Place Cell
]

cell_type_names = ['Splitter', 'Splitter', 'Remapping Splitter', 'Place Cell']

selected_cells_blue = []
for quadrant in quadrants:
    if np.sum(quadrant) > 0:
        quadrant_cells = np.where(quadrant)[0]
        corr_vals = np.array(corr_coeffs)[quadrant_cells]
        diff_scores = np.array(max_amp_diff_scores)[quadrant_cells]

        # Find the cell with extreme combination of correlation and difference score in each quadrant
        distances = [euclidean([corr, diff], [corr_middle, 0.5]) for corr, diff in zip(corr_vals, diff_scores)]
        extreme_idx = quadrant_cells[np.argmax(distances)]
        selected_cells_blue.append(original_cell_numbers[extreme_idx])


# Find cells slightly off-center from the middle of each quadrant for black cells
selected_cells_black = []
for quadrant in quadrants:
    if np.sum(quadrant) > 0:
        quadrant_cells = np.where(quadrant)[0]
        corr_vals = np.array(corr_coeffs)[quadrant_cells]
        diff_scores = np.array(max_amp_diff_scores)[quadrant_cells]

        # Find the cell slightly off-center from the middle of each quadrant
        distances = [euclidean([corr, diff], [corr_middle, 0.5]) for corr, diff in zip(corr_vals, diff_scores)]
        sorted_indices = np.argsort(distances)
        middle_idx = quadrant_cells[sorted_indices[len(distances) // 20]]  # Select the cell at the 5th percentile
        selected_cells_black.append(original_cell_numbers[middle_idx])

# Find cells in between the extreme and middle cells for red cells
selected_cells_red = []
for quadrant in quadrants:
    if np.sum(quadrant) > 0:
        quadrant_cells = np.where(quadrant)[0]
        corr_vals = np.array(corr_coeffs)[quadrant_cells]
        diff_scores = np.array(max_amp_diff_scores)[quadrant_cells]

        # Find the cell closest to the midpoint between the extreme and middle cells in each quadrant
        extreme_idx = selected_cells_blue[len(selected_cells_red)]
        middle_idx = selected_cells_black[len(selected_cells_red)]
        extreme_corr, extreme_diff = corr_coeffs[original_cell_numbers.index(extreme_idx)], max_amp_diff_scores[original_cell_numbers.index(extreme_idx)]
        middle_corr, middle_diff = corr_coeffs[original_cell_numbers.index(middle_idx)], max_amp_diff_scores[original_cell_numbers.index(middle_idx)]
        target_corr, target_diff = (extreme_corr + middle_corr) / 2, (extreme_diff + middle_diff) / 2
        distances = [euclidean([corr, diff], [target_corr, target_diff]) for corr, diff in zip(corr_vals, diff_scores)]
        inner_idx = quadrant_cells[np.argmin(distances)]
        selected_cells_red.append(original_cell_numbers[inner_idx])

# Blue cells (adding some manual selection)
blue_cell_numbers = [1063, 248, selected_cells_blue[2], 2917]
blue_cell_indices = [original_cell_numbers.index(cell_number) for cell_number in blue_cell_numbers]

# Black cells
black_cell_numbers = [selected_cells_black[0], selected_cells_black[1], selected_cells_black[2], selected_cells_black[3]]
black_cell_indices = [original_cell_numbers.index(cell_number) for cell_number in black_cell_numbers]


# Red cells
red_cell_numbers = [selected_cells_red[0], selected_cells_red[1], selected_cells_red[2], 6]
red_cell_indices = [original_cell_numbers.index(cell_number) for cell_number in red_cell_numbers]

for i, cell_idx in enumerate(blue_cell_indices):
    sns.scatterplot(x=[corr_coeffs[cell_idx]], y=[max_amp_diff_scores[cell_idx]], s=300, color='blue', alpha=1, edgecolor='black', linewidth=0.5, ax=ax_scatter)
    ax_scatter.annotate(f'Blue {i+1}', (corr_coeffs[cell_idx], max_amp_diff_scores[cell_idx]),
                        xytext=(10, 10), textcoords='offset points', fontsize=40, color='blue')

for i, cell_idx in enumerate(black_cell_indices):
    sns.scatterplot(x=[corr_coeffs[cell_idx]], y=[max_amp_diff_scores[cell_idx]], s=300, color='black', alpha=1, edgecolor='black', linewidth=0.5, ax=ax_scatter)
    ax_scatter.annotate(f'Black {i+1}', (corr_coeffs[cell_idx], max_amp_diff_scores[cell_idx]),
                        xytext=(10, 10), textcoords='offset points', fontsize=40, color='black')

for i, cell_idx in enumerate(red_cell_indices):
    sns.scatterplot(x=[corr_coeffs[cell_idx]], y=[max_amp_diff_scores[cell_idx]], s=300, color='red', alpha=1, edgecolor='black', linewidth=0.5, ax=ax_scatter)
    ax_scatter.annotate(f'Red {i+1}', (corr_coeffs[cell_idx], max_amp_diff_scores[cell_idx]),
                        xytext=(10, 10), textcoords='offset points', fontsize=40, color='red')

ax_scatter.set_xlabel('Correlation (Near vs Far Trials)', fontsize=40)
ax_scatter.set_ylabel('Difference Score', fontsize=40)
ax_scatter.tick_params(axis='both', labelsize=32, width=2, length=10, direction='out')
ax_scatter.set_xticks([-0.5, 0, 0.5, 1])  # Set x-axis ticks
ax_scatter.set_yticks([0, 0.5, 1])  # Set y-axis ticks
ax_scatter.xaxis.set_tick_params(width=2)  # Increase x-axis tick thickness
ax_scatter.yaxis.set_tick_params(width=2)  # Increase y-axis tick thickness


# Set the xlim and ylim before despine to ensure proper plot boundaries
ax_scatter.set_xlim(-0.8, 1)
ax_scatter.set_ylim(0, 1)

sns.despine(ax=ax_scatter, top=True, right=True)

plt.tight_layout()
plt.subplots_adjust(top=0.95)  # Adjust the top spacing to prevent overlapping
plt.savefig('Figure_3_FINAL/scatter_plot_selected_cells_square.pdf', dpi=500, bbox_inches='tight')

# Define colors for tuning curves
near_color = 'orchid'    
far_color = 'seagreen'  


# Plot tuning curves for the selected cells
fig_blue, axs_blue = plt.subplots(2, 2, figsize=(16, 16), dpi=500)
axs_blue = axs_blue.flatten()

for i, cell_idx in enumerate(blue_cell_indices):
    axs_blue[i].plot(positions, binF_Near_filtered[cell_idx, :], label='Near' if i == 0 else '', linestyle='-', color=near_color, linewidth=4)
    axs_blue[i].plot(positions, binF_Far_filtered[cell_idx, :], label='Far' if i == 0 else '', linestyle='-', color=far_color, linewidth=4)

    axs_blue[i].set_xlabel('Position (cm)', fontsize=32)
    axs_blue[i].set_ylabel('Activity (dF/F0)', fontsize=32)
    axs_blue[i].set_title(f'Blue {i+1}', fontsize=36, color='blue')
    axs_blue[i].tick_params(axis='both', labelsize=28, width=2, length=10, direction='out')
    axs_blue[i].set_xticks([0, 50, 100, 150, 200])  # Set x-axis ticks
    axs_blue[i].xaxis.set_tick_params(width=2)  # Increase x-axis tick thickness
    axs_blue[i].yaxis.set_tick_params(width=2)  # Increase y-axis tick thickness
    axs_blue[i].yaxis.set_major_locator(MaxNLocator(integer=True))

    if i == 0:
        axs_blue[i].legend(fontsize=28, loc='upper right', frameon=False)
    sns.despine(ax=axs_blue[i], top=True, right=True)  # Remove top and right axis box

plt.tight_layout()
plt.savefig('Figure_3_FINAL/tuning_curves_blue_cells.pdf', dpi=500, bbox_inches='tight')

fig_black, axs_black = plt.subplots(2, 2, figsize=(16, 16), dpi=500)
axs_black = axs_black.flatten()

for i, cell_idx in enumerate(black_cell_indices):
    axs_black[i].plot(positions, binF_Near_filtered[cell_idx, :], linestyle='-', color=near_color, linewidth=4)
    axs_black[i].plot(positions, binF_Far_filtered[cell_idx, :], linestyle='-', color=far_color, linewidth=4)

    axs_black[i].set_xlabel('Position (cm)', fontsize=32)
    axs_black[i].set_ylabel('Activity (dF/F0)', fontsize=32)
    axs_black[i].set_title(f'Black {i+1}', fontsize=36, color='black')
    axs_black[i].tick_params(axis='both', labelsize=28, width=2, length=10, direction='out')
    axs_black[i].set_xticks([0, 50, 100, 150, 200])  # Set x-axis ticks
    axs_black[i].xaxis.set_tick_params(width=2)  # Increase x-axis tick thickness
    axs_black[i].yaxis.set_tick_params(width=2)  # Increase y-axis tick thickness
    axs_black[i].yaxis.set_major_locator(MaxNLocator(integer=True))

    sns.despine(ax=axs_black[i], top=True, right=True)  # Remove top and right axis box

plt.tight_layout()
plt.savefig('Figure_3_FINAL/tuning_curves_inner_cells.pdf', dpi=500, bbox_inches='tight')

fig_red, axs_red = plt.subplots(2, 2, figsize=(16, 16), dpi=500)
axs_red = axs_red.flatten()

for i, cell_idx in enumerate(red_cell_indices):
    axs_red[i].plot(positions, binF_Near_filtered[cell_idx, :], linestyle='-', color=near_color, linewidth=4)
    axs_red[i].plot(positions, binF_Far_filtered[cell_idx, :], linestyle='-', color=far_color, linewidth=4)

    axs_red[i].set_xlabel('Position (cm)', fontsize=32)
    axs_red[i].set_ylabel('Activity (dF/F0)', fontsize=32)
    axs_red[i].set_title(f'Red {i+1}', fontsize=36, color='red')
    axs_red[i].tick_params(axis='both', labelsize=28, width=2, length=10, direction='out')
    axs_red[i].set_xticks([0, 50, 100, 150, 200])  # Set x-axis ticks
    axs_red[i].xaxis.set_tick_params(width=2)  # Increase x-axis tick thickness
    axs_red[i].yaxis.set_tick_params(width=2)  # Increase y-axis tick thickness
    axs_red[i].yaxis.set_major_locator(MaxNLocator(integer=True))

    sns.despine(ax=axs_red[i], top=True, right=True)  # Remove top and right axis box
    if i == 0:  
        axs_red[i].set_ylim(None, 1)  

plt.tight_layout()
plt.savefig('Figure_3_FINAL/tuning_curves_middle_cells.pdf', dpi=500, bbox_inches='tight')

In [None]:
import scipy.ndimage as ndimage
import matplotlib.colors as mcolors
import matplotlib.lines as mlines

sessions = range_A

# Select the first, middle, and last sessions
selected_sessions = [sessions[0], sessions[len(sessions)//2], sessions[-1]]

# Define markers and their corresponding positions
markers = [
    {'name': 'Track Boundaries', 'position': [0, 12]},
    {'name': 'Track Boundaries', 'position': [40, 46]},
    {'name': 'Indicator', 'position': [12, 20]},
    {'name': 'Reward / Pre-Reward', 'position': [20, 40]},
]

# Define colors for each group
colors = {
    'Track Boundaries': 'black',
    'Indicator': 'red',
    'Reward / Pre-Reward': 'lightblue',
}

# Create the directory if it doesn't exist
os.makedirs(f'plots/{animal}/New_figure_3', exist_ok=True)

# Plot scatter plots of corr vs diff for selected sessions with colored dots
fig1, axs1 = plt.subplots(1, len(selected_sessions), figsize=(16, 6), dpi=300)
#fig1.suptitle(f'Correlation Coefficient vs Max Amplitude Difference (Scatter Plots)', fontsize=24)

for i, session in enumerate(selected_sessions):
    corr_coeffs = []
    max_amp_diff_scores = []
    positions = []

    # Calculate max amplitude for each cell across all positions for both trial types
    max_amplitude_Near_idx = np.argmax(non_negative(binF_Near[session]), axis=1)
    max_amplitude_Far_idx = np.argmax(non_negative(binF_Far[session]), axis=1)

    max_amplitude_Near = binF_Near[session][np.arange(len(max_amplitude_Near_idx)), max_amplitude_Near_idx]
    max_amplitude_Far = binF_Far[session][np.arange(len(max_amplitude_Far_idx)), max_amplitude_Far_idx]

    # Calculate threshold based on mean and standard deviation of activity values for the session
    activity_values = np.concatenate((binF_Near[session][:, :], binF_Far[session][:, :]))
    threshold = np.mean(activity_values) + threshold_std * np.std(activity_values)

    for j in range(len(max_amplitude_Near)):
        if max_amplitude_Near[j] >= threshold or max_amplitude_Far[j] >= threshold:
            corr_coeff, _ = pearsonr(binF_Near[session][j, :], binF_Far[session][j, :])
            corr_coeffs.append(corr_coeff)

            max_amp_ratio = max_amplitude_Near[j] / max_amplitude_Far[j]
            max_amp_diff_score = 1 - min(max_amp_ratio, 1/max_amp_ratio)
            max_amp_diff_scores.append(max_amp_diff_score)

            # Get the position of the cell based on the max amplitude index
            position = max_amplitude_Near_idx[j] if max_amplitude_Near[j] >= max_amplitude_Far[j] else max_amplitude_Far_idx[j]
            positions.append(position)

    # Determine the color of each dot based on its position
    dot_colors = []
    for position in positions:
        for marker in markers:
            if marker['position'][0] <= position < marker['position'][1]:
                dot_colors.append(colors[marker['name']])
                break
        else:
            dot_colors.append('gray')  # Default color if position doesn't fall into any group

    # Plot correlation coefficient vs max amplitude difference score (scatter plots)
    ax1 = axs1[i]
    scatter = ax1.scatter(corr_coeffs, max_amp_diff_scores, s=50, c=dot_colors, edgecolors='black', linewidths=0.5)

    ax1.set_xlabel('Corr (Near vs Far Trials)', fontsize=20)
    ax1.set_ylabel('Max Amplitude Difference Score', fontsize=20)
    ax1.set_title(f'Session {session+1}', fontsize=24)
    ax1.tick_params(axis='both', labelsize=18)
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust the spacing between subplots
plt.savefig(f'plots/{animal}/New_figure_3/{animal}_corr_coeff_vs_max_amp_diff_scatter_plots_color_coded.pdf', dpi=300, bbox_inches='tight')

# Create a separate plot for the legend
fig_legend, ax_legend = plt.subplots(figsize=(6, 1), dpi=300)

legend_handles = [mlines.Line2D([], [], color=color, marker='o', linestyle='None', markersize=10, label=label) for label, color in colors.items()]
ax_legend.legend(handles=legend_handles, loc='center', ncol=len(colors), fontsize=16, frameon=False)

ax_legend.axis('off')

plt.tight_layout()
plt.savefig(f'plots/{animal}/New_figure_3/{animal}_corr_coeff_vs_max_amp_diff_scatter_plots_legend.pdf', dpi=300, bbox_inches='tight')

In [None]:
sessions = range_A
# Select the first, middle, and last sessions
selected_sessions = [sessions[0], sessions[len(sessions)//2], sessions[-1]]
# Define markers and their corresponding positions
markers = [
    {'name': 'Track Boundaries', 'position': [[0, 12], [40, 46]], 'color': 'gray'},
    {'name': 'Indicator', 'position': [12, 20], 'color': 'orange'},
    {'name': 'Intra-Track', 'position': [20, 40], 'color': 'mediumturquoise'},
]
# Create the directory if it doesn't exist
os.makedirs(f'plots/{animal}/New_figure_3', exist_ok=True)
# Set Seaborn style
sns.set(style="ticks")
# Plot scatter plots of corr vs diff for selected sessions with colored dots
fig1, axs1 = plt.subplots(3, 3, figsize=(16, 16), dpi=500, sharex=True, sharey=True)
for i, marker in enumerate(markers):
    for j, session in enumerate(selected_sessions):
        corr_coeffs = []
        max_amp_diff_scores = []
        # Calculate max amplitude for each cell across all positions for both trial types
        max_amplitude_Near_idx = np.argmax(non_negative(binF_Near[session]), axis=1)
        max_amplitude_Far_idx = np.argmax(non_negative(binF_Far[session]), axis=1)
        max_amplitude_Near = binF_Near[session][np.arange(len(max_amplitude_Near_idx)), max_amplitude_Near_idx]
        max_amplitude_Far = binF_Far[session][np.arange(len(max_amplitude_Far_idx)), max_amplitude_Far_idx]
        # Calculate threshold based on mean and standard deviation of activity values for the session
        activity_values = np.concatenate((binF_Near[session][:, :], binF_Far[session][:, :]))
        threshold = np.mean(activity_values) + threshold_std * np.std(activity_values)
        for k in range(len(max_amplitude_Near)):
            if max_amplitude_Near[k] >= threshold or max_amplitude_Far[k] >= threshold:
                corr_coeff, _ = pearsonr(binF_Near[session][k, :], binF_Far[session][k, :])
                corr_coeffs.append(corr_coeff)
                max_amp_ratio = max_amplitude_Near[k] / max_amplitude_Far[k]
                max_amp_diff_score = 1 - min(max_amp_ratio, 1/max_amp_ratio)
                max_amp_diff_scores.append(max_amp_diff_score)
                # Get the position of the cell based on the max amplitude index
                position = max_amplitude_Near_idx[k] if max_amplitude_Near[k] >= max_amplitude_Far[k] else max_amplitude_Far_idx[k]
                # Check if the position falls into the current marker's range
                if marker['name'] == 'Track Boundaries':
                    if any(start <= position < end for start, end in marker['position']):
                        # Plot correlation coefficient vs max amplitude difference score (scatter plots)
                        ax1 = axs1[i, j]
                        ax1.scatter(corr_coeff, max_amp_diff_score, s=80, c=marker['color'], alpha=0.8)
                else:
                    if marker['position'][0] <= position < marker['position'][1]:
                        # Plot correlation coefficient vs max amplitude difference score (scatter plots)
                        ax1 = axs1[i, j]
                        ax1.scatter(corr_coeff, max_amp_diff_score, s=80, c=marker['color'], alpha=0.8)
                ax1.spines['right'].set_visible(False)
                ax1.spines['top'].set_visible(False)
# Set x-axis range and tick intervals for all subplots
for i in range(3):
    for j in range(3):
        axs1[i, j].set_xlim(-0.8, 1)
        axs1[i, j].set_xticks([-0.5, 0, 0.5, 1])
        axs1[i, j].set_xticklabels([-0.5, 0, 0.5, 1], fontsize=12)  # Set x-tick labels explicitly
        axs1[i, j].set_yticks([0, 0.5, 1])
        axs1[i, j].tick_params(axis='both', labelsize=12, width=2, length=6)
        axs1[i, j].spines['bottom'].set_visible(True)
        axs1[i, j].spines['left'].set_visible(True)
        ylim = axs1[i, j].get_ylim()
        print(f"Subplot ({i}, {j}) - Y-axis limits: {ylim}")

# Add y-axis labels for the subplots in the first column
for i in range(3):
    axs1[i, 0].set_ylabel('Difference Score', fontsize=14)
# Add x-axis labels for the subplots in the last row
for j in range(3):
    axs1[2, j].set_xlabel('Correlation (Near vs Far Trials)', fontsize=14)
# Adjust the spacing between subplots
plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.savefig(f'Figure_3_FINAL/{animal}_corr_coeff_vs_max_amp_diff_scatter_plots_color_coded_separate_merged_seaborn_modified_4.pdf', dpi=300, bbox_inches='tight')

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from scipy.interpolate import interp1d
import vr2p
import zarr

def non_negative(arr):
    return np.where(arr < 0, 0, arr)

def calculate_percentages(corr_coeffs, max_amp_diff_scores, threshold=0.2):
    quadrants = [
        (np.array(corr_coeffs) < threshold) & (np.array(max_amp_diff_scores) >= 0.5),  # Quadrant 1: Strong Splitter
        (np.array(corr_coeffs) >= threshold) & (np.array(max_amp_diff_scores) >= 0.5),  # Quadrant 2: Weak Splitter
        (np.array(corr_coeffs) < threshold) & (np.array(max_amp_diff_scores) < 0.5),  # Quadrant 3: Multiplex Splitter
        (np.array(corr_coeffs) >= threshold) & (np.array(max_amp_diff_scores) < 0.5)  # Quadrant 4: Place Cell
    ]

    percentages = [np.sum(quadrant) / len(corr_coeffs) * 100 for quadrant in quadrants]
    return percentages

# Check if the variables are already saved
if os.path.exists('plots/interp_sessions.npy') and os.path.exists('plots/mean_percentages.npy') and os.path.exists('plots/sem_percentages.npy'):
    # Load the saved variables
    interp_sessions = np.load('plots/interp_sessions.npy')
    mean_percentages = np.load('plots/mean_percentages.npy')
    sem_percentages = np.load('plots/sem_percentages.npy')
else:
    # Get the list of animal names
    names = os.listdir('/.../Set A/')
    names = [name for name in names if name != '.DS_Store']

    all_percentages = []

    for animal_name in names:
        animal = animal_name[:8]
        print(f"Processing animal: {animal}")

        # Create a directory for the animal's plots
        os.makedirs(f'plots/{animal}', exist_ok=True)

        # Load data
        path = f'/.../Set A/{animal_name}/'
        data = vr2p.ExperimentData(path)

        # Generate index for days animal is performing Cue Set A only
        day_count = []
        for i in range(len(data.signals.multi_session.F)):
            if ('Cue Set A' in data.vr[i].trial.set.unique()) and (len(data.vr[i].trial.set.unique())==1):
                day_count.append(i)
            else:
                break

        print(max(day_count))

        # Load stored place field analysis for each day
        range_A = range(max(day_count)+1)
        criteria = 'significant'
        zarr_file = zarr.open(f'/SetA/{animal}-PF.zarr', mode="r")
        pf_all_Near = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()] for i in range_A]
        pf_all_Far = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()] for i in range_A]
        binF_Near = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]
        binF_Far = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]

        sessions = range_A

        # Define markers and their corresponding positions
        markers = [
            {'name': 'Track Boundaries', 'position': [0, 12]},
            {'name': 'Track Boundaries', 'position': [40, 46]},
            {'name': 'Indicator', 'position': [12, 20]},
            {'name': 'Intra-Track', 'position': [20, 40]},
        ]

        # Define colors for each group
        colors = {
            'Track Boundaries': 'black',
            'Indicator': 'red',
            'Intra-Track': 'lightblue',
        }

        # Define threshold based on mean and standard deviation
        threshold_std = 2  # Number of standard deviations above the mean

        percentage_data = []

        for session in sessions:
            corr_coeffs = []
            max_amp_diff_scores = []
            positions = []

            # Calculate max amplitude for each cell across all positions for both trial types
            max_amplitude_Near_idx = np.argmax(non_negative(binF_Near[session]), axis=1)
            max_amplitude_Far_idx = np.argmax(non_negative(binF_Far[session]), axis=1)

            max_amplitude_Near = binF_Near[session][np.arange(len(max_amplitude_Near_idx)), max_amplitude_Near_idx]
            max_amplitude_Far = binF_Far[session][np.arange(len(max_amplitude_Far_idx)), max_amplitude_Far_idx]

            # Calculate threshold based on mean and standard deviation of activity values for the session
            activity_values = np.concatenate((binF_Near[session][:, :], binF_Far[session][:, :]))
            threshold = np.mean(activity_values) + threshold_std * np.std(activity_values)

            for j in range(len(max_amplitude_Near)):
                if max_amplitude_Near[j] >= threshold or max_amplitude_Far[j] >= threshold:
                    corr_coeff, _ = pearsonr(binF_Near[session][j, :], binF_Far[session][j, :])
                    corr_coeffs.append(corr_coeff)

                    max_amp_ratio = max_amplitude_Near[j] / max_amplitude_Far[j]
                    max_amp_diff_score = 1 - min(max_amp_ratio, 1/max_amp_ratio)
                    max_amp_diff_scores.append(max_amp_diff_score)

                    # Get the position of the cell based on the max amplitude index
                    position = max_amplitude_Near_idx[j] if max_amplitude_Near[j] >= max_amplitude_Far[j] else max_amplitude_Far_idx[j]
                    positions.append(position)

            # Calculate the percentages for each quadrant
            percentages = calculate_percentages(corr_coeffs, max_amp_diff_scores, threshold=0.2)
            percentage_data.append(percentages)

        all_percentages.append(percentage_data)

        # Plot the percentage changes for each quadrant over sessions for the current animal
        fig_animal, ax_animal = plt.subplots(figsize=(8, 6), dpi=300)

        for i, quadrant in enumerate(['Q1: Strong Splitter', 'Q2: Weak Splitter', 'Q3: Multiplex Splitter', 'Q4: Place Cell']):
            percentages = [data[i] for data in percentage_data]
            ax_animal.plot(range(1, len(sessions) + 1), percentages, marker='o', label=quadrant, linewidth=2)

        ax_animal.set_xlabel('Session', fontsize=14, fontweight='bold')
        ax_animal.set_ylabel('Percentage (%)', fontsize=14, fontweight='bold')
        ax_animal.set_title(f'Percentage of Cell Types Across Sessions ({animal})', fontsize=16, fontweight='bold')
        ax_animal.tick_params(axis='both', labelsize=14)
        ax_animal.spines['top'].set_visible(False)
        ax_animal.spines['right'].set_visible(False)
        ax_animal.legend(fontsize=12, frameon=False, loc='upper left')

        plt.tight_layout()
        plt.savefig(f'plots/{animal}/{animal}_percentage_changes.pdf', dpi=300, bbox_inches='tight')
        plt.close(fig_animal)

    # Normalize session numbers and interpolate data
    normalized_percentages = []
    max_sessions = max(len(data) for data in all_percentages)

    for data in all_percentages:
        sessions = np.linspace(0, 1, len(data))
        interp_sessions = np.linspace(0, 1, max_sessions)
        interp_data = []

        for i in range(4):
            interp_func = interp1d(sessions, [d[i] for d in data], kind='linear')
            interp_data.append(interp_func(interp_sessions))

        normalized_percentages.append(interp_data)

    # Calculate mean and SEM across animals
    mean_percentages = np.mean(normalized_percentages, axis=0)
    sem_percentages = np.std(normalized_percentages, axis=0) / np.sqrt(len(normalized_percentages))

    # Save variables for replotting
    np.save('plots/interp_sessions.npy', interp_sessions)
    np.save('plots/mean_percentages.npy', mean_percentages)
    np.save('plots/sem_percentages.npy', sem_percentages)

# Create a figure and axis for the summary plot
fig_summary, ax_summary = plt.subplots(figsize=(8, 6), dpi=300)

# Plot the mean percentage changes for each quadrant with SEM shading
for i, quadrant in enumerate(['Q1: Strong Splitter', 'Q2: Weak Splitter', 'Q3: Multiplex Splitter', 'Q4: Place Cell']):
    ax_summary.plot(interp_sessions, mean_percentages[i], marker='o', label=quadrant, linewidth=2)
    ax_summary.fill_between(interp_sessions, mean_percentages[i] - sem_percentages[i], mean_percentages[i] + sem_percentages[i],
                            alpha=0.3, edgecolor='none')

# Set axis labels and title for the summary plot
ax_summary.set_xlabel('Normalized Session', fontsize=14, fontweight='bold')
ax_summary.set_ylabel('Percentage (%)', fontsize=14, fontweight='bold')
ax_summary.set_title('Percentage of Cell Types Across Sessions (All Animals)', fontsize=16, fontweight='bold')

# Set tick parameters and remove top and right spines for the summary plot
ax_summary.tick_params(axis='both', labelsize=14)
ax_summary.spines['top'].set_visible(False)
ax_summary.spines['right'].set_visible(False)

# Add legend for the summary plot
ax_summary.legend(fontsize=12, frameon=False, loc='upper left')

# Adjust layout and save the summary plot
plt.tight_layout()
os.makedirs('plots', exist_ok=True)
plt.savefig('plots/all_animals_percentage_changes_summary.pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from scipy.interpolate import interp1d
import vr2p
import zarr

def non_negative(arr):
    return np.where(arr < 0, 0, arr)

def calculate_percentages(corr_coeffs, max_amp_diff_scores, threshold=0.2):
    quadrants = [
        np.array(max_amp_diff_scores) >= 0.5,  # Splitter (Strong and Weak combined)
        (np.array(corr_coeffs) < threshold) & (np.array(max_amp_diff_scores) < 0.5),  # Multiplex Splitter
        (np.array(corr_coeffs) >= threshold) & (np.array(max_amp_diff_scores) < 0.5)  # Place Cell
    ]

    percentages = [np.sum(quadrant) / len(corr_coeffs) * 100 for quadrant in quadrants]
    return percentages

# Get the list of animal names
names = os.listdir('/.../Set A/')
names = [name for name in names if name != '.DS_Store']

all_percentages = []

for animal_name in names:
    animal = animal_name[:8]
    print(f"Processing animal: {animal}")

    # Create a directory for the animal's plots
    os.makedirs(f'plots/{animal}', exist_ok=True)

    # Load data
    path = f'/.../Set A/{animal_name}/'
    data = vr2p.ExperimentData(path)

    # Generate index for days animal is performing Cue Set A only
    day_count = []
    for i in range(len(data.signals.multi_session.F)):
        if ('Cue Set A' in data.vr[i].trial.set.unique()) and (len(data.vr[i].trial.set.unique())==1):
            day_count.append(i)
        else:
            break

    print(max(day_count))

    # Load stored place field analysis for each day
    range_A = range(max(day_count)+1)
    criteria = 'significant'
    zarr_file = zarr.open(f'/SetA/{animal}-PF.zarr', mode="r")
    pf_all_Near = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()] for i in range_A]
    pf_all_Far = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()] for i in range_A]
    binF_Near = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]
    binF_Far = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]

    sessions = range_A

    # Define threshold based on mean and standard deviation
    threshold_std = 2  # Number of standard deviations above the mean

    percentage_data = []

    for session in sessions:
        corr_coeffs = []
        max_amp_diff_scores = []

        # Calculate max amplitude for each cell across all positions for both trial types
        max_amplitude_Near_idx = np.argmax(non_negative(binF_Near[session]), axis=1)
        max_amplitude_Far_idx = np.argmax(non_negative(binF_Far[session]), axis=1)

        max_amplitude_Near = binF_Near[session][np.arange(len(max_amplitude_Near_idx)), max_amplitude_Near_idx]
        max_amplitude_Far = binF_Far[session][np.arange(len(max_amplitude_Far_idx)), max_amplitude_Far_idx]

        # Calculate threshold based on mean and standard deviation of activity values for the session
        activity_values = np.concatenate((binF_Near[session][:, :], binF_Far[session][:, :]))
        threshold = np.mean(activity_values) + threshold_std * np.std(activity_values)

        for j in range(len(max_amplitude_Near)):
            if max_amplitude_Near[j] >= threshold or max_amplitude_Far[j] >= threshold:
                corr_coeff, _ = pearsonr(binF_Near[session][j, :], binF_Far[session][j, :])
                corr_coeffs.append(corr_coeff)

                max_amp_ratio = max_amplitude_Near[j] / max_amplitude_Far[j]
                max_amp_diff_score = 1 - min(max_amp_ratio, 1/max_amp_ratio)
                max_amp_diff_scores.append(max_amp_diff_score)

        # Calculate the percentages for each quadrant
        percentages = calculate_percentages(corr_coeffs, max_amp_diff_scores, threshold=0.2)
        percentage_data.append(percentages)

    all_percentages.append(percentage_data)

    # Plot the percentage changes for each quadrant over sessions for the current animal
    fig_animal, ax_animal = plt.subplots(figsize=(8, 6), dpi=300)

    for i, quadrant in enumerate(['Splitter', 'Multiplex Splitter', 'Place Cell']):
        percentages = [data[i] for data in percentage_data]
        ax_animal.plot(range(1, len(sessions) + 1), percentages, marker='o', label=quadrant, linewidth=2)

    ax_animal.set_xlabel('Session', fontsize=14, fontweight='bold')
    ax_animal.set_ylabel('Percentage (%)', fontsize=14, fontweight='bold')
    ax_animal.set_title(f'Percentage of Cell Types Across Sessions ({animal})', fontsize=16, fontweight='bold')
    ax_animal.tick_params(axis='both', labelsize=14)
    ax_animal.spines['top'].set_visible(False)
    ax_animal.spines['right'].set_visible(False)
    ax_animal.legend(fontsize=12, frameon=False, loc='upper left')

    plt.tight_layout()
    plt.savefig(f'plots/{animal}/{animal}_percentage_changes.pdf', dpi=300, bbox_inches='tight')
    plt.close(fig_animal)

# Normalize session numbers and interpolate data
normalized_percentages = []
max_sessions = max(len(data) for data in all_percentages)

for data in all_percentages:
    sessions = np.linspace(0, 1, len(data))
    interp_sessions = np.linspace(0, 1, max_sessions)
    interp_data = []

    for i in range(3):
        interp_func = interp1d(sessions, [d[i] for d in data], kind='linear')
        interp_data.append(interp_func(interp_sessions))

    normalized_percentages.append(interp_data)

# Calculate mean and SEM across animals
mean_percentages = np.mean(normalized_percentages, axis=0)
sem_percentages = np.std(normalized_percentages, axis=0) / np.sqrt(len(normalized_percentages))

# Create a figure and axis for the summary plot
fig_summary, ax_summary = plt.subplots(figsize=(8, 6), dpi=300)

# Plot the mean percentage changes for each quadrant with SEM shading
for i, quadrant in enumerate(['Splitter', 'Multiplex Splitter', 'Place Cell']):
    ax_summary.plot(interp_sessions, mean_percentages[i], marker='o', label=quadrant, linewidth=2)
    ax_summary.fill_between(interp_sessions, mean_percentages[i] - sem_percentages[i], mean_percentages[i] + sem_percentages[i],
                            alpha=0.3, edgecolor='none')

# Set axis labels and title for the summary plot
ax_summary.set_xlabel('Normalized Session', fontsize=14, fontweight='bold')
ax_summary.set_ylabel('Percentage (%)', fontsize=14, fontweight='bold')
ax_summary.set_title('Percentage of Cell Types Across Sessions (All Animals)', fontsize=16, fontweight='bold')

# Set tick parameters and remove top and right spines for the summary plot
ax_summary.tick_params(axis='both', labelsize=14)
ax_summary.spines['top'].set_visible(False)
ax_summary.spines['right'].set_visible(False)

# Add legend for the summary plot
ax_summary.legend(fontsize=12, frameon=False, loc='upper left')

# Adjust layout and save the summary plot
plt.tight_layout()
os.makedirs('plots', exist_ok=True)
plt.savefig('plots/all_animals_percentage_changes_summary.pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Create a figure and axis for the summary plot
fig_summary, ax_summary = plt.subplots(figsize=(6, 6), dpi=300)

# Define colors for each category
colors = ['chocolate', 'mediumseagreen','slategray' ]

# Plot the mean percentage changes for each quadrant with SEM shading
for i, quadrant in enumerate(['Splitter', 'Multiplex Splitter', 'Place Cell']):
    ax_summary.plot(interp_sessions, mean_percentages[i], marker='o', label=quadrant, linewidth=2, color=colors[i])
    ax_summary.fill_between(interp_sessions, mean_percentages[i] - sem_percentages[i], mean_percentages[i] + sem_percentages[i],
                            alpha=0.3, color=colors[i], edgecolor='none')

# Set axis labels and title for the summary plot
ax_summary.set_xlabel('Normalized Training Time', fontsize=16)
ax_summary.set_ylabel('Percentage (%)', fontsize=16)
#ax_summary.set_title('Percentage of Cell Types Across Sessions (All Animals)', fontsize=18)

# Set tick parameters and remove top and right spines for the summary plot
ax_summary.tick_params(axis='both', labelsize=14)
ax_summary.spines['top'].set_visible(False)
ax_summary.spines['right'].set_visible(False)

# Add axis tick marks
ax_summary.xaxis.set_ticks_position('bottom')
ax_summary.yaxis.set_ticks_position('left')

# Adjust layout and save the summary plot
plt.tight_layout()
plt.savefig('Figure_3_FINAL//all_animals_percentage_changes_summary_colored.pdf', dpi=300, bbox_inches='tight')

# Create a separate figure for the custom legend
fig_legend, ax_legend = plt.subplots(figsize=(3, 3), dpi=300)
ax_legend.axis('off')

ax_legend.add_patch(Rectangle((0, 0.5), 1, 0.5, facecolor='chocolate'))
ax_legend.text(0.5, 0.75, 'S', ha='center', va='center', fontsize=36)

ax_legend.add_patch(Rectangle((0, 0), 0.5, 0.5, facecolor='mediumseagreen'))
ax_legend.text(0.25, 0.25, 'RS', ha='center', va='center', fontsize=36)

ax_legend.add_patch(Rectangle((0.5, 0), 0.5, 0.5, facecolor='slategray'))
ax_legend.text(0.75, 0.25, 'P\n/P-S', ha='center', va='center', fontsize=36)

# Adjust layout and save the legend plot
plt.tight_layout()
plt.savefig('Figure_3_FINAL/legend_plot.pdf', dpi=300, bbox_inches='tight')

plt.show()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import vr2p
import zarr

def non_negative(arr):
    return np.where(arr < 0, 0, arr)

def calculate_percentages(corr_coeffs, max_amp_diff_scores, threshold=0.2):
    quadrants = [
        np.array(max_amp_diff_scores) >= 0.5,  # Splitter (Strong and Weak combined)
        (np.array(corr_coeffs) < threshold) & (np.array(max_amp_diff_scores) < 0.5),  # Multiplex Splitter
        (np.array(corr_coeffs) >= threshold) & (np.array(max_amp_diff_scores) < 0.5)  # Place Cell
    ]

    percentages = [np.sum(quadrant) / len(corr_coeffs) * 100 for quadrant in quadrants]
    return percentages

# Get the list of animal names
names = os.listdir('/.../Set A/')
names = [name for name in names if name != '.DS_Store']

all_percentages = []

for animal_name in names:
    animal = animal_name[:8]
    print(f"Processing animal: {animal}")

    # Load data
    path = f'/.../Set A/{animal_name}/'
    data = vr2p.ExperimentData(path)

    # Generate index for days animal is performing Cue Set A only
    day_count = []
    for i in range(len(data.signals.multi_session.F)):
        if ('Cue Set A' in data.vr[i].trial.set.unique()) and (len(data.vr[i].trial.set.unique())==1):
            day_count.append(i)
        else:
            break

    print(max(day_count))

    # Load stored place field analysis for the last session
    last_session = max(day_count)
    criteria = 'significant'
    zarr_file = zarr.open(f'/SetA/{animal}-PF.zarr', mode="r")
    binF_Near = zarr_file[f'Cue Set A/1/excl_no_response/{last_session}/{criteria}'][()]['binF']
    binF_Far = zarr_file[f'Cue Set A/2/excl_no_response/{last_session}/{criteria}'][()]['binF']

    # Define threshold based on mean and standard deviation
    threshold_std = 2  # Number of standard deviations above the mean

    # Calculate threshold based on mean and standard deviation of activity values for the session
    activity_values = np.concatenate((binF_Near[:, :], binF_Far[:, :]))
    threshold = np.mean(activity_values) + threshold_std * np.std(activity_values)

    # Find max peak position for each cell
    max_peak_positions = np.argmax(np.maximum(non_negative(binF_Near), non_negative(binF_Far)), axis=1)

    percentage_data = np.zeros((binF_Near.shape[1], 3))

    for position in range(binF_Near.shape[1]):
        corr_coeffs = []
        max_amp_diff_scores = []

        for j in range(len(max_peak_positions)):
            if max_peak_positions[j] == position:
                if np.max(binF_Near[j, :]) >= threshold or np.max(binF_Far[j, :]) >= threshold:
                    corr_coeff, _ = pearsonr(binF_Near[j, :], binF_Far[j, :])
                    corr_coeffs.append(corr_coeff)

                    max_amp_ratio = np.max(binF_Near[j, :]) / np.max(binF_Far[j, :])
                    max_amp_diff_score = 1 - min(max_amp_ratio, 1/max_amp_ratio)
                    max_amp_diff_scores.append(max_amp_diff_score)

        # Calculate the percentages for each quadrant at the current position
        if len(corr_coeffs) > 0:
            percentages = calculate_percentages(corr_coeffs, max_amp_diff_scores, threshold=0.2)
            percentage_data[position, :] = percentages

    all_percentages.append(percentage_data)

# Calculate mean and SEM across animals
mean_percentages = np.nanmean(all_percentages, axis=0)
sem_percentages = np.nanstd(all_percentages, axis=0) / np.sqrt(len(all_percentages))

# Create a figure and axis for the position vs percentage plot
fig, ax = plt.subplots(figsize=(10, 6), dpi=300)

# Plot the mean percentage for each quadrant along positions with SEM shading
positions = np.arange(1, 47)  # Assuming 46 position bins
for i, quadrant in enumerate(['Splitter', 'Multiplex Splitter', 'Place Cell']):
    ax.plot(positions, mean_percentages[:, i], marker='o', label=quadrant, linewidth=2)
    ax.fill_between(positions, mean_percentages[:, i] - sem_percentages[:, i], mean_percentages[:, i] + sem_percentages[:, i],
                    alpha=0.3, edgecolor='none')

# Set axis labels and title
ax.set_xlabel('Position', fontsize=14, fontweight='bold')
ax.set_ylabel('Percentage (%)', fontsize=14, fontweight='bold')
ax.set_title('Percentage of Cell Types Along Positions (Last Session, All Animals)', fontsize=16, fontweight='bold')

# Set tick parameters and remove top and right spines
ax.tick_params(axis='both', labelsize=14)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Add legend
ax.legend(fontsize=12, frameon=False, loc='upper left')

# Adjust layout and save the plot
plt.tight_layout()
#os.makedirs('plots', exist_ok=True)
#plt.savefig('plots/all_animals_percentage_along_positions.pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Create a figure and axis for the position vs percentage plot
fig, ax = plt.subplots(figsize=(6, 6), dpi=300)

# Define colors for each category
colors = ['chocolate', 'mediumseagreen','slategray' ]

# Plot the mean percentage for each quadrant along positions with SEM shading
positions = np.arange(1, 47)*5  # Assuming 46 position bins
for i, quadrant in enumerate(['Splitter', 'Multiplex Splitter', 'Place Cell']):
    ax.plot(positions, mean_percentages[:, i], marker='o', label=quadrant, linewidth=2, color=colors[i])
    ax.fill_between(positions, mean_percentages[:, i] - sem_percentages[:, i], mean_percentages[:, i] + sem_percentages[:, i],
                    alpha=0.3, color=colors[i], edgecolor='none')

# Set axis labels and title
ax.set_xlabel('Position (cm)', fontsize=16)
ax.set_ylabel('Percentage (%)', fontsize=16)
#ax.set_title('Percentage of Cell Types Along Positions (Last Session, All Animals)', fontsize=18)

# Set y-axis range from 0 to 100
ax.set_ylim(-5, 105)

# Set tick parameters and remove top and right spines
ax.tick_params(axis='both', labelsize=14)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Add axis tick marks
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')

# Adjust layout and save the plot
plt.tight_layout()
plt.savefig('Figure_3_FINAL/all_animals_percentage_along_positions_colored.pdf', dpi=300, bbox_inches='tight')



In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from scipy.ndimage import gaussian_filter
import vr2p
import zarr

def non_negative(arr):
    return np.where(arr < 0, 0, arr)

# Get the list of animal names
names = os.listdir('/.../Set A/')
names = [name for name in names if name != '.DS_Store']

# Define the cell types and their corresponding positions
cell_types = [
    {'name': 'Track Boundary', 'position': [(0, 12), (40, 46)]},
    {'name': 'Intra-Track', 'position': [(20, 40)]},
    {'name': 'Indicator', 'position': [(12, 20)]},
]

# Create a directory to save the data and plots
os.makedirs('plots/heatmaps', exist_ok=True)

for cell_type in cell_types:
    # Check if the data file exists
    data_file = f'plots/heatmaps/heatmap_data_{cell_type["name"]}.npz'
    if os.path.exists(data_file):
        # Load the saved data
        loaded_data = np.load(data_file)
        avg_heatmap_first = loaded_data['avg_heatmap_first']
        avg_heatmap_middle = loaded_data['avg_heatmap_middle']
        avg_heatmap_last = loaded_data['avg_heatmap_last']
    else:
        # Initialize variables to store heatmaps for each session and cell type
        heatmaps_first = []
        heatmaps_middle = []
        heatmaps_last = []

        for animal_name in names:
            animal = animal_name[:8]
            print(f"Processing animal: {animal}")

            # Load data
            path = f'/.../Set A/{animal_name}/'
            data = vr2p.ExperimentData(path)

            # Generate index for days animal is performing Cue Set A only
            day_count = []
            for i in range(len(data.signals.multi_session.F)):
                if ('Cue Set A' in data.vr[i].trial.set.unique()) and (len(data.vr[i].trial.set.unique())==1):
                    day_count.append(i)
                else:
                    break

            print(max(day_count))

            # Load stored place field analysis for each day
            range_A = range(max(day_count)+1)
            criteria = 'significant'
            zarr_file = zarr.open(f'/SetA/{animal}-PF.zarr', mode="r")
            pf_all_Near = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()] for i in range_A]
            pf_all_Far = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()] for i in range_A]
            binF_Near = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]
            binF_Far = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]

            sessions = range_A

            # Define threshold based on mean and standard deviation
            threshold_std = 2  # Number of standard deviations above the mean

            for session in [sessions[0], sessions[len(sessions)//2], sessions[-1]]:
                corr_coeffs = []
                max_amp_diff_scores = []
                positions = []

                # Calculate max amplitude and positions for each cell
                max_amplitude_Near_idx = np.argmax(non_negative(binF_Near[session]), axis=1)
                max_amplitude_Far_idx = np.argmax(non_negative(binF_Far[session]), axis=1)
                max_amplitude_Near = binF_Near[session][np.arange(len(max_amplitude_Near_idx)), max_amplitude_Near_idx]
                max_amplitude_Far = binF_Far[session][np.arange(len(max_amplitude_Far_idx)), max_amplitude_Far_idx]

                # Calculate threshold based on mean and standard deviation of activity values for the session
                activity_values = np.concatenate((binF_Near[session][:, :], binF_Far[session][:, :]))
                threshold = np.mean(activity_values) + threshold_std * np.std(activity_values)

                for j in range(len(max_amplitude_Near)):
                    if max_amplitude_Near[j] >= threshold or max_amplitude_Far[j] >= threshold:
                        corr_coeff, _ = pearsonr(binF_Near[session][j, :], binF_Far[session][j, :])
                        max_amp_ratio = max_amplitude_Near[j] / max_amplitude_Far[j]
                        max_amp_diff_score = 1 - min(max_amp_ratio, 1/max_amp_ratio)

                        position_Near = max_amplitude_Near_idx[j]
                        position_Far = max_amplitude_Far_idx[j]

                        # Check if the cell belongs to the current cell type
                        for pos_range in cell_type['position']:
                            if pos_range[0] <= position_Near < pos_range[1] or pos_range[0] <= position_Far < pos_range[1]:
                                corr_coeffs.append(corr_coeff)
                                max_amp_diff_scores.append(max_amp_diff_score)
                                positions.append((position_Near, position_Far))
                                break

                # Create heatmap
                heatmap, xedges, yedges = np.histogram2d(corr_coeffs, max_amp_diff_scores, bins=50, range=[[-1, 1], [0, 1]], density=True)

                # Apply Gaussian smoothing to the heatmap
                smoothed_heatmap = gaussian_filter(heatmap, sigma=1)

                # Append the smoothed heatmap to the corresponding list
                if session == sessions[0]:
                    heatmaps_first.append(smoothed_heatmap)
                elif session == sessions[len(sessions)//2]:
                    heatmaps_middle.append(smoothed_heatmap)
                else:
                    heatmaps_last.append(smoothed_heatmap)

        # Calculate average heatmaps for each session
        avg_heatmap_first = np.mean(heatmaps_first, axis=0)
        avg_heatmap_middle = np.mean(heatmaps_middle, axis=0)
        avg_heatmap_last = np.mean(heatmaps_last, axis=0)

        np.savez(data_file, avg_heatmap_first=avg_heatmap_first, avg_heatmap_middle=avg_heatmap_middle, avg_heatmap_last=avg_heatmap_last)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6), dpi=300)
    plt.style.use('default')
    cmap = 'viridis'

    # Plot the average heatmap for the first session
    im1 = ax1.imshow(avg_heatmap_first.T, origin='lower', extent=[-1, 1, 0, 1], cmap=cmap, aspect='auto', vmin=0, vmax=1.5)
    ax1.set_xlabel('Correlation', fontsize=20, fontweight='bold')
    ax1.set_ylabel('Difference Score', fontsize=20, fontweight='bold')
    ax1.set_title(f'{cell_type["name"]} - First Session', fontsize=24, fontweight='bold')
    ax1.tick_params(axis='both', labelsize=18, width=2, length=6)
    ax1.spines['top'].set_linewidth(1.5)
    ax1.spines['right'].set_linewidth(1.5)
    ax1.spines['bottom'].set_linewidth(1.5)
    ax1.spines['left'].set_linewidth(1.5)

    # Plot the average heatmap for the middle session
    im2 = ax2.imshow(avg_heatmap_middle.T, origin='lower', extent=[-1, 1, 0, 1], cmap=cmap, aspect='auto', vmin=0, vmax=1.5)
    ax2.set_xlabel('Correlation', fontsize=20, fontweight='bold')
    ax2.set_title(f'{cell_type["name"]} - Middle Session', fontsize=24, fontweight='bold')
    ax2.tick_params(axis='both', labelsize=18, width=2, length=6)
    ax2.spines['top'].set_linewidth(1.5)
    ax2.spines['right'].set_linewidth(1.5)
    ax2.spines['bottom'].set_linewidth(1.5)
    ax2.spines['left'].set_linewidth(1.5)

    # Plot the average heatmap for the last session
    im3 = ax3.imshow(avg_heatmap_last.T, origin='lower', extent=[-1, 1, 0, 1], cmap=cmap, aspect='auto', vmin=0, vmax=1.5)
    ax3.set_xlabel('Correlation', fontsize=20, fontweight='bold')
    ax3.set_title(f'{cell_type["name"]} - Last Session', fontsize=24, fontweight='bold')
    ax3.tick_params(axis='both', labelsize=18, width=2, length=6)
    ax3.spines['top'].set_linewidth(1.5)
    ax3.spines['right'].set_linewidth(1.5)
    ax3.spines['bottom'].set_linewidth(1.5)
    ax3.spines['left'].set_linewidth(1.5)

    plt.tight_layout(pad=4.0)
    plt.savefig(f'plots/heatmaps/all_animals_heatmaps_{cell_type["name"]}_first_middle_last_formatted.pdf', dpi=300, bbox_inches='tight', facecolor='w')
    plt.show()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns

# Define the cell types and their corresponding colors
cell_types = [
    {'name': 'Track Boundary', 'color': 'gray'},
    {'name': 'Indicator', 'color': 'orange'},
    {'name': 'Intra-Track', 'color': 'mediumturquoise'},
]

sns.set(style="ticks")

fig, axes = plt.subplots(3, 3, figsize=(16, 16), dpi=500, sharex=True, sharey=True)

for i, cell_type in enumerate(cell_types):
    data_file = f'plots/heatmaps/heatmap_data_{cell_type["name"]}.npz'
    loaded_data = np.load(data_file)
    avg_heatmap_first = loaded_data['avg_heatmap_first']
    avg_heatmap_middle = loaded_data['avg_heatmap_middle']
    avg_heatmap_last = loaded_data['avg_heatmap_last']

    aspect_ratio = 2/(1.1)

    cmap = LinearSegmentedColormap.from_list("custom_cmap", ["white", cell_type['color']])

    # Plot the average heatmap for the first session
    im1 = axes[i, 0].imshow(avg_heatmap_first.T, origin='lower', extent=[-1, 1, 0, 1], cmap=cmap, vmin=0, vmax=1.5, aspect=aspect_ratio)
    axes[i, 0].set_xticks([-1, 0, 1])
    axes[i, 0].set_yticks([0, 0.5, 1])

    # Plot the average heatmap for the middle session
    im2 = axes[i, 1].imshow(avg_heatmap_middle.T, origin='lower', extent=[-1, 1, 0, 1], cmap=cmap, vmin=0, vmax=1.5, aspect=aspect_ratio)
    axes[i, 1].set_xticks([-1, 0, 1])
    axes[i, 1].set_yticks([0, 0.5, 1])

    # Plot the average heatmap for the last session
    im3 = axes[i, 2].imshow(avg_heatmap_last.T, origin='lower', extent=[-1, 1, 0, 1], cmap=cmap, vmin=0, vmax=1.5, aspect=aspect_ratio)
    axes[i, 2].set_xticks([-1, 0, 1])
    axes[i, 2].set_yticks([0, 0.5, 1])

# Set x-axis tick labels, font size, and y-axis limits for all subplots
for i in range(3):
    for j in range(3):
        axes[i, j].set_xticklabels([-1, 0, 1], fontsize=12)
        axes[i, j].set_yticklabels([0, 0.5, 1], fontsize=12)
        axes[i, j].tick_params(axis='both', labelsize=12, width=2, length=6)
        axes[i, j].spines['right'].set_visible(False)
        axes[i, j].spines['top'].set_visible(False)
        axes[i, j].spines['bottom'].set_visible(True)
        axes[i, j].spines['left'].set_visible(True)
        axes[i, j].set_ylim(-0.05, 1.05)  # Set y-axis limits from -0.05 to 1.05

# Add y-axis labels for the subplots in the first column
for i in range(3):
    axes[i, 0].set_ylabel('Difference Score', fontsize=14)

# Add x-axis labels for the subplots in the last row
for j in range(3):
    axes[2, j].set_xlabel('Correlation (Near vs Far Trials)', fontsize=14)

# Adjust the spacing between subplots
plt.subplots_adjust(wspace=0.05, hspace=0.05)

# Save the main figure as a PDF
plt.savefig('Figure_3_FINAL/all_animals_heatmaps_first_middle_last_minimal_seaborn_3by3_square_subplots_revised.pdf', dpi=300, bbox_inches='tight', facecolor='w')

plt.show()

In [None]:

# Create a separate color bar plot
fig_cbar, ax_cbar = plt.subplots(figsize=(1, 12), dpi=300)
norm = plt.Normalize(vmin=0, vmax=1.5)
sm = plt.cm.ScalarMappable(cmap='gray_r', norm=norm)
sm.set_array([])  # Dummy array for the ScalarMappable
cbar = fig_cbar.colorbar(sm, cax=ax_cbar, orientation='vertical')
cbar.set_label('Normalized density', fontsize=40)
cbar.ax.tick_params(labelsize=30)  # Increase tick label size
cbar.ax.yaxis.set_tick_params(width=2)  # Increase tick size

# Save the color bar figure as a PDF
plt.savefig('Figure_3_FINAL/colorbar.pdf', dpi=500, bbox_inches='tight', facecolor='w')
plt.show()