In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_tuning_curves
from analyze_decode import get_decoded
from run import analysis_infos
from utils_maze import get_trials

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "session-diagnostics")

In [None]:
# import info.r067d7 as r067d7
# import info.r068d6 as r068d6

# infos = [r067d7, r068d6]

In [None]:
def plot_tuning_curves(neurons, xx, yy):
    multiple_tuning_curves = np.zeros(neurons.tuning_shape)
    cmap = plt.cm.get_cmap('bone_r', 25)
    plt.figure()

    for i in range(neurons.n_neurons):
        multiple_tuning_curves += neurons.tuning_curves[i]

    pp = plt.pcolormesh(xx, yy, multiple_tuning_curves, vmax=100., cmap=cmap)
    plt.colorbar(pp)
    plt.axis('off')
    plt.tight_layout()
    title = info.session_id + '-tuning_curve-all'
    plt.title(title)
    plt.tight_layout()
#     plt.show()
    plt.savefig(os.path.join(output_filepath, title + ".png"))
    plt.close()

In [None]:
def plot_num_swr(info):
    n_swrs = []

    task_times = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
    for task_time in task_times:
        sliced_lfp = lfp.time_slice(info.task_times[task_time].start, info.task_times[task_time].stop)

        z_thresh = 3.0
        merge_thresh = 0.02
        min_length = 0.05
        swrs = nept.detect_swr_hilbert(sliced_lfp, fs=info.fs, z_thresh=z_thresh,
                                       merge_thresh=merge_thresh, min_length=min_length)

        n_swrs.append(swrs.n_epochs)

    fig, ax = plt.subplots()
    ind = np.arange(len(task_times))

    plt.bar(ind, n_swrs)
    ax.set_xticks(ind)
    ax.set_xticklabels(task_times, rotation=75, fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    plt.axhline(y=10, color='r', linestyle=':')
    title = info.session_id + '-n_swrs'
    plt.title(title)
    plt.tight_layout()
#     plt.show()
    plt.savefig(os.path.join(output_filepath, title + ".png"))
    plt.close()

In [None]:
def plot_decode_error_by_position(decode, xedges, yedges):
    if len(decode["errors"]) == 0:
        print("no errors to plot. Skipping...")
    else:
        error_hist_byactual = np.zeros((len(yedges), len(xedges)))
        n_hist_byactual = np.ones((len(yedges), len(xedges)))
        error_x = []
        error_y = []

        for error, x, y in zip(decode['errors'], decode['position'].x, decode['position'].y):
            x_idx = nept.find_nearest_idx(xedges, x)
            y_idx = nept.find_nearest_idx(yedges, y)
            error_hist_byactual[y_idx][x_idx] += error
            n_hist_byactual[y_idx][x_idx] += 1

        error_byactual = error_hist_byactual / n_hist_byactual

        pp = plt.pcolormesh(xx, yy, error_byactual, vmin=0., vmax=50., cmap='Greys')
        plt.colorbar(pp)
        plt.axis('off')
        title = info.session_id + '-errors-by-position'
        plt.title(title)
        plt.tight_layout()
    #     plt.show()
        plt.savefig(os.path.join(output_filepath, title + ".png"))
        plt.close()

In [None]:
infos = analysis_infos

In [None]:
session_ids = []

n_neurons = []
bins_decoded = []
mean_decode_errors = []

In [None]:
for info in infos:
    print(info.session_id)
    events, position, spikes, lfp, lfp_theta = get_data(info)

    phase = "phase3"

    position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)
    spikes = [spiketrain.time_slice(info.task_times[phase].start, info.task_times[phase].stop) for spiketrain in spikes]

    xedges, yedges = nept.get_xyedges(position, binsize=8)
    xx, yy = np.meshgrid(xedges, yedges)

    neurons = get_tuning_curves(info, position, spikes, xedges, yedges, speed_limit=0.5,
                                phase_id="phase3", cache=False)
    
    n_neurons.append(neurons.n_neurons)
    session_ids.append(info.session_id)
    
    args = dict(info=info,
            dt=0.025,
            gaussian_std=0.0075,
            min_neurons=2,
            min_spikes=1,
            min_swr=3,
            neurons=neurons,
            normalized=False,
            run_time=True,
            speed_limit=10.,
            t_smooth=0.8,
            shuffle_id=False,
            window=0.025,
            decoding_times=info.task_times['phase3'],
            min_proportion_decoded=0.1,
            decode_sequences=False,
            random_shuffle=False,
            )

    decode = dict()
    (decode['decoded'], 
     decode['decoded_epochs'], 
     decode['errors'], 
     decode['position'], 
     decode['likelihood'], 
     decode['percent_decoded']) = get_decoded(**args)
    
    bins_decoded.append(decode['percent_decoded'])
    mean_decode_errors.append(np.mean(decode['errors']))
    
    plot_tuning_curves(neurons, xx, yy)
    
    plot_num_swr(info)
    
    plot_decode_error_by_position(decode, xedges, yedges)

In [None]:
fig, ax = plt.subplots()
ind = np.arange(len(session_ids))

plt.bar(ind, n_neurons)
ax.set_xticks(ind)
ax.set_xticklabels(session_ids, rotation=90, fontsize=10)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
plt.title("Number of active neurons")
plt.axhline(y=40, color='r', linestyle=':')
title = 'all-n_active-neurons.png'
plt.title(title)
plt.tight_layout()
# plt.show()
plt.savefig(os.path.join(output_filepath, title))
plt.close()

In [None]:
fig, ax = plt.subplots()
ind = np.arange(len(session_ids))

plt.bar(ind, bins_decoded)
ax.set_xticks(ind)
ax.set_xticklabels(session_ids, rotation=90, fontsize=10)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
plt.title("Perent of bins decoded")
plt.axhline(y=25, color='r', linestyle=':')
title = 'all-percent_decoded.png'
plt.title(title)
plt.tight_layout()
# plt.show()
plt.savefig(os.path.join(output_filepath, title))
plt.close()

In [None]:
fig, ax = plt.subplots()
ind = np.arange(len(session_ids))

plt.bar(ind, mean_decode_errors)
ax.set_xticks(ind)
ax.set_xticklabels(session_ids, rotation=90, fontsize=10)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
plt.title("Mean decoding errors (cm)")
title = 'all-mean_decoding_errors.png'
plt.title(title)
plt.tight_layout()
# plt.show()
plt.savefig(os.path.join(output_filepath, title))
plt.close()

In [None]:
import cv2

In [None]:
image_path = os.path.join(output_filepath, 'R068d5-tuning_curve-all.png')

In [None]:
image = cv2.imread(image_path, 0)

In [None]:
ret, thresh1 = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)

In [None]:
plt.subplot()
plt.imshow(thresh1,'gray')
plt.xticks([])
plt.yticks([])

plt.show()