In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import itertools
import scipy
import os
import nept

from loading_data import get_data

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "phase_fields")

In [None]:
from analyze_tc_shifts import get_tuning_curves, get_pearsons_correlation, find_intersection, find_neighbours, plot_tc_corr, compare_correlations

In [None]:
import info.r067d6 as r067d6
import info.r067d7 as r067d7
import info.r068d7 as r068d7

In [None]:
infos = [r067d7, r068d7]

In [None]:
corr_stable12 = []
corr_stable13 = []
corr_stable23 = []

corr_novel12 = []
corr_novel13 = []
corr_novel23 = []

for info in infos:
    print(info.session_id)
    events, position, spikes, lfp, _ = get_data(info)
    xedges, yedges = nept.get_xyedges(position, binsize=3)

    tc_shape = (len(yedges) - 1, len(xedges) - 1)

    shortcut1 = find_intersection(info, "shortcut1", xedges, yedges)
    shortcut2 = find_intersection(info, "shortcut2", xedges, yedges)
    novel1 = find_intersection(info, "novel1", xedges, yedges)
#     novel2 = find_intersec/tion(info, "novel2", xedges, yedges)
    stable1 = find_intersection(info, "stable1", xedges, yedges)

    novel_points = [shortcut1, shortcut2, novel1]
    stable_points = [stable1]
    novel_neighbours = find_neighbours(tc_shape, novel_points, neighbour_size=2)
    stable_neighbours = find_neighbours(tc_shape, stable_points, neighbour_size=2)

    corr12 = get_pearsons_correlation(info, "phase1", "phase2", xedges, yedges, position, spikes)
    corr13 = get_pearsons_correlation(info, "phase1", "phase3", xedges, yedges, position, spikes)
    corr23 = get_pearsons_correlation(info, "phase2", "phase3", xedges, yedges, position, spikes)
    corr33 = get_pearsons_correlation(info, "phase3", "phase3", xedges, yedges, position, spikes)

    stable12, novel12 = compare_correlations(corr12, stable_neighbours, novel_neighbours)
    stable13, novel13 = compare_correlations(corr13, stable_neighbours, novel_neighbours)
    stable23, novel23 = compare_correlations(corr23, stable_neighbours, novel_neighbours)

    if not np.isnan(stable12):
        corr_stable12.append(stable12)
    if not np.isnan(novel12):
        corr_novel12.append(novel12)
    if not np.isnan(stable13):
        corr_stable13.append(stable13)
    if not np.isnan(novel13):
        corr_novel13.append(novel13)
    if not np.isnan(stable23):
        corr_stable23.append(stable23)
    if not np.isnan(novel23):
        corr_novel23.append(novel23)

    print("phases 1 and 2. Average correlation for stable:", stable12, "compared to novel:", novel12, "segments")
    print("phases 1 and 3. Average correlation for stable:", stable13, "compared to novel:", novel13, "segments")
    print("phases 2 and 3. Average correlation for stable:", stable23, "compared to novel:", novel23, "segments")
    
    filepath = None

#     filepath = os.path.join(output_filepath, info.session_id + "_phase-shift12.png")
    plot_tc_corr(corr12, stable_neighbours, novel_neighbours, filepath)

#     filepath = os.path.join(output_filepath, info.session_id + "_phase-shift13.png")
    plot_tc_corr(corr13, stable_neighbours, novel_neighbours, filepath)

#     filepath = os.path.join(output_filepath, info.session_id + "_phase-shift23.png")
    plot_tc_corr(corr23, stable_neighbours, novel_neighbours, filepath)

#     filepath = os.path.join(output_filepath, info.session_id + "_phase-shift33.png")
    plot_tc_corr(corr33, stable_neighbours, novel_neighbours, filepath)

print([corr_stable12, corr_novel12, corr_stable13, corr_novel13, corr_stable23, corr_novel23])
x = np.arange(6) + 1
plt.boxplot([corr_stable12, corr_novel12, corr_stable13, corr_novel13, corr_stable23, corr_novel23])
labels = ["corr_stable12", "corr_novel12", "corr_stable13", "corr_novel13", "corr_stable23", "corr_novel23"]
plt.xticks(x, labels, rotation='vertical')
plt.show()

In [None]:
# events, position, spikes, lfp, _ = get_data(r067d7)

