In [1]:
# %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString
import pickle
import seaborn as sns
import os

import vdmlab as vdm

from tuning_curves_functions import get_tc, get_odd_firing_idx, linearize

import info.R066d4_info as r066d4

In [3]:
# pickle_filepath = 'E:\\code\\python-vdmlab\\projects\\emily_shortcut\\cache\\pickled\\'
# output_filepath = 'E:\\code\\python-vdmlab\\projects\\emily_shortcut\\plots\\sequence\\'
pickle_filepath = 'C:\\Users\\Emily\\Code\\python-vdmlab\\projects\\emily_shortcut\\cache\\pickled\\'
output_filepath = 'C:\\Users\\Emily\\Code\\python-vdmlab\\projects\\emily_shortcut\\plots\\sequence\\'

In [4]:
info = r066d4

In [5]:
print(info.session_id)
pos = info.get_pos(info.pxl_to_cm)

t_start = info.task_times['phase3'][0]
t_stop = info.task_times['phase3'][1]

t_start_idx = vdm.find_nearest_idx(pos['time'], t_start)
t_end_idx = vdm.find_nearest_idx(pos['time'], t_stop)

sliced_pos = dict()
sliced_pos['x'] = pos['x'][t_start_idx:t_end_idx]
sliced_pos['y'] = pos['y'][t_start_idx:t_end_idx]
sliced_pos['time'] = pos['time'][t_start_idx:t_end_idx]

linear, zone = linearize(info, pos)

spikes = info.get_spikes()

tc = get_tc(info, sliced_pos, pickle_filepath)

sort_idx = vdm.get_sort_idx(tc['u'])
odd_firing_idx = get_odd_firing_idx(tc['u'])

ordered_spikes = spikes['time'][sort_idx]

R066d4


In [7]:
dt = np.median(np.diff(linear['u']['time']))

edges = np.hstack(((linear['u']['time']-dt)/2, linear['u']['time'][-1]))
subsample = 6
edges = edges[np.arange(1, len(edges)+1, 6)]

centers = (edges[:-1] + dt)/2

In [None]:
position_z = linear['u']
num_bins = 101

linear_start = np.min(position_z['time'])
linear_stop = np.max(position_z['time'])
edges = np.linspace(linear_start, linear_stop, num=num_bins)
centers = np.array((edges[1:] + edges[:-1]) / 2.)

# occupancy = np.zeros(len(centers))

In [10]:
intervals = dict()
intervals['start'] = edges[:-2]
intervals['stop'] = edges[1:-1]

counts = vdm.spike_counts(spikes['time'], intervals)

In [11]:
print(np.min(counts), np.max(counts))

0.0 0.0


In [13]:
print(np.shape(counts))

(86, 7622)


In [None]:
print(intervals['start'][0:5], intervals['stop'][0:5])

In [None]:
plt.plot(counts)
plt.show()

# From matlab

In [257]:
import scipy.io as sio
loading_decoder = sio.loadmat('C:\\Users\\Emily\\Desktop\\_decoding.mat')

In [256]:
decode = dict(time=[])
decode['spikes'] = loading_decoder['pyspikes'][0]
decode['ztime'] = loading_decoder['pyztime'][0]
decode['zdata'] = loading_decoder['pyzdata'][0]
decode['tc'] = loading_decoder['pytc']

In [42]:
print(np.shape(decode['tc']))

(48,)


In [103]:
dt = np.median(np.diff(decode['ztime']))

edges = np.hstack((decode['ztime']-(dt/2), decode['ztime'][-1]))
subsample = 6
edges = edges[::subsample]

In [104]:
print(len(edges), np.min(edges), np.max(edges))

8951 13523.44869 15618.461092


In [105]:
this_dt = np.median(np.diff(edges))

In [106]:
print(this_dt, dt)

0.200074 0.0330939999985


In [234]:
decode['ztime'][:10]

array([ 13523.465237,  13523.497987,  13523.565035,  13523.598096,
        13523.632243,  13523.665258,  13523.698067,  13523.765092,
        13523.798116,  13523.832185])

In [108]:
print(len(centers), np.min(centers), np.max(centers))

