In [None]:
# %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString
import pickle
import seaborn as sns
import os

import vdmlab as vdm

from tuning_curves_functions import get_tc, get_odd_firing_idx, linearize

import info.R063d3_info as r063d3
import info.R066d4_info as r066d4

In [None]:
# pickle_filepath = 'E:\\code\\python-vdmlab\\projects\\emily_shortcut\\cache\\pickled\\'
# output_filepath = 'E:\\code\\python-vdmlab\\projects\\emily_shortcut\\plots\\sequence\\'
pickle_filepath = 'C:\\Users\\Emily\\Code\\python-vdmlab\\projects\\emily_shortcut\\cache\\pickled\\'
output_filepath = 'C:\\Users\\Emily\\Code\\python-vdmlab\\projects\\emily_shortcut\\plots\\sequence\\'

In [None]:
info = r063d3

In [None]:
print(info.session_id)
pos = info.get_pos(info.pxl_to_cm)

t_start = info.task_times['phase3'][0]
t_stop = info.task_times['phase3'][1]

t_start_idx = vdm.find_nearest_idx(pos['time'], t_start)
t_end_idx = vdm.find_nearest_idx(pos['time'], t_stop)

sliced_pos = dict()
sliced_pos['x'] = pos['x'][t_start_idx:t_end_idx]
sliced_pos['y'] = pos['y'][t_start_idx:t_end_idx]
sliced_pos['time'] = pos['time'][t_start_idx:t_end_idx]

linear, zone = linearize(info, pos)

spikes = info.get_spikes()

tc = get_tc(info, sliced_pos, pickle_filepath)

sort_idx = vdm.get_sort_idx(tc['u'])
odd_firing_idx = get_odd_firing_idx(tc['u'])

ordered_spikes = spikes['time'][sort_idx]

In [None]:
dt = np.median(np.diff(linear['u']['time']))

edges = np.hstack(((linear['u']['time']-dt)/2, linear['u']['time'][-1]))
subsample = 6
edges = edges[np.arange(1, len(edges)+1, 6)]

centers = (edges[:-1] + dt)/2

In [None]:
position_z = linear['u']
num_bins = 101

linear_start = np.min(position_z['time'])
linear_stop = np.max(position_z['time'])
edges = np.linspace(linear_start, linear_stop, num=num_bins)
centers = np.array((edges[1:] + edges[:-1]) / 2.)

# occupancy = np.zeros(len(centers))

In [None]:
intervals = dict()
intervals['start'] = edges[:-2]
intervals['stop'] = edges[1:-1]

counts = vdm.spike_counts(spikes['time'], intervals)

In [None]:
print(np.min(counts), np.max(counts))

In [None]:
print(np.shape(counts))

In [None]:
print(intervals['start'][0:5], intervals['stop'][0:5])

In [None]:
plt.plot(counts)
plt.show()

# From matlab

In [None]:
import scipy.io as sio
loading_decoder = sio.loadmat('C:\\Users\\Emily\\Desktop\\_decoding.mat')

In [None]:
decode = dict(time=[])
decode['spikes'] = loading_decoder['pyspikes'][0]
decode['ztime'] = loading_decoder['pyztime'][0]
decode['zdata'] = loading_decoder['pyzdata'][0]
decode['tc'] = loading_decoder['pytc']

In [None]:
print(np.shape(decode['tc']))

In [None]:
dt = np.median(np.diff(decode['ztime']))

edges = np.hstack((decode['ztime']-(dt/2), decode['ztime'][-1]))
subsample = 6
edges = edges[::subsample]

In [None]:
print(len(edges), np.min(edges), np.max(edges))

In [None]:
this_dt = np.median(np.diff(edges))

In [None]:
print(this_dt, dt)

In [None]:
decode['ztime'][:10]

In [None]:
print(len(centers), np.min(centers), np.max(centers))

In [None]:
from scipy import signal
gaussian_std = 0.02 / this_dt
gaussian_window = 1.0 / this_dt

gaussian_filter = signal.gaussian(gaussian_window, gaussian_std)
gaussian_filter /= np.sum(gaussian_filter)

