In [None]:
from Py2P import core,sync,plot
from matplotlib import pyplot as plt
import os
import shutil

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
datapath = 'suite2p\\plane0\\'
syncfile = 'rec_000_000.mat'
sequencefile = 'stim_dict.json'

trials_names = {0:"IPSI",1:"CONTRA",2:"BOTH"}

In [None]:
synchro = sync.Sync()
synchro.generate_data_structure(syncfile,sequencefile,trials_names)
rec = core.Rec2P(data_path=datapath, sync=synchro)

In [None]:
rec.load_params()
cells = rec.get_cells()

In [None]:
### analyze and plot populations ###

pops = rec.get_populations(stims_names=["chirp"],
                           trials_names=None, 
                           n_clusters=None, 
                           use_tsne=True,
                           type='dff', 
                           normalize='z', 
                           plot=True)

plt.savefig("PCA_populations.png")

In [None]:
### plot ###

import numpy as np

## plot FOV
plot.plot_FOV(pops, rec, k=3)

ordered_trials = sorted(list(trials_names.values()))
stim_names = synchro.stims_names

for id,pop in enumerate(pops):

    save_path = "POPULATION_#%d"%id

    if os.path.isdir(save_path):

        shutil.rmtree(save_path)

    os.mkdir(save_path)

    pop_cells = [cells[c] for c in pop]


    ### plot heatmap

    fig, axs = plt.subplots(len(trials_names),len(stim_names),figsize=(10,15))

    fig.suptitle('POPULATION #%d, N:%d'%(id,len(pop_cells)), fontsize=16)

    for i,trial in enumerate(ordered_trials):

        for j,stim in enumerate(stim_names):

            if not isinstance(axs, np.ndarray):

                ax = axs

            elif len(stim_names)==1:

                ax=axs[i]

            elif len(ordered_trials)==1:

                ax=axs[j]

            else:
                ax=axs[i,j]


            if trial in synchro.sync_ds[stim]:


                plot.plot_averages_heatmap(pop_cells, 
                                            synchro, 
                                            stims=stim, 
                                            trials=trial, 
                                            type='dff', 
                                            stim_bar=False,
                                            vmin=None, 
                                            vmax=None, 
                                            normalize='z', 
                                            cb_label="\u0394F/F (z-score)",
                                            ax=ax)

            else:

                ax.axis("off")

            if i == 0: ax.set_title(stim,fontsize=20)

    if len(ordered_trials)>1:

        for ax, trial in zip(axs, ordered_trials):

            ax[0].set_ylabel(trial,fontsize=18)
    else:

        ax.set_ylabel(trial,fontsize=18)

    plt.savefig("%s\\heatmaps_dff.png"%save_path, bbox_inches="tight")
    plt.close(fig)

    ###
    #plot full heatmap
    fig, axs = plt.subplots(1,figsize=(10,7))
    fig.suptitle('POPULATION #%d, N:%d'%(id,len(pop_cells)), fontsize=16)

    plot.plot_averages_heatmap(pop_cells,
                                synchro,
                                full="dff",
                                vmin=None, 
                                vmax=None, 
                                normalize='z', 
                                cb_label="\u0394F/F (z-score)", 
                                stim_bar=False,
                                ax=axs)
    
    plt.savefig("%s\\full_heatmap_dff.png"%save_path, bbox_inches="tight")
    plt.close(fig)

    ### plot averaged responses

    plot.plot_multipleStim(pop_cells, 
                           synchro, 
                           average=True, 
                           save=True, 
                           save_path=save_path, 
                           stims=None, 
                           trials=None, 
                           full='dff',
                           share_x=False, 
                           share_y=False,
                           group_trials=False, 
                           legend=True)


    ### plot all the cells

    save_path_allcells = save_path+"\\all_cells"

    if os.path.isdir(save_path_allcells):

        shutil.rmtree(save_path_allcells)

    os.mkdir(save_path_allcells)

    plot.plot_multipleStim(pop_cells, 
                           synchro, 
                           average=False, 
                           save=True, 
                           save_path=save_path_allcells, 
                           stims=None, 
                           trials=None, 
                           full='dff',
                           share_x=False, 
                           share_y=False,
                           group_trials=False, 
                           legend=True)
    
    
    ###

    # # plot modulation histogramns
    # pop_rmis = {}
    # pop_rand_mod = {}

    # cell_rmis = []
    # cell_rand_mod = []

    # for stim in ["full_field"]: 

    #     pop_rmis |= {stim:[]}
    #     pop_rand_mod |= {stim:[]}

    # for id in tqdm(pop):

    #     cell = cells[id]

    #     for stim in ["full_field"]: 
           
    #        rmi = cell.calculate_modulation(stim, "BOTH", "CONTRA")
    #        rand_mod = cell.calculate_random_modulation(stim, "CONTRA", n_shuff=100)

    #        pop_rmis[stim].append(rmi)
    #        pop_rand_mod[stim].append(rand_mod)

    # plot.plot_histogram(pop_rmis, pop_rand_mod, "%s\\BothVsContra.png"%save_path)