# plt.plot(position.x, position.y, "k.", ms=3)
# plt.plot(r067d7.path_pts["stable1"][0], r067d7.path_pts["stable1"][1], "r.", ms=10)
# plt.show()

In [None]:
tuning_curves = get_tuning_curves(info, position, spikes, xedges, yedges, phase="phase1")

In [None]:
1/0

In [None]:
fig, ax = plt.subplots()

boxplot = ax.boxplot([corr_stable12, corr_novel12, corr_stable12, corr_novel12, corr_stable12, corr_novel12], 
                     positions=[1, 2, 4, 5, 7, 8], widths=0.75, patch_artist=True)

colours = ['#bf812d', '#35978f', '#bf812d', '#35978f', '#bf812d', '#35978f']
for patch, colour in zip(boxplot['boxes'], colours):
    patch.set_facecolor(colour)

plt.setp(boxplot['medians'], color='k')

plt.ylim(0.3, 1.)
labels = ["Phases 1-2", "Phases 1-3", "Phases 2-3"]
plt.xticks([1.5, 4.5, 7.5], labels)
plt.ylabel("Mean correlation")

hB, = plt.plot([1, 1], '-', color='#bf812d')
hR, = plt.plot([1, 1],'-', color='#35978f')
plt.legend((hB, hR), ('Stable segments', 'Novel segments'), bbox_to_anchor=(1., 1.))
hB.set_visible(False)
hR.set_visible(False)

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

plt.tight_layout()
plt.show()

In [None]:
info = r067d7

events, position, spikes, lfp, lfp_theta = get_data(info)
xedges, yedges = nept.get_xyedges(position, binsize=3)

tc_shape = (len(yedges) - 1, len(xedges) - 1)

In [None]:
position1 = position.time_slice(info.task_times["phase1"].start, info.task_times["phase1"].stop)
position2 = position.time_slice(info.task_times["phase2"].start, info.task_times["phase2"].stop)
position3 = position.time_slice(info.task_times["phase3"].start, info.task_times["phase3"].stop)

In [None]:
plt.plot(position1.x, position1.y, "g.")
plt.plot(position2.x, position2.y, "b.")
plt.plot(position3.x, position3.y, "r.")
plt.show()

In [None]:
info.session_id

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

phase = "phase1"
sliced_position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)
sliced_spikes = [spiketrain.time_slice(info.task_times[phase].start, info.task_times[phase].stop) for spiketrain in spikes]

neurons = get_tuning_curves(info, sliced_position, sliced_spikes, xedges, yedges, speed_limit=4., 
                            phase_id=phase, min_n_spikes=None, trial_times=None, trial_number=None, cache=False)

multiple_tuning_curves = np.zeros(neurons.tuning_shape)
for i in range(neurons.n_neurons):
# for i in [9, 10]:
#     print(i)
    multiple_tuning_curves += neurons.tuning_curves[i]

plt.figure(figsize=(6, 5))
pp = plt.pcolormesh(xx, yy, multiple_tuning_curves, vmin=0.01, cmap="Greys")
plt.colorbar(pp)
#     plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

phase = "phase3"
sliced_position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)
sliced_spikes = [spiketrain.time_slice(info.task_times[phase].start, info.task_times[phase].stop) for spiketrain in spikes]

neurons = get_tuning_curves(info, sliced_position, sliced_spikes, xedges, yedges, speed_limit=3., 
                            phase_id=phase, min_n_spikes=None, trial_times=None, trial_number=None, cache=False)

t = np.zeros(neurons.tuning_shape)
multiple_tuning_curves = np.zeros(neurons.tuning_shape)

for i in range(neurons.n_neurons):
#     print(i)
    multiple_tuning_curves += neurons.tuning_curves[i]
    t += neurons.tuning_curves[i]

plt.figure(figsize=(6, 5))
# pp = plt.pcolormesh(xx, yy, t, vmin=0.01, vmax=0.5, cmap="Blues")
pp = plt.pcolormesh(xx, yy, multiple_tuning_curves, vmin=0.01, cmap="Greys")
plt.colorbar(pp)
#     plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
np.where(multiple_tuning_curves == np.max(multiple_tuning_curves))[0]

In [None]:
np.where(multiple_tuning_curves == np.max(multiple_tuning_curves))[0]

In [None]:
total_spikes = 0
for i in range(neurons.n_neurons):
    n_spikes = len(spikes[i].time)
    print(n_spikes)
    total_spikes += n_spikes
