In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

from scipy.ndimage.filters import gaussian_filter
import scipy.signal as signal

import vdmlab as vdm

In [None]:
pos = dict()
pos['x'] = np.array([1, 5, 9])
pos['y'] = np.array([6, 10, 1])
pos['time'] = np.array([0, 1, 2])

plt.plot(pos['x'], pos['y'], 'b.', ms=10)
plt.xlim(0.5, 9.5)
plt.ylim(0.5, 10.5)
plt.show()

In [None]:
spikes = np.array([[0.1, 2.1, 2.3],
                   [0.2, 0.7, 1.1]])

In [None]:
binsize = 3
xedges = np.arange(pos['x'].min(), pos['x'].max(), binsize)
yedges = np.arange(pos['y'].min(), pos['y'].max(), binsize)
if (xedges[-1] != pos['x'].max()) or (yedges[-1] != pos['y'].max()):
    xedges = np.hstack((xedges, pos['x'].max()))
    yedges = np.hstack((yedges, pos['y'].max()))

In [None]:
sampling_rate = 1/30.

position_heatmap, pos_xedges, pos_yedges = np.histogram2d(pos['y'], pos['x'], bins=[yedges, xedges])

position_heatmap *= sampling_rate

shape = position_heatmap.shape
position_heatmap.flatten()
occupied_idx = position_heatmap > 0
# np.reshape(occupied_idx, shape)

tuning_curves = []
for neuron_spikes in spikes:
    counts_x = []
    counts_y = []
    for spike_time in neuron_spikes:
        spike_idx = vdm.find_nearest_idx(pos['time'], spike_time)
        counts_x.append(pos['x'][spike_idx])
        counts_y.append(pos['y'][spike_idx])
    counts_heatmap, counts_xedges, counts_yedges = np.histogram2d(counts_y, counts_x, bins=[yedges, xedges])
    
    firing_rate = np.zeros(shape)
    firing_rate.flatten()
    counts_heatmap.flatten()
    position_heatmap.flatten()
    firing_rate[occupied_idx] = counts_heatmap[occupied_idx] / position_heatmap[occupied_idx]
    np.reshape(firing_rate, shape)

    tuning_curves.append(firing_rate)


In [None]:
tuning_curves

In [None]:
filter_type = 'gaussian'
# filter_type = None
gaussian_sigma = 0.5

if filter_type == 'gaussian':
    tc = []
    for firing_rate in tuning_curves:
        tc.append(gaussian_filter(firing_rate, gaussian_sigma))
else:
    print('Tuning curve with no filter.')
    tc = tuning_curves

In [None]:
tc

In [None]:
plt.figure()
# plt.plot(pos['x'], pos['y'], 'k.', ms=0.2)
# xedges = np.linspace(np.min(pos['x'])-2, np.max(pos['x'])+2, num_bins+1)
# yedges = np.linspace(np.min(pos['y'])-2, np.max(pos['y'])+2, num_bins+1)
xx, yy = np.meshgrid(xedges, yedges)

# if plot_log:
#     pp = plt.pcolormesh(xx, yy, heatmaps, norm=SymLogNorm(linthresh=1.0, vmax=vmax), cmap='YlGn')
# else:
for tuning_curve in tc:
    pp = plt.pcolormesh(xx, yy, tuning_curve, cmap='YlGn')
    plt.colorbar(pp)
    plt.axis('off')
    plt.show()

In [None]:
yedges

In [None]:
xedges = np.linspace(np.min(pos['x'])-2, np.max(pos['x'])+2, num_bins+1)
yedges = np.linspace(np.min(pos['y'])-2, np.max(pos['y'])+2, num_bins+1)

heatmaps = dict()
count = 1
for neuron in neuron_list:
    field_x = []
    field_y = []
    for spike in spikes['time'][neuron]:
        spike_idx = find_nearest_idx(pos['time'], spike)
        field_x.append(pos['x'][spike_idx])
        field_y.append(pos['y'][spike_idx])
        heatmap, out_xedges, out_yedges = np.histogram2d(field_x, field_y, bins=[xedges, yedges])
    heatmaps[neuron] = heatmap.T
    print(str(neuron) + ' of ' + str(len(neuron_list)))
    count += 1

In [None]:
linear_start = np.min(linear['position'])
linear_stop = np.max(linear['position'])
edges = np.arange(linear_start, linear_stop, binsize)
if edges[-1] < linear_stop:
    edges = np.hstack([edges, linear_stop])

position_counts = np.histogram(linear['position'], bins=edges)[0]
position_counts = position_counts.astype(float)
position_counts *= sampling_rate
occupied_idx = position_counts > 0

tc = []
for idx, neuron_spikes in enumerate(spike_times):
    counts_idx = []
    for spike_time in neuron_spikes:
        bin_idx = find_nearest_idx(linear['time'], spike_time)
        counts_idx.append(linear['position'][bin_idx])
    spike_counts = np.histogram(counts_idx, bins=edges)[0]

    firing_rate = np.zeros(len(edges)-1)
    firing_rate[occupied_idx] = spike_counts[occupied_idx] / position_counts[occupied_idx]
    tc.append(firing_rate)

if filter_type == 'gaussian':
    filter_multiplier = 6
    out_tc = []
    gaussian_filter = signal.get_window(('gaussian', gaussian_std), gaussian_std*filter_multiplier)
    normalized_gaussian = gaussian_filter / np.sum(gaussian_filter)
    for firing_rate in tc:
        out_tc.append(np.convolve(firing_rate, normalized_gaussian, mode='same'))
else:
    print('Tuning curve with no filter.')
    out_tc = tc