In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import itertools
import scipy
import os
import nept

from loading_data import get_data

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "phase_fields")

In [None]:
from analyze_tc_shifts import get_tuning_curves, get_pearsons_correlation, find_intersection, find_neighbours, plot_tc_corr, compare_correlations

In [None]:
import info.r067d6 as r067d6
import info.r067d7 as r067d7
import info.r068d7 as r068d7
# infos = [r067d7, r068d7]

from run import spike_sorted_infos
infos = spike_sorted_infos

In [None]:
def plot_correlation(correlation, stable_neighbours, novel_neighbours, filepath):
    n_colours = 15
    colours = [(1., 1., 1.)]
    colours.extend(matplotlib.cm.copper_r(np.linspace(0, 1, n_colours-1)))
    cmap = matplotlib.colors.ListedColormap(colours)
    correlation[np.isnan(correlation)] = -np.nanmax(correlation) / n_colours

    plt.imshow(correlation, vmax=1.0, cmap=cmap)
    for point in stable_neighbours:
        plt.plot(point[0], point[1], 'r.', ms=15)
    for point in novel_neighbours:
        plt.plot(point[0], point[1], 'b.', ms=15)
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(filepath)
    plt.close()
#     plt.show()

In [None]:
for binsize in [4, 6, 8, 10, 12, 14, 16]:
    corr_stable12 = []
    corr_stable13 = []
    corr_stable23 = []

    corr_novel12 = []
    corr_novel13 = []
    corr_novel23 = []

    for info in infos:
        print(info.session_id)
        events, position, spikes, lfp, _ = get_data(info)
        xedges, yedges = nept.get_xyedges(position, binsize=binsize)

        tc_shape = (len(yedges) - 1, len(xedges) - 1)

        shortcut1 = find_intersection(info, "shortcut1", xedges, yedges)
        shortcut2 = find_intersection(info, "shortcut2", xedges, yedges)
        novel1 = find_intersection(info, "novel1", xedges, yedges)
    #     novel2 = find_intersec/tion(info, "novel2", xedges, yedges)
        stable1 = find_intersection(info, "stable1", xedges, yedges)

        novel_points = [shortcut1, shortcut2, novel1]
        stable_points = [stable1]
        novel_neighbours = find_neighbours(tc_shape, novel_points, neighbour_size=2)
        stable_neighbours = find_neighbours(tc_shape, stable_points, neighbour_size=2)

        corr12 = get_pearsons_correlation(info, "phase1", "phase2", xedges, yedges, position, spikes)
        corr13 = get_pearsons_correlation(info, "phase1", "phase3", xedges, yedges, position, spikes)
        corr23 = get_pearsons_correlation(info, "phase2", "phase3", xedges, yedges, position, spikes)
        corr33 = get_pearsons_correlation(info, "phase3", "phase3", xedges, yedges, position, spikes)

        stable12, novel12 = compare_correlations(corr12, stable_neighbours, novel_neighbours)
        stable13, novel13 = compare_correlations(corr13, stable_neighbours, novel_neighbours)
        stable23, novel23 = compare_correlations(corr23, stable_neighbours, novel_neighbours)

        if not np.isnan(stable12):
            corr_stable12.append(stable12)
        if not np.isnan(novel12):
            corr_novel12.append(novel12)
        if not np.isnan(stable13):
            corr_stable13.append(stable13)
        if not np.isnan(novel13):
            corr_novel13.append(novel13)
        if not np.isnan(stable23):
            corr_stable23.append(stable23)
        if not np.isnan(novel23):
            corr_novel23.append(novel23)

        print("phases 1 and 2. Mean correlation for stable:", stable12, "compared to novel:", novel12, "segments")
        print("phases 1 and 3. Mean correlation for stable:", stable13, "compared to novel:", novel13, "segments")
        print("phases 2 and 3. Mean correlation for stable:", stable23, "compared to novel:", novel23, "segments")

#         filepath = os.path.join(output_filepath, info.session_id + "_phase-shift12.png")
#         plot_correlation(corr12, stable_neighbours, novel_neighbours, filepath)
#         filepath = os.path.join(output_filepath, info.session_id + "_phase-shift13.png")
#         plot_correlation(corr13, stable_neighbours, novel_neighbours, filepath)
#         filepath = os.path.join(output_filepath, info.session_id + "_phase-shift23.png")
#         plot_correlation(corr23, stable_neighbours, novel_neighbours, filepath)

    print([corr_stable12, corr_novel12, corr_stable13, corr_novel13, corr_stable23, corr_novel23])
    x = np.arange(6) + 1
    plt.boxplot([corr_stable12, corr_novel12, corr_stable13, corr_novel13, corr_stable23, corr_novel23])
    labels = ["stable12", "novel12", "stable13", "novel13", "stable23", "novel23"]
    plt.xticks(x, labels, rotation='vertical')
    filepath = os.path.join(output_filepath, "Mean correlation_binsize-"+str(binsize)+".png")
    plt.tight_layout()
    plt.savefig(filepath)
    plt.close()
    # plt.show()