8950 13523.548727 15618.377692


In [247]:
from scipy import signal
gaussian_std = 0.02 / this_dt
gaussian_window = 1.0 / this_dt

gaussian_filter = signal.gaussian(gaussian_window, gaussian_std)
gaussian_filter /= np.sum(gaussian_filter)

In [248]:
q = np.zeros((int(len(decode['spikes'])), int(len(edges)-1)))
for idx, neuron_spikes in enumerate(decode['spikes']):
    q[idx] = np.histogram(neuron_spikes, bins=edges)[0]
    if gaussian_std > this_dt:
        q[idx] = np.convolve(q[idx], gaussian_filter, mode='same')

In [243]:
plt.plot(gaussian_filter)
plt.show()

In [246]:
plt.pcolormesh(q[:,:100])
plt.colorbar()
plt.show()

In [208]:
print(np.shape(np.histogram(neuron_spikes, bins=edges)[0]))

(8950,)


In [14]:
import numpy.matlib
import math

In [398]:
def get_counts(spikes, edges, gaussian_std=0.02, gaussian_window=1.0):
    dt = np.median(np.diff(edges))
    
    apply_filter = False
    gaussian_std /= dt
    gaussian_window /= dt
    
    if gaussian_std > dt:
        apply_filter = True
    
    if apply_filter:
        gaussian_filter = signal.gaussian(gaussian_window, gaussian_std)
        gaussian_filter /= np.sum(gaussian_filter)
        
    counts = np.zeros((int(len(spikes)), int(len(edges)-1)))
    for idx, neuron_spikes in enumerate(spikes):
        counts[idx] = np.histogram(neuron_spikes, bins=edges)[0]
        if apply_filter:
            counts[idx] = np.convolve(q[idx], gaussian_filter, mode='same')
    return counts

dt = np.median(np.diff(decode['ztime']))
edges = np.hstack((decode['ztime']-(dt/2), decode['ztime'][-1]))
subsample = 6
edges = edges[::subsample]
counts = get_counts(decode['spikes'], edges)

In [399]:
plt.pcolormesh(counts[:,:100])
plt.colorbar()
plt.show()

In [400]:
def bayesian_prob(counts, tuning_curves, centers, min_neurons=1, min_spikes=1):
    length = np.shape(counts)[1]
    num_bins = np.shape(tuning_curves)[1]
    bin_size = np.median(np.diff(centers))
    
    prob = np.empty((length, num_bins)) * np.nan
    for idx in range(num_bins):
        # What does tempprod represent?
        tempprod = np.nansum(np.log(tuning_curves[:,idx][..., np.newaxis] ** q), axis=0)

        # What does tempsum represent?
        tempsum = np.exp(-bin_size * np.nansum(tuning_curves[:, idx]))

        prob[:, idx] = np.exp(tempprod) * tempsum * (1/num_bins)


    prob /= np.sum(prob, axis=1)[..., np.newaxis]
    # prob[np.isnan(prob)] = 0 add to docstring
    
    num_active_neurons = np.sum(counts > min_spikes, axis=0)
    prob[num_active_neurons < min_neurons] = np.nan
    return prob

centers = edges[:-1] + np.median(np.diff(edges))/2
prob = bayesian_prob(counts, decode['tc'], centers)

In [401]:
plt.pcolormesh(prob[200::-1])
plt.colorbar()
plt.show()

In [402]:
def find_nearest_indices(array, vals):
    return np.array([vdm.find_nearest_idx(array, val) for val in vals], dtype=int)

In [403]:
def decode_location(prob, linear):
    max_decoded_idx = np.argmax(prob, axis=1)
    decoded = max_decoded_idx * (np.max(linear['position'])-np.min(linear['position'])) / (np.shape(prob)[1]-1)
    decoded += np.min(linear['position'])

    nan_idx = np.sum(np.isnan(prob), axis=1) == (np.shape(prob)[1]-1)
    decoded[nan_idx] = np.nan
    
    return decoded

linear = dict()
linear['position'] = decode['zdata']
linear['time'] =
decoded = decode_location(prob, linear)

