In [None]:
# %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString
import pickle
import seaborn as sns
import os

import vdmlab as vdm

from load_data import get_pos, get_spikes, get_csc
from tuning_curves_functions import get_tc, get_odd_firing_idx, linearize

import info.R063d2_info as r063d2
import info.R063d3_info as r063d3
import info.R063d4_info as r063d4
import info.R063d5_info as r063d5
import info.R063d6_info as r063d6
import info.R066d1_info as r066d1
import info.R066d2_info as r066d2
import info.R066d3_info as r066d3
import info.R066d4_info as r066d4
import info.R067d1_info as r067d1

In [None]:
sns.set_style('white')
sns.set_style('ticks')

In [None]:
pickle_filepath = 'E:\\code\\emi_shortcut\\cache\\pickled\\'
output_filepath = 'E:\\code\\emi_shortcut\\plots\\sequence\\'
# pickle_filepath = 'C:\\Users\\Emily\\Code\\emily_shortcut\\cache\\pickled\\'
# output_filepath = 'C:\\Users\\Emily\\Code\\emily_shortcut\\plots\\sequence\\'

In [None]:
info = r067d1

In [None]:
print(info.session_id)
pos = get_pos(info.pos_mat, info.pxl_to_cm)
csc = get_csc(info.good_swr[0])
spikes = get_spikes(info.spike_mat)

In [None]:
tc = get_tc(info, pos, pickle_filepath, expand_by=2)
swr_times, swr_idx, filtered_butter = vdm.detect_swr_hilbert(csc, fs=info.fs)

In [None]:
colours = ['#bd0026', '#fc4e2a', '#ef3b2c', '#ec7014', '#fe9929', 
           '#78c679', '#41ab5d', '#238443', '#66c2a4', '#41b6c4', 
           '#1d91c0', '#8c6bb1', '#225ea8', '#88419d', '#ae017e', 
           '#dd3497', '#f768a1', '#fcbba1', '#fc9272', '#fb6a4a', 
           '#e31a1c', '#fb6a4a', '#993404', '#b30000', '#800026',
           '#bd0026', '#fc4e2a', '#fb6a4a', '#ef3b2c', '#ec7014', 
           '#fe9929', '#78c679', '#41ab5d', '#238443', '#66c2a4', 
           '#41b6c4', '#1d91c0', '#8c6bb1', '#225ea8', '#88419d', 
           '#ae017e', '#dd3497', '#f768a1', '#fcbba1', '#fc9272', 
           '#fb6a4a', '#e31a1c', '#fb6a4a', '#993404', '#b30000']

In [None]:
all_u_fields = vdm.find_fields(tc['u'], hz_thresh=8, min_length=3, max_length=15)
all_shortcut_fields = vdm.find_fields(tc['shortcut'], hz_thresh=8, min_length=3, max_length=15)
all_novel_fields = vdm.find_fields(tc['novel'], hz_thresh=8, min_length=3, max_length=15)

u_fields = vdm.get_single_field(all_u_fields)
shortcut_fields = vdm.get_single_field(all_shortcut_fields)
novel_fields = vdm.get_single_field(all_novel_fields)

In [None]:
this_tc = tc['shortcut']

sort_idx = vdm.get_sort_idx(this_tc)

odd_firing_idx = get_odd_firing_idx(this_tc, max_mean_firing=10)

these_fields = []
for key in shortcut_fields:
    these_fields.append(key)

field_spikes = []
field_tc = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        if idx in these_fields:
            field_spikes.append(spikes['time'][idx])
            field_tc.append(this_tc[idx])

In [None]:
t_start = info.task_times['prerecord'][0]
t_stop = info.task_times['postrecord'][1]
linear, zone = linearize(info, pos, t_start, t_stop, expand_by=2)

In [None]:
this_linear = linear['shortcut']

In [None]:
print(len(these_fields))

In [None]:
# swr = 1991
# # for swr in [2090, 2196]:

# print(swr)
ms = 10
loc = 1

# start_time = info.task_times['postrecord'][0]
# stop_time = info.task_times['postrecord'][1]

