In [1]:
# %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString
import pickle
import seaborn as sns
import os

import vdmlab as vdm

from tuning_curves_functions import get_tc, get_odd_firing_idx

import info.R063d2_info as r063d2
import info.R063d3_info as r063d3
import info.R063d4_info as r063d4
import info.R063d5_info as r063d5
import info.R063d6_info as r063d6
import info.R066d1_info as r066d1
import info.R066d2_info as r066d2
import info.R066d4_info as r066d4

In [2]:
pickle_filepath = 'E:\\code\\vandermeerlab\\code-python\\projects\\emily_shortcut\\cache\\pickled\\'
output_filepath = 'E:\\code\\vandermeerlab\\code-python\\projects\\emily_shortcut\\plots\\'
# pickle_filepath = 'C:\\Users\\Emily\\Code\\vandermeerlab\\code-python\\projects\\emily_shortcut\\cache\\pickled\\'
# output_filepath = 'C:\\Users\\Emily\\Code\\vandermeerlab\\code-python\\projects\\emily_shortcut\\plots\\'

In [3]:
info = r063d4

In [4]:
print(info.session_id)
pos = info.get_pos(info.pxl_to_cm)
csc = info.get_csc()
spikes = info.get_spikes()

tc = get_tc(info, pos, pickle_filepath)

pickled_spike_heatmaps = pickle_filepath + info.session_id + '_spike_heatmaps.pkl'
if os.path.isfile(pickled_spike_heatmaps):
    with open(pickled_spike_heatmaps, 'rb') as fileobj:
        spike_heatmaps = pickle.load(fileobj)
else:
    spikes = info.get_spikes()

    all_neurons = list(range(1, len(spikes['time'])))
    spike_heatmaps = vdm.get_heatmaps(all_neurons, spikes, pos)
    with open(pickled_spike_heatmaps, 'wb') as fileobj:
        pickle.dump(spike_heatmaps, fileobj)

exp_time = 'pauseB'

t_start = info.task_times[exp_time][0]
t_stop = info.task_times[exp_time][1]

t_start_idx = vdm.find_nearest_idx(np.array(csc['time']), t_start)
t_end_idx = vdm.find_nearest_idx(np.array(csc['time']), t_stop)

sliced_csc = dict()
sliced_csc['time'] = csc['time'][t_start_idx:t_end_idx]
sliced_csc['data'] = csc['data'][t_start_idx:t_end_idx]

swr_times, swr_idx, filtered_butter = vdm.detect_swr_hilbert(sliced_csc, fs=info.fs, z_thres=4, power_thres=3, merge_thres=0.2, min_length=0.01)


sort_idx = vdm.get_sort_idx(tc['u'])

odd_firing_idx = get_odd_firing_idx(tc['u'])


all_u_fields = vdm.find_fields(tc['u'])
all_shortcut_fields = vdm.find_fields(tc['shortcut'])
all_novel_fields = vdm.find_fields(tc['novel'])

u_compare = vdm.find_fields(tc['u'], hz_thres=3)
shortcut_compare = vdm.find_fields(tc['shortcut'], hz_thres=3)
novel_compare = vdm.find_fields(tc['novel'], hz_thres=3)

u_fields_unique = vdm.unique_fields(all_u_fields, shortcut_compare, novel_compare)
shortcut_fields_unique = vdm.unique_fields(all_shortcut_fields, u_compare, novel_compare)
novel_fields_unique = vdm.unique_fields(all_novel_fields, u_compare, shortcut_compare)

u_fields_size = vdm.sized_fields(u_fields_unique)
shortcut_fields_size = vdm.sized_fields(shortcut_fields_unique)
novel_fields_size = vdm.sized_fields(novel_fields_unique)

u_fields = vdm.get_single_field(u_fields_size)
shortcut_fields = vdm.get_single_field(shortcut_fields_size)
novel_fields = vdm.get_single_field(novel_fields_size)


these_fields = []
for key in u_fields:
    these_fields.append(key)

field_spikes = []
field_tc = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        if idx in these_fields:
            field_spikes.append(spikes['time'][idx])
            field_tc.append(tc['u'][idx])

R063d4
Number of SWR events found:  255


