In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import numpy as np
import warnings
import random
import scipy
import pickle
import os
import nept
import scalebar

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_zones
from utils_maze import get_bin_centers

thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding", "shuffled")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

# Set random seeds
random.seed(0)
np.random.seed(0)

In [None]:
colours = dict()
colours["u"] = "#2b8cbe"
colours["shortcut"] = "#31a354"
colours["novel"] = "#d95f0e"
colours["other"] = "#bdbdbd"

task_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
zone_labels = ["u", "shortcut", "novel", "other"]

In [None]:
from analyze_classy_decode import Session, TaskTime

In [None]:
import info.r063d2 as info

true_path = os.path.join(pickle_filepath, info.session_id + "_likelihoods_true.pkl")

with open(true_path, 'rb') as fileobj:
    true_session = pickle.load(fileobj)

In [None]:

maze_highlight = "#fed976"

zones = getattr(true_session, task_label).zones

xx, yy = np.meshgrid(info.xedges, info.yedges)
xcenters, ycenters = get_bin_centers(info)
xxx, yyy = np.meshgrid(xcenters, ycenters)

fig = plt.figure(figsize=(12, 8))
gs = gridspec.GridSpec(4, 5)
gs.update(wspace=0.3, hspace=0.3)

ax1 = plt.subplot(gs[0, 0])
task_label = "postrecord"
tuning_curves = getattr(true_session, task_label).tuning_curves

raw_tuning_curves = np.nansum(tuning_curves, axis=(0,1))
raw_tuning_curves[np.isnan(raw_tuning_curves)] = 0

pp = ax1.pcolormesh(xx, yy, raw_tuning_curves, cmap='bone_r')
ax1.set_title("Raw TCs", fontsize=14)
ax1.axis('off')

ax2 = plt.subplot(gs[0, 1])
avg_tuning_curves = np.nansum(tuning_curves/tuning_curves.shape[1], axis=(0,1))
avg_tuning_curves[np.isnan(avg_tuning_curves)] = 0

pp2 = ax2.pcolormesh(xx, yy, avg_tuning_curves, cmap='bone_r')
ax2.set_title("Average TCs", fontsize=14)
ax2.axis('off')

ax3 = plt.subplot(gs[0, 2])
norm_tuning_curves = np.nansum(
    tuning_curves/np.nansum(tuning_curves, axis=(2,3))[..., np.newaxis, np.newaxis], axis=(0,1))
norm_tuning_curves[np.isnan(norm_tuning_curves)] = 0

pp3 = ax3.pcolormesh(xx, yy, norm_tuning_curves, cmap='bone_r')
ax3.set_title("Normalized TCs", fontsize=14)
ax3.axis('off')

vmax = 0.1
for i, task_label in enumerate(task_labels):
    ax = plt.subplot(gs[1, i])
    likelihood = np.nanmean(np.array(getattr(true_session, task_label).likelihoods), axis=(0, 1))
    likelihood[np.isnan(likelihood)] = 0

    ax.plot(true_session.position.x, true_session.position.y, ".", color=maze_highlight, ms=1, alpha=0.2)
    ax.set_title("All SWRs " + task_label, fontsize=14)
    ax.axis('off')
    likelihoods = ax.pcolormesh(xx, yy, likelihood, vmax=vmax, cmap='bone_r')
    for label in ["u", "shortcut", "novel"]:
        ax.contour(xxx, yyy, zones[label], levels=0, linewidths=2, colors=colours[label])

cbar = fig.colorbar(likelihoods)
cbar.ax.tick_params(labelsize=10)

plt.tight_layout()
plt.show()