# <div style="text-align: center"> Code for correlation and angle matrices for all animals

## Load dependencies

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm, colors
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
import pandas as pd
from ipyfilechooser import FileChooser
from linear2ac.io import get_main_data_folder
import zarr
import vr2p


In [None]:
import os
names = os.listdir('/.../Set A/')
names = [name for name in names if name != '.DS_Store']
print(names)
print(f'Found {len(names)} files')

## Load place field data for a particular animal

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr



for animal_name in names:
    print(animal_name[:8])
    # load data.
    path = f'/.../Set A/{animal_name}/'


    data = vr2p.ExperimentData(path) # You can normally just put the file path here as text.


    # Generate index for days animal is performing Cue Set A only.
    day_count = []

    for i in range(len(data.signals.multi_session.F)):
        if ('Cue Set A' in data.vr[i].trial.set.unique()) and (len(data.vr[i].trial.set.unique())==1):
            day_count.append(i)
        else:
            break
    print(max(day_count))

    #load stored place field analysis for each day

    range_A = range(max(day_count)+1)
    criteria = 'putative'

    zarr_location = get_main_data_folder()/'placefields'/'50_600_SetA'/f"{animal_name[:8]}-PF.zarr"
    #zarr_file = zarr.open(zarr_location.as_posix(), mode="r")
    zarr_file = zarr.open(f'/nrs/spruston/Tyche/vr2p_datasets/placefields/50_600_SetA/{animal_name[:8]}-PF.zarr', mode="r")

    pf_all_T1 = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()] for i in range_A]
    pf_all_T2 = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()] for i in range_A]

    binF_T1 = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]
    binF_T2 = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]

    binF_T1_all = np.array(binF_T1).T
    binF_T2_all = np.array(binF_T2).T


    n_days = binF_T1_all.shape[2]
    n_cols = (max(day_count)+1)
    n_positions = binF_T1_all.shape[0]

    sessions_to_plot = np.arange(binF_T1_all.shape[2])

    fig, axs = plt.subplots(1, n_cols, figsize=(4*(max(day_count)+1), 4), dpi=600)

    for index, Session in enumerate(sessions_to_plot):
        corr_matrix1 = np.zeros((n_positions, n_positions))
        for i in range(n_positions):
            for j in range(n_positions):
                corr, _ = pearsonr(binF_T1_all[i, :, Session], binF_T2_all[j, :, Session])
                corr_matrix1[i, j] = corr

        sns.heatmap(corr_matrix1, cmap='icefire', vmin=-1, vmax=1, ax=axs[index], 
                    cbar=False, xticklabels=False, yticklabels=False, linewidths=0)  # set linewidths to 0
        axs[index].set_aspect('equal')  # make each subplot square
        # Set the title for the subplot
        axs[index].set_title(f'Session {Session + 1}')  # Titles start from 1 (not 0)

        # Loop for dotted lines
        for lines in [12, 20, 26, 30, 36, 40]:  
            axs[index].axvline(lines, linestyle=(0, (2, 5)), color='white', linewidth=1.5)  # dotted vertical lines
            axs[index].axhline(lines, linestyle=(0, (2, 5)), color='white', linewidth=1.5)  # dotted horizontal lines

        # Draw square bounding boxes
        for (low, high) in [(12, 20), (26, 30), (36, 40)]:
            axs[index].plot([low, high, high, low, low], [low, low, high, high, low], color='white',linewidth=3)  

    plt.tight_layout()

    plt.savefig(f'corr_plot_{animal_name[:8]}.pdf', format='pdf', dpi=600)
    plt.show()



In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import cosine

for animal_name in names:
    print(animal_name[:8])
    path = f'/.../Set A/{animal_name}/'
    data = vr2p.ExperimentData(path)

    day_count = []

    for i in range(len(data.signals.multi_session.F)):
        if ('Cue Set A' in data.vr[i].trial.set.unique()) and (len(data.vr[i].trial.set.unique())==1):
            day_count.append(i)
        else:
            break
    print(max(day_count))

    range_A = range(max(day_count)+1)
    criteria = 'putative'

    zarr_location = get_main_data_folder()/'placefields'/'SetA'/f"{animal_name[:8]}-PF.zarr"
    zarr_file = zarr.open(f'/.../placefields/SetA/{animal_name[:8]}-PF.zarr', mode="r")

    binF_T1 = [zarr_file[f'Cue Set A/1/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]
    binF_T2 = [zarr_file[f'Cue Set A/2/excl_no_response/{i}/{criteria}'][()]['binF'] for i in range_A]

    binF_T1_all = np.array(binF_T1).T
    binF_T2_all = np.array(binF_T2).T

    n_days = binF_T1_all.shape[2]
    n_cols = (max(day_count)+1)
    n_positions = binF_T1_all.shape[0]

    sessions_to_plot = np.arange(binF_T1_all.shape[2])

    fig, axs = plt.subplots(1, n_cols, figsize=(4*(max(day_count)+1), 4), dpi=600)

    for index, Session in enumerate(sessions_to_plot):
        angle_matrix = np.zeros((n_positions, n_positions))
        for i in range(n_positions):
            for j in range(n_positions):
                cos_sim = 1 - cosine(binF_T1_all[i, :, Session], binF_T2_all[j, :, Session])
                angle_rad = np.arccos(cos_sim)
                angle_deg = np.degrees(angle_rad)
                angle_matrix[i, j] = angle_deg

        #sns.heatmap(angle_matrix, cmap='icefire', vmin=0, vmax=90, ax=axs[index], 
                    #cbar=False, xticklabels=False, yticklabels=False, linewidths=0)
        sns.heatmap(angle_matrix, cmap='Blues', vmin=0, vmax=90, ax=axs[index], 
                    cbar=False, xticklabels=False, yticklabels=False, linewidths=0)

        axs[index].set_aspect('equal')
        axs[index].set_title(f'Session {Session + 1}')  

        for lines in [12, 20, 26, 30, 36, 40]:  
            axs[index].axvline(lines, linestyle=(0, (2, 5)), color='white', linewidth=1.5)
            axs[index].axhline(lines, linestyle=(0, (2, 5)), color='white', linewidth=1.5)

        for (low, high) in [(12, 20), (26, 30), (36, 40)]:
            axs[index].plot([low, high, high, low, low], [low, low, high, high, low], color='white',linewidth=3)  

    plt.tight_layout()

    plt.savefig(f'angle_plot_{animal_name[:8]}.pdf', format='pdf', dpi=600)
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl

# Define a normalizer object using the min and max of your data range
norm = mpl.colors.Normalize(vmin=0, vmax=90)

# Create a scalar mappable object with the colormap you want
mappable = mpl.cm.ScalarMappable(norm=norm, cmap='Blues')

# Create a new figure and colorbar, using the scalar mappable object
fig, ax = plt.subplots(figsize=(1, 6), dpi=300)
fig.subplots_adjust(left=0.5)
cb = fig.colorbar(mappable, cax=ax, orientation='vertical')

cb.set_label('Angles (degrees)')

# Save the figure as a PDF
plt.savefig("colorbar.pdf", format='pdf', bbox_inches='tight')
plt.show()