In [5]:
colours = ['#7f0000', '#b30000', '#d7301f', '#ef6548', '#ec7014', '#fe9929', '#78c679', '#238443', '#7fcdbb',
           '#41b6c4', '#1d91c0', '#225ea8', '#8c96c6', '#8c6bb1', '#7a0177', '#ae017e', '#f768a1', '#fa9fb5',
           '#fcc5c0', '#d9f0a3', '#addd8e', '#f7fcb9', '#fec44f', '#ffeda0', '#ffffbf', '#dd3497', '#9ebcda',
           '#88419d', '#8c6bb1', 'k', 'b', 'c', 'g', 'm', 'r', 'y', 'k', 'b', 'c', 'g', 'm', 'r', 'y',
           'k', 'b', 'c', 'g', 'm', 'r', 'y', 'k', 'b', 'c', 'g', 'm', 'r', 'y', 'k', 'b', 'c', 'g', 'm', 'r', 'y',
           'k', 'b', 'c', 'g', 'm', 'r', 'y', 'k', 'b', 'c', 'g', 'm', 'r', 'y', 'k', 'b', 'c', 'g', 'm', 'r', 'y']

# x = [1,2,3]
# for colour in colours:
#     print(colour)
#     plt.plot(x, x, color=colour)
#     plt.fill_between(x, 0, x, facecolor=colour)
#     plt.show()


In [6]:
from matplotlib.offsetbox import AnchoredOffsetbox

# Adapted from mpl_toolkits.axes_grid2
class AnchoredScaleBar(AnchoredOffsetbox):
    def __init__(self, transform, sizex=0, sizey=0, labelx=None, labely=None,
                 loc=4, pad=0.1, borderpad=0.1, sep=2, prop=None, **kwargs):
        """
        Modified, draw a horizontal and/or vertical  bar with the size in data coordinate
        of the give axes. A label will be drawn underneath (center-aligned).
        
        Parameters
        ----------
        transform : the coordinate frame (typically axes.transData)
        sizex, sizey : width of x,y bar, in data units. 0 to omit
        labelx, labely : labels for x,y bars; None to omit
        loc : position in containing axes
        pad, borderpad : padding, in fraction of the legend font size (or prop)
        sep : separation between labels and bars in points.
        **kwargs : additional arguments passed to base class constructor
        """
        from matplotlib.lines import Line2D
        from matplotlib.pyplot import arrow
        from matplotlib.text import Text
        from matplotlib.offsetbox import AuxTransformBox
        bars = AuxTransformBox(transform)
        inv = transform.inverted()
        pixelxy = inv.transform((1, 1)) - inv.transform((0, 0))

        if sizex:
            barx = Line2D([sizex, 0], [0, 0], transform=transform, color='k')
            bars.add_artist(barx)            

        if sizey:
            bary = Line2D([0, 0], [0, sizey], transform=transform, color='k')
            bars.add_artist(bary)             
 
        if sizex and labelx:
            textx = Text(text=labelx, x=sizex/2.0, y=-5*pixelxy[1], ha='center', va='top')
            bars.add_artist(textx)
            
        if sizey and labely:
            texty = Text(text=labely, rotation='vertical', y=sizey/2.0, x=-2*pixelxy[0], 
                         va='center', ha='right')
            bars.add_artist(texty)
            
        AnchoredOffsetbox.__init__(self, loc=loc, pad=pad, borderpad=borderpad,
                                       child=bars, prop=prop, frameon=False, **kwargs)

def add_scalebar(ax, matchx=True, matchy=True, hidex=True, hidey=True, **kwargs):
    """ Add scalebars to axes
    Adds a set of scale bars to *ax*, matching the size to the ticks of the 
    plot and optionally hiding the x and y axes
    
    Parameters
    ---------- 
    ax : the axis to attach ticks to
    matchx, matchy : if True, set size of scale bars to spacing between ticks
                    if False, size should be set using sizex and sizey params
    hidex, hidey : if True, hide x-axis and y-axis of parent
    **kwargs : additional arguments passed to AnchoredScaleBars
 
    Returns created scalebar object
    """
    def find_loc(axis):
        loc = axis.get_majorticklocs()
        return len(loc)>1 and (loc[1] - loc[0])
    
    if matchx:
        kwargs['sizex'] = find_loc(ax.xaxis)
#         kwargs['labelx'] = str(kwargs['sizex'])
        kwargs['labelx'] = str(int(kwargs['sizex']*1000)) + ' ms'
        
    if matchy:
        kwargs['sizey'] = find_loc(ax.yaxis)
        kwargs['labely'] = str(kwargs['sizey'])
        
    scalebar = AnchoredScaleBar(ax.transData, **kwargs)
    ax.add_artist(scalebar)
 
    return scalebar