print('total spikes:', total_spikes)

In [None]:
phase1 = "phase1"
phase2 = "phase3"

sliced_position1 = position.time_slice(info.task_times[phase1].start, info.task_times[phase1].stop)
sliced_spikes1 = [spiketrain.time_slice(info.task_times[phase1].start, info.task_times[phase1].stop) for spiketrain in spikes]
neurons1 = get_tuning_curves(info, sliced_position1, sliced_spikes1, xedges, yedges, speed_limit=4., phase_id=phase1, min_n_spikes=None, trial_times=None, trial_number=None, cache=False)

sliced_position2 = position.time_slice(info.task_times[phase2].start, info.task_times[phase2].stop)
sliced_spikes2 = [spiketrain.time_slice(info.task_times[phase2].start, info.task_times[phase2].stop) for spiketrain in spikes]
neurons2 = get_tuning_curves(info, sliced_position2, sliced_spikes2, xedges, yedges, speed_limit=4., phase_id=phase2, min_n_spikes=None, trial_times=None, trial_number=None, cache=False)

In [None]:
xx, yy = np.meshgrid(xedges, yedges)
cmap = plt.cm.get_cmap('bone_r')

In [None]:
neurons = neurons2

for i in range(neurons.n_neurons):
    multiple_tuning_curves = np.zeros(neurons.tuning_shape)
#     print(i)
    multiple_tuning_curves += neurons.tuning_curves[i]

plt.figure(figsize=(6, 5))
pp = plt.pcolormesh(xx, yy, multiple_tuning_curves, vmin=0.01, cmap=cmap)
#     plt.colorbar(pp)
plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# for info in infos:
events, position, spikes, lfp, lfp_theta = get_data(info)
xedges, yedges = nept.get_xyedges(position)

tc_shape = (len(yedges) - 1, len(xedges) - 1)

shortcut1 = find_intersection(info, "shortcut1", xedges, yedges)
shortcut2 = find_intersection(info, "shortcut2", xedges, yedges)
novel1 = find_intersection(info, "novel1", xedges, yedges)
# novel2 = find_intersection(info, "novel2", xedges, yedges)
stable1 = (np.array([1]), np.array([10]))

novel_points = [shortcut1, shortcut2, novel1]
stable_points = [stable1]

novel_neighbours = find_neighbours(tc_shape, novel_points, neighbour_size=2)
stable_neighbours = find_neighbours(tc_shape, stable_points, neighbour_size=2)

corr12 = get_pearsons_correlation(info, "phase1", "phase2", xedges, yedges, position, spikes)
corr13 = get_pearsons_correlation(info, "phase1", "phase3", xedges, yedges, position, spikes)
corr23 = get_pearsons_correlation(info, "phase2", "phase3", xedges, yedges, position, spikes)
corr33 = get_pearsons_correlation(info, "phase3", "phase3", xedges, yedges, position, spikes)

filepath = os.path.join(output_filepath, info.session_id + "_phase-shift12.png")
plot_tc_corr(corr12, stable_neighbours, novel_neighbours)

plot_tc_corr(corr13, stable_neighbours, novel_neighbours)

plot_tc_corr(corr23, stable_neighbours, novel_neighbours)

filepath = os.path.join(output_filepath, info.session_id + "_phase-shift33.png")
plot_tc_corr(corr33, stable_neighbours, novel_neighbours)

In [None]:
def compare_correlations(correlations, stable_neighbours, novel_neighbours):
    stable_corr = [correlations[pt[1]][pt[0]] for pt in stable_neighbours]
    novel_corr = [correlations[pt[1]][pt[0]] for pt in novel_neighbours]
    
    stable = np.nanmean(stable_corr)
    novel = np.nanmean(novel_corr)

    return stable, novel

compare_correlations(corr12, stable_neighbours, novel_neighbours)

In [None]:
compare_correlations(corr13, stable_neighbours, novel_neighbours)

In [None]:
compare_correlations(corr23, stable_neighbours, novel_neighbours)

In [None]:
compare_correlations(corr33, stable_neighbours, novel_neighbours)

In [None]:
yy = [0.9448256327499559, 0.87462859828618689] 
tt = [0.81271753867513286, 0.38649566177067601]

In [None]:
x = np.arange(2) + 1
plt.boxplot([yy, tt])
labels = ['yy', 'tt']
plt.xticks(x, labels, rotation='vertical')
plt.show()