# start_time_swr = info.task_times['postrecord'][0]
# stop_time_swr = info.task_times['postrecord'][0]+1

start_time = 15529
stop_time = 15573

start_time_swr = 16763.8
stop_time_swr = 16764.8

# start_time_swr = 16843.9
# stop_time_swr = 16844.2


# for i, (start_time, stop_time, start_time_swr, stop_time_swr) in enumerate(zip(info.sequence['run_start'], 
#                                                                                info.sequence['run_stop'], 
#                                                                                info.sequence['swr_start'], 
#                                                                                info.sequence['swr_stop'])):
spike_loc = 2

rows = len(field_spikes)+1
cols = 7
fig = plt.figure()
ax1 = plt.subplot2grid((rows, cols), (0, 1), rowspan=rows, colspan=4)
ax2 = plt.subplot2grid((rows, cols), (0, 5), rowspan=rows, colspan=2)

max_position = np.zeros(len(this_linear['time']))
max_position.fill(np.max(this_linear['position']/90))
ax1.plot(this_linear['time'], max_position, color='#bdbdbd', lw=1)
ax1.plot(this_linear['time'], this_linear['position']/90, 'b.', ms=4)
ax1.plot(csc['time'], csc['data']*800+1.25, 'k', lw=1)

for sw_start, sw_stop in zip(swr_idx['start'], swr_idx['stop']):
    ax1.plot(csc['time'][sw_start:sw_stop], csc['data'][sw_start:sw_stop]*800+1.25, 'r', lw=2)
    
for idx, neuron_spikes in enumerate(field_spikes):
    ax1.plot(neuron_spikes, np.ones(len(neuron_spikes))+(idx*spike_loc+3), '|', 
             color=colours[int(np.floor((idx*spike_loc+3)/spike_loc))], ms=ms, mew=2)
ax1.set_xlim([start_time, stop_time])
ax1.set_ylim([0, len(field_spikes)*spike_loc+3])
vdm.add_scalebar(ax1, matchy=False, loc=loc)
# plt.setp(ax1, xticks=[], xticklabels=[], yticks=[])

ax2.plot(csc['time'], csc['data']*1000+1.25, 'k', lw=1)
ax2.plot(csc['time'], filtered_butter*1000+0.5, 'b', lw=1)
for sw_start, sw_stop in zip(swr_idx['start'], swr_idx['stop']):
    ax2.plot(csc['time'][sw_start:sw_stop], csc['data'][sw_start:sw_stop]*1000+1.25, 'r', lw=2)

for j, neuron_spikes in enumerate(field_spikes):
    ax2.plot(neuron_spikes, np.ones(len(neuron_spikes))+(j*spike_loc+3), '|', 
             color=colours[int(np.floor((j*spike_loc+3)/spike_loc))], ms=ms, mew=2)
ax2.set_xlim([start_time_swr, stop_time_swr])
ax2.set_ylim([0, len(field_spikes)*spike_loc+3])
vdm.add_scalebar(ax2, matchy=False, loc=loc)
# plt.setp(ax2, xticks=[], xticklabels=[], yticks=[])

x = list(range(0, len(field_tc[0])))

for ax_loc in range(0, rows-1):
    ax = plt.subplot2grid((rows, cols), (ax_loc, 0))

    idx = rows - ax_loc - 1
    ax.plot(field_tc[idx-1], color=colours[idx-1])
    ax.fill_between(x, 0, field_tc[idx-1], facecolor=colours[idx-1])
    max_loc = np.where(field_tc[idx-1] == np.max(field_tc[idx-1]))[0][0]
    ax.text(max_loc-3, 1, str(int(np.ceil(np.max(field_tc[idx-1])))))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.setp(ax, xticks=[], xticklabels=[], yticks=[])

sns.despine()
plt.show()
#     filename = info.session_id + '_sequence-swr' + str(i) + '.png'
#     print(filename)
#     savepath = os.path.join(output_filepath, filename)
#     plt.savefig(savepath, dpi=300, bbox_inches='tight')
#     plt.close()