In [68]:
# Run time
start_time = info.task_times['phase3'][0]+190
stop_time = info.task_times['phase3'][0]+210

# start_time = info.task_times['phase3'][0]+235
# stop_time = info.task_times['phase3'][0]+250

# start_time = info.task_times['phase2'][0]+860
# stop_time = info.task_times['phase2'][0]+880

# SWR time
# idx = 39
# idx = 46
# start_time_swr = swr_times['start'][idx]
# stop_time_swr = swr_times['stop'][idx]

start_time_swr = 1371.23
stop_time_swr = 1371.33


rows = 12
cols = 7
fig = plt.figure()
ax1 = plt.subplot2grid((rows, cols), (0, 1), rowspan=rows, colspan=4)
ax2 = plt.subplot2grid((rows, cols), (0, 5), rowspan=rows, colspan=2)

location = 1
spike_loc = 2
for idx, neuron_spikes in enumerate(field_spikes):
#     if idx in these_fields:
    ax1.plot(neuron_spikes, np.ones(len(neuron_spikes))+location, '|', 
             color=colours[int(np.floor(location/spike_loc))], ms=10, mew=2)
    location += spike_loc
ax1.set_xlim([start_time, stop_time])
ax1.set_ylim([1, location])
add_scalebar(ax1, matchy=False, loc=1)
plt.setp(ax1, xticks=[], xticklabels=[], yticks=[])

location = 1
for i, neuron_spikes in enumerate(field_spikes):
#     if i in these_fields:
    ax2.plot(neuron_spikes, np.ones(len(neuron_spikes))+location, '|', 
             color=colours[int(np.floor(location/spike_loc))], ms=10, mew=2)
    location += spike_loc
ax2.set_xlim([start_time_swr, stop_time_swr])
ax2.set_ylim([1, location])
add_scalebar(ax2, matchy=False, loc=1)
plt.setp(ax2, xticks=[], xticklabels=[], yticks=[])

x = list(range(0, len(field_tc[0])))

for ax_loc in range(0, rows):
    ax = plt.subplot2grid((rows, cols), (ax_loc, 0))

    idx = rows - ax_loc - 1
    ax.plot(field_tc[idx], color=colours[idx])
    ax.fill_between(x, 0, field_tc[idx], facecolor=colours[idx])
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.setp(ax, xticks=[], xticklabels=[], yticks=[])

sns.despine()
plt.show()
# savepath = output_filepath + info.session_id + '_sequence-swr1.png'
# plt.savefig(savepath, dpi=300, bbox_inches='tight')
# plt.close()

In [8]:
ordered_spikes = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        ordered_spikes.append(spikes['time'][idx])

In [11]:
swr_times, swr_idx, filtered_butter = vdm.detect_swr_hilbert(sliced_csc, fs=info.fs, z_thres=3, power_thres=4, merge_thres=0.2, min_length=0.01)

Number of SWR events found:  318


In [72]:
for idx in range(10,11):
    plt.plot(csc['time'], csc['data']*1500)
    for swr_start, swr_stop in zip(swr_idx['start'], swr_idx['stop']):
        plt.plot(csc['time'][swr_start:swr_stop], csc['data'][swr_start:swr_stop]*1500, 'r')
    location = 1
    for i, neuron_spikes in enumerate(ordered_spikes):
        if i in these_fields:
            plt.plot(neuron_spikes, np.ones(len(neuron_spikes))+location, '|', color='r', ms=10, mew=1)
            location += 2
        else:
            plt.plot(neuron_spikes, np.ones(len(neuron_spikes))+location, '|', color='k', ms=10, mew=1)
            location += 2
            
#     for i, neuron_spikes in enumerate(field_spikes):
# #         if i in these_fields:
#         plt.plot(neuron_spikes, np.ones(len(neuron_spikes))+location, '|', color='k', ms=10, mew=1)
#         location += 2
#     plt.xlim(csc['time'][swr_idx['start'][idx]]-1, csc['time'][swr_idx['start'][idx]]+1)
#     plt.xlim(1347.419835, 1348.419835)
    plt.xlim(1371.23, 1371.326)
    plt.show()

In [None]:
print(csc['time'][swr_idx['start'][idx]]-0.5, csc['time'][swr_idx['start'][idx]]+0.5)

In [37]:
print(len(field_spikes))

12


In [55]:
y = 7
t = np.floor(y/2)
print(t)

3.0
