In [1]:
# %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString
import pickle
import seaborn as sns
import os

import vdmlab as vdm

from tuning_curves_functions import get_tc, get_odd_firing_idx, linearize

import info.R063d2_info as r063d2
import info.R063d3_info as r063d3
import info.R063d4_info as r063d4
import info.R063d5_info as r063d5
import info.R063d6_info as r063d6
import info.R066d1_info as r066d1
import info.R066d2_info as r066d2
import info.R066d3_info as r066d3
import info.R066d4_info as r066d4

In [2]:
pickle_filepath = 'E:\\code\\python-vdmlab\\projects\\emily_shortcut\\cache\\pickled\\'
output_filepath = 'E:\\code\\python-vdmlab\\projects\\emily_shortcut\\plots\\sequence\\'
# pickle_filepath = 'C:\\Users\\Emily\\Code\\python-vdmlab\\projects\\emily_shortcut\\cache\\pickled\\'
# output_filepath = 'C:\\Users\\Emily\\Code\\python-vdmlab\\projects\\emily_shortcut\\plots\\sequence\\'

In [3]:
info = r066d2

In [4]:
print(info.session_id)
pos = info.get_pos(info.pxl_to_cm)
csc = info.get_csc(info.good_swr[0])
spikes = info.get_spikes()

R066d2


In [5]:
speed = vdm.get_speed(pos)

run_threshold = 0.45
t_run = speed['time'][speed['smoothed'] >= run_threshold]

run_idx = np.zeros(pos['time'].shape, dtype=bool)
for idx in t_run:
    run_idx |= (pos['time'] == idx)
    
run_pos = dict()
run_pos['x'] = pos['x'][run_idx]
run_pos['y'] = pos['y'][run_idx]
run_pos['time'] = pos['time'][run_idx]

tc = get_tc(info, run_pos, pickle_filepath)

In [6]:
colours = ['#bd0026', '#fc4e2a', '#ef3b2c', '#ec7014', '#fe9929', 
           '#78c679', '#41ab5d', '#238443', '#66c2a4', '#41b6c4', 
           '#1d91c0', '#8c6bb1', '#225ea8', '#88419d', '#ae017e', 
           '#dd3497', '#f768a1', '#fcbba1', '#fc9272', '#fb6a4a', 
           '#e31a1c', '#fb6a4a', '#993404', '#b30000', '#800026',
           '#bd0026', '#fc4e2a', '#fb6a4a', '#ef3b2c', '#ec7014', 
           '#fe9929', '#78c679', '#41ab5d', '#238443', '#66c2a4', 
           '#41b6c4', '#1d91c0', '#8c6bb1', '#225ea8', '#88419d', 
           '#ae017e', '#dd3497', '#f768a1', '#fcbba1', '#fc9272', 
           '#fb6a4a', '#e31a1c', '#fb6a4a', '#993404', '#b30000', 
           '#800026', 'k', 'k', 'k', 'k', 'k', 'k', 'k', 'k', 'k']

In [7]:
all_u_fields = vdm.find_fields(tc['u'])
all_shortcut_fields = vdm.find_fields(tc['shortcut'])
all_novel_fields = vdm.find_fields(tc['novel'])

# u_compare = vdm.find_fields(tc['u'], hz_thres=3)
# shortcut_compare = vdm.find_fields(tc['shortcut'], hz_thres=3)
# novel_compare = vdm.find_fields(tc['novel'], hz_thres=3)

# u_fields_unique = vdm.unique_fields(all_u_fields, shortcut_compare, novel_compare)
# shortcut_fields_unique = vdm.unique_fields(all_shortcut_fields, u_compare, novel_compare)
# novel_fields_unique = vdm.unique_fields(all_novel_fields, u_compare, shortcut_compare)

u_fields_sized = vdm.sized_fields(all_u_fields, max_length=15)
shortcut_fields_sized = vdm.sized_fields(all_shortcut_fields, max_length=15)
novel_fields_sized = vdm.sized_fields(all_novel_fields, max_length=15)

u_fields_single = vdm.get_single_field(u_fields_sized)
shortcut_fields_single = vdm.get_single_field(shortcut_fields_sized)
novel_fields_single = vdm.get_single_field(novel_fields_sized)

sort_idx = vdm.get_sort_idx(tc['u'])
odd_firing_idx = get_odd_firing_idx(tc['u'])

u_tc = []
u_field_spikes = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        if idx in u_fields_single:
            u_field_spikes.append(spikes['time'][idx])
            u_tc.append(tc['u'][idx])
            