actual_idx = find_nearest_indices(decode['ztime'], centers)
actual_location = decode['zdata'][rat_location_idx]

decode_error = np.abs(actual_location - decoded)
np.mean(decode_error)

19.078440508736477

In [404]:
plt.plot(centers, decoded)
plt.plot(decode['ztime'], decode['zdata'], 'r.')
plt.show()

In [250]:
centers = edges[:-1] + np.median(np.diff(edges))/2

In [263]:
print(np.shape(decode['tc'])[1])

48


In [267]:
min_neurons = 1
min_spikes = 1

length = np.shape(q)[1]
num_bins = np.shape(decode['tc'])[1]
bin_size = np.median(np.diff(centers))

# occ_uniform = np.ones((1, num_bins))[0] * bin_size



In [297]:
q.shape

(86, 8950)

In [298]:
prob = np.empty((length, num_bins)) * np.nan
for idx in range(num_bins):
    # What does tempprod represent?
    tempprod = np.nansum(np.log(decode['tc'][:,idx][..., np.newaxis] ** q), axis=0)
    
    # What does tempsum represent?
    tempsum = np.exp(-bin_size * np.nansum(decode['tc'][:, idx]))
    
    prob[:, idx] = np.exp(tempprod) * tempsum * (1/num_bins)
    

prob /= np.sum(prob, axis=1)[..., np.newaxis]
# prob[np.isnan(prob)] = 0 add to docstring


In [301]:
num_active_neurons = np.sum(q > min_spikes, axis=0)
prob[num_active_neurons < min_neurons] = np.nan

In [305]:
plt.pcolormesh(prob[200::-1])
plt.colorbar()
plt.show()

In [306]:
max_decoded_idx = np.argmax(prob, axis=1)

In [322]:
print(np.min(decode['zdata']), np.max(decode['zdata']))
print(0, np.shape(prob)[1]-1)
print(np.min(decoded), np.max(decoded))

1.0 146.0
0 47
1.0 146.0


In [321]:
decoded = max_decoded_idx * (np.max(decode['zdata'])-np.min(decode['zdata'])) / (np.shape(prob)[1]-1)
decoded += np.min(decode['zdata'])

In [354]:
nan_idx = np.sum(np.isnan(prob), axis=1) == (np.shape(prob)[1]-1)

In [357]:
decoded[nan_idx] = np.nan

In [366]:
def find_nearest_indices(array, vals):
    return np.array([vdm.find_nearest_idx(array, val) for val in vals], dtype=int)

In [372]:
rat_location_idx = find_nearest_indices(decode['ztime'], centers)

In [373]:
np.shape(rat_location_idx)

(8950,)

In [374]:
rat_location = decode['zdata'][rat_location_idx]

In [375]:
decode_error = np.abs(rat_location - decoded)

In [377]:
plt.plot(decode_error)
plt.show()

In [378]:
np.mean(decode_error)

19.078440508736477

In [379]:
plt.plot(centers, decoded)
plt.plot(decode['ztime'], decode['zdata'], 'r.')
plt.show()

In [391]:
import nengo
smoothed_decoded = nengo.Lowpass(0.001).filtfilt(decoded)

In [392]:
print(np.mean(np.abs(rat_location - smoothed_decoded)))

17.9240187898


In [394]:
plt.plot(centers, smoothed_decoded)
plt.plot(decode['ztime'], decode['zdata'], 'r.')
plt.show()

In [348]:
a = np.random.rand(4, 2)
a[:, 1] = np.nan
print(a)
np.argmax(a, axis=0)

[[ 0.57765858         nan]
 [ 0.60636395         nan]
 [ 0.99415385         nan]
 [ 0.2813858          nan]]


array([2, 0], dtype=int64)

In [353]:
np.sum(np.isnan([np.nan, 2, np.nan]))

2

In [90]:
intervals = dict()
intervals['start'] = edges[:-2]
intervals['stop'] = edges[1:-1]

counts = vdm.spike_counts(decode['spikes'], intervals)

In [92]:
len(intervals['start'])

8949

In [91]:
print(len(counts), np.min(counts), np.max(counts))

86 0.0 302.0