In [None]:
q = np.zeros((int(len(decode['spikes'])), int(len(edges)-1)))
for idx, neuron_spikes in enumerate(decode['spikes']):
    q[idx] = np.histogram(neuron_spikes, bins=edges)[0]
    if gaussian_std > this_dt:
        q[idx] = np.convolve(q[idx], gaussian_filter, mode='same')

In [None]:
plt.plot(gaussian_filter)
plt.show()

In [None]:
plt.pcolormesh(q[:,:100])
plt.colorbar()
plt.show()

In [None]:
print(np.shape(np.histogram(neuron_spikes, bins=edges)[0]))

In [None]:
print(linear['u'].keys())

In [None]:
def get_counts(spikes, edges, gaussian_std=0.02, gaussian_window=1.0):
    dt = np.median(np.diff(edges))
    
    apply_filter = False
    gaussian_std /= dt
    gaussian_window /= dt
    
    if gaussian_std > dt:
        apply_filter = True
    
    if apply_filter:
        gaussian_filter = signal.gaussian(gaussian_window, gaussian_std)
        gaussian_filter /= np.sum(gaussian_filter)
        
    counts = np.zeros((int(len(spikes)), int(len(edges)-1)))
    for idx, neuron_spikes in enumerate(spikes):
        counts[idx] = np.histogram(neuron_spikes, bins=edges)[0]
        if apply_filter:
            counts[idx] = np.convolve(q[idx], gaussian_filter, mode='same')
    return counts

In [None]:
# linear = dict()
# linear['position'] = decode['zdata']
# linear['time'] = decode['ztime']
# spikes = dict()
# spikes['time'] = decode['spikes']
# tc = decode['tc']

linear = linear['u']
tc = np.array(tc['u'])

dt = np.median(np.diff(linear['time']))
edges = np.hstack((linear['time']-(dt/2), linear['time'][-1]))
subsample = 6
edges = edges[::subsample]
counts = get_counts(spikes['time'], edges)

In [None]:
plt.pcolormesh(counts[:,:100])
plt.colorbar()
plt.show()

In [None]:
def bayesian_prob(counts, tuning_curves, centers, min_neurons=1, min_spikes=1):
    length = np.shape(counts)[1]
    num_bins = np.shape(tuning_curves)[1]
    bin_size = np.median(np.diff(centers))
    
    prob = np.empty((length, num_bins)) * np.nan
    for idx in range(num_bins):
        # What does tempprod represent?
        tempprod = np.nansum(np.log(tuning_curves[:, idx][..., np.newaxis] ** counts), axis=0)

        # What does tempsum represent?
        tempsum = np.exp(-bin_size * np.nansum(tuning_curves[:, idx]))

        prob[:, idx] = np.exp(tempprod) * tempsum * (1/num_bins)


    prob /= np.sum(prob, axis=1)[..., np.newaxis]
    # prob[np.isnan(prob)] = 0 add to docstring
    
    num_active_neurons = np.sum(counts > min_spikes, axis=0)
    prob[num_active_neurons < min_neurons] = np.nan
    return prob


In [None]:
centers = edges[:-1] + np.median(np.diff(edges))/2
prob = bayesian_prob(counts, tc, centers)

plt.pcolormesh(prob[200::-1])
plt.colorbar()
plt.show()

In [None]:
def decode_location(prob, linear):
    max_decoded_idx = np.argmax(prob, axis=1)
    decoded = max_decoded_idx * (np.max(linear['position'])-np.min(linear['position'])) / (np.shape(prob)[1]-1)
    decoded += np.min(linear['position'])

    nan_idx = np.sum(np.isnan(prob), axis=1) == (np.shape(prob)[1]-1)
    decoded[nan_idx] = np.nan
    
    return decoded

In [None]:
def find_nearest_indices(array, vals):
    return np.array([vdm.find_nearest_idx(array, val) for val in vals], dtype=int)

In [None]:
decoded = decode_location(prob, linear)

actual_idx = find_nearest_indices(linear['time'], centers)
actual_location = linear['position'][actual_idx]

decoded[np.isnan(decoded)] = 0
decode_error = np.abs(actual_location - decoded)
np.mean(decode_error)

In [None]:
np.shape(decoded)

In [None]:
plt.plot(centers, decoded)
plt.plot(linear['time'], linear['position'], 'r.')
plt.show()