sort_idx = vdm.get_sort_idx(tc['shortcut'])
odd_firing_idx = get_odd_firing_idx(tc['shortcut'])

shortcut_tc = []
shortcut_field_spikes = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        if idx in shortcut_fields_single:
            shortcut_field_spikes.append(spikes['time'][idx])
            shortcut_tc.append(tc['shortcut'][idx])
            
sort_idx = vdm.get_sort_idx(tc['novel'])
odd_firing_idx = get_odd_firing_idx(tc['novel'])

novel_tc = []
novel_field_spikes = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        if idx in novel_fields_single:
            novel_field_spikes.append(spikes['time'][idx])
            novel_tc.append(tc['novel'][idx])

In [8]:
print(len(u_field_spikes), len(shortcut_field_spikes), len(novel_field_spikes))

29 15 18


In [9]:
sns.set_style('white')
sns.set_style('ticks')

t_start = info.task_times['phase3'][0]
t_stop = info.task_times['phase3'][1]

t_start_idx = vdm.find_nearest_idx(run_pos['time'], t_start)
t_stop_idx = vdm.find_nearest_idx(run_pos['time'], t_stop)

sliced_pos = dict()
sliced_pos['x'] = run_pos['x'][t_start_idx:t_stop_idx]
sliced_pos['y'] = run_pos['y'][t_start_idx:t_stop_idx]
sliced_pos['time'] = run_pos['time'][t_start_idx:t_stop_idx]

In [12]:
rows = 3
cols = 8

tc = u_tc

fig = plt.figure()
for i, idx in enumerate(range(1, 26)):
    if i < cols:
        row = 0
        col = i
    elif cols*2 >= i > cols:
        row = 1
        col = i - cols - 1
    elif i > cols*2:
        row = 2
        col = i - cols*2 - 1
    ax = plt.subplot2grid((rows, cols), (row, col))
    ax.plot(sliced_pos['x'], sliced_pos['y'], '.', color='#bdbdbd', ms=1)
    
    tc[idx] = ((tc[idx] * (5 - 0)) / (np.max(tc[idx]) - np.min(tc[idx]))) + 0
    ax.plot(list(range(15, np.shape(tc)[1]+15)), tc[idx]+60, color='#252525', ms=0.1)
    ax.fill_between(list(range(15, np.shape(tc)[1]+15)), 60, tc[idx]+60, facecolor=colours[idx])
    for spike in u_field_spikes[idx]:
        if t_start < spike < t_stop:
            spike_idx = vdm.find_nearest_idx(sliced_pos['time'], spike)
            ax.plot(sliced_pos['x'][spike_idx], sliced_pos['y'][spike_idx], 'o', color=colours[idx], 
                     markeredgecolor='#252525', fillstyle='full', markeredgewidth=0.1, ms=3)
            plt.setp(ax, xticks=[], xticklabels=[], yticks=[])
            
sns.despine(bottom=True, left=True)
plt.tight_layout()
fig.subplots_adjust(hspace=0.01, wspace=0.01)
plt.show()

In [22]:
rows = 3
cols = 5

tc = shortcut_tc

fig = plt.figure()
for i, idx in enumerate(range(15)):
    if i < cols:
        row = 0
        col = i
    elif cols*2 >= i > cols:
        row = 1
        col = i - cols - 1
    elif i > cols*2:
        row = 2
        col = i - cols*2 - 1
    ax = plt.subplot2grid((rows, cols), (row, col))
    ax.plot(sliced_pos['x'], sliced_pos['y'], '.', color='#bdbdbd', ms=1)
    
    tc[idx] = ((tc[idx] * (5 - 0)) / (np.max(tc[idx]) - np.min(tc[idx]))) + 0
    ax.plot(list(range(15, np.shape(tc)[1]+15)), tc[idx]+60, color='#252525', ms=0.1)
    ax.fill_between(list(range(15, np.shape(tc)[1]+15)), 60, tc[idx]+60, facecolor=colours[idx])
    for spike in shortcut_field_spikes[idx]:
        if t_start < spike < t_stop:
            spike_idx = vdm.find_nearest_idx(sliced_pos['time'], spike)
            ax.plot(sliced_pos['x'][spike_idx], sliced_pos['y'][spike_idx], 'o', color=colours[idx], 
                     markeredgecolor='#252525', fillstyle='full', markeredgewidth=0.1, ms=3)
            plt.setp(ax, xticks=[], xticklabels=[], yticks=[])
            
sns.despine(bottom=True, left=True)
plt.tight_layout()
fig.subplots_adjust(hspace=0.01, wspace=0.01)
plt.show()