In [None]:
import nengo
smoothed_decoded = nengo.Lowpass(0.002).filtfilt(decoded)

In [None]:
decode_error = np.abs(actual_location - smoothed_decoded)
print(np.mean(decode_error))

plt.plot(centers, smoothed_decoded)
plt.plot(linear['time'], linear['position'], 'r.')
plt.show()

In [None]:
for thisposition in linear['time'][10000:10200]:
    x = vdm.find_nearest_idx(pos['time'], thisposition)
    plt.plot(pos['time'][x], pos['x'][x], 'b.')
    plt.plot(pos['time'][x], pos['y'][x], 'g.')

# plt.plot(pos['x'], pos['y'], 'b')
# plt.xlim(linear['time'][10000], linear['time'][12000])
plt.show()

In [None]:
plt.plot(linear['time'], linear['position'], 'b.')
plt.show()

In [None]:
centers = edges[:-1] + np.median(np.diff(edges))/2

In [None]:
print(np.shape(decode['tc'])[1])

In [None]:
min_neurons = 1
min_spikes = 1

length = np.shape(q)[1]
num_bins = np.shape(decode['tc'])[1]
bin_size = np.median(np.diff(centers))

# occ_uniform = np.ones((1, num_bins))[0] * bin_size



In [None]:
q.shape

In [None]:
prob = np.empty((length, num_bins)) * np.nan
for idx in range(num_bins):
    # What does tempprod represent?
    tempprod = np.nansum(np.log(decode['tc'][:,idx][..., np.newaxis] ** q), axis=0)
    
    # What does tempsum represent?
    tempsum = np.exp(-bin_size * np.nansum(decode['tc'][:, idx]))
    
    prob[:, idx] = np.exp(tempprod) * tempsum * (1/num_bins)
    

prob /= np.sum(prob, axis=1)[..., np.newaxis]
# prob[np.isnan(prob)] = 0 add to docstring


In [None]:
num_active_neurons = np.sum(q > min_spikes, axis=0)
prob[num_active_neurons < min_neurons] = np.nan

In [None]:
plt.pcolormesh(prob[200::-1])
plt.colorbar()
plt.show()

In [None]:
max_decoded_idx = np.argmax(prob, axis=1)

In [None]:
print(np.min(decode['zdata']), np.max(decode['zdata']))
print(0, np.shape(prob)[1]-1)
print(np.min(decoded), np.max(decoded))

In [None]:
decoded = max_decoded_idx * (np.max(decode['zdata'])-np.min(decode['zdata'])) / (np.shape(prob)[1]-1)
decoded += np.min(decode['zdata'])

In [None]:
nan_idx = np.sum(np.isnan(prob), axis=1) == (np.shape(prob)[1]-1)

In [None]:
decoded[nan_idx] = np.nan

In [None]:
def find_nearest_indices(array, vals):
    return np.array([vdm.find_nearest_idx(array, val) for val in vals], dtype=int)

In [None]:
rat_location_idx = find_nearest_indices(decode['ztime'], centers)

In [None]:
np.shape(rat_location_idx)

In [None]:
rat_location = decode['zdata'][rat_location_idx]

In [None]:
decode_error = np.abs(rat_location - decoded)

In [None]:
plt.plot(decode_error)
plt.show()

In [None]:
np.mean(decode_error)

In [None]:
plt.plot(centers, decoded)
plt.plot(decode['ztime'], decode['zdata'], 'r.')
plt.show()

In [None]:
import nengo
smoothed_decoded = nengo.Lowpass(0.001).filtfilt(decoded)

In [None]:
print(np.mean(np.abs(rat_location - smoothed_decoded)))

In [None]:
plt.plot(centers, smoothed_decoded)
plt.plot(decode['ztime'], decode['zdata'], 'r.')
plt.show()

In [None]:
a = np.random.rand(4, 2)
a[:, 1] = np.nan
print(a)
np.argmax(a, axis=0)

In [None]:
np.sum(np.isnan([np.nan, 2, np.nan]))

In [None]:
intervals = dict()
intervals['start'] = edges[:-2]
intervals['stop'] = edges[1:-1]

counts = vdm.spike_counts(decode['spikes'], intervals)

In [None]:
len(intervals['start'])

In [None]:
print(len(counts), np.min(counts), np.max(counts))