In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import numpy as np
import warnings
import random
import scipy
import pickle
import os
import nept
import scalebar

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_zones
from utils_maze import get_bin_centers

thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding", "classy", "compare_task-times")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

# Set random seeds
random.seed(0)
np.random.seed(0)

In [None]:
colours = dict()
colours["u"] = "#2b8cbe"
colours["shortcut"] = "#31a354"
colours["novel"] = "#d95f0e"
colours["other"] = "#bdbdbd"

task_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
zone_labels = ["u", "shortcut", "novel", "other"]

In [None]:
from analyze_classy_decode import Session, TaskTime

In [None]:
# import info.r063d2 as r063d2
# import info.r063d3 as r063d3
# infos = [r063d2, r063d3]

In [None]:
from run import analysis_infos
infos = analysis_infos

In [None]:
for info in infos:
    print(info.session_id)
    true_path = os.path.join(pickle_filepath, info.session_id + "_likelihoods_true.pkl")
    
    with open(true_path, 'rb') as fileobj:
        true_session = pickle.load(fileobj)
        
    passthresh_path = os.path.join(pickle_filepath, info.session_id + "_likelihoods_true_passthresh.pkl")
    
    with open(passthresh_path, 'rb') as fileobj:
        passthresh_session = pickle.load(fileobj)
    
    n_shuffles = 100
    shuffled_path = os.path.join(pickle_filepath,
                                 info.session_id + "_likelihoods_shuffled-%03d.pkl" % n_shuffles)

    with open(shuffled_path, 'rb') as fileobj:
        shuffled_session = pickle.load(fileobj)

    filepath = None
    maze_highlight = "#fed976"

    zones = getattr(true_session, "prerecord").zones

    xx, yy = np.meshgrid(info.xedges, info.yedges)
    xcenters, ycenters = get_bin_centers(info)
    xxx, yyy = np.meshgrid(xcenters, ycenters)

    fig = plt.figure(figsize=(12, 8))
    gs = gridspec.GridSpec(4, 5)
    gs.update(wspace=0.2, hspace=0.5)

    ax1 = plt.subplot(gs[0, 0])
    task_label = "postrecord"
    tuning_curves = getattr(true_session, task_label).tuning_curves

    raw_tuning_curves = np.nansum(tuning_curves, axis=(0,1))
    raw_tuning_curves[np.isnan(raw_tuning_curves)] = 0

    pp = ax1.pcolormesh(xx, yy, raw_tuning_curves, cmap='bone_r')
    ax1.set_title("Raw TCs", fontsize=14)
    ax1.axis('off')

    ax2 = plt.subplot(gs[0, 1])
    avg_tuning_curves = np.nansum(tuning_curves/tuning_curves.shape[1], axis=(0,1))
    avg_tuning_curves[np.isnan(avg_tuning_curves)] = 0

    pp2 = ax2.pcolormesh(xx, yy, avg_tuning_curves, cmap='bone_r')
    ax2.set_title("Average TCs", fontsize=14)
    ax2.axis('off')

    ax3 = plt.subplot(gs[0, 2])
    norm_tuning_curves = np.nansum(
        tuning_curves/np.nansum(tuning_curves, axis=(2,3))[..., np.newaxis, np.newaxis], axis=(0,1))
    norm_tuning_curves[np.isnan(norm_tuning_curves)] = 0

    pp3 = ax3.pcolormesh(xx, yy, norm_tuning_curves, cmap='bone_r')
    ax3.set_title("Normalized TCs", fontsize=14)
    ax3.axis('off')
    
    ax4 = plt.subplot(gs[0, 3])
    shuffled_tuning_curves = getattr(shuffled_session, task_label).tuning_curves
    shuff_tuning_curves = np.nansum(shuffled_tuning_curves, axis=(0,1))
    shuff_tuning_curves[np.isnan(shuff_tuning_curves)] = 0
    
    pp4 = ax4.pcolormesh(xx, yy, shuff_tuning_curves, cmap='bone_r')
    ax4.set_title("Shuffled TCs", fontsize=14)
    ax4.axis('off')

    vmax = 0.1
    for i, task_label in enumerate(task_labels):
        ax5 = plt.subplot(gs[1, i])
        likelihood = np.nanmean(np.array(getattr(true_session, task_label).likelihoods), axis=(0, 1))
        likelihood[np.isnan(likelihood)] = 0

        ax5.plot(true_session.position.x, true_session.position.y, ".", color=maze_highlight, ms=1, alpha=0.2)
        ax5.set_title(task_label + "\n All SWRs", fontsize=14)
        ax5.axis('off')
        likelihoods = ax5.pcolormesh(xx, yy, likelihood, vmax=vmax, cmap='bone_r')
        for label in ["u", "shortcut", "novel"]:
            ax5.contour(xxx, yyy, zones[label], levels=0, linewidths=2, colors=colours[label])
        if getattr(true_session, task_label).swrs is not None:
            ax5.text(np.max(true_session.position.x)/4, -10, 
                     "# swr: "+str(getattr(true_session, task_label).swrs.n_epochs), 
                     fontsize=10)
        else:
            ax5.text(np.max(true_session.position.x)/4, -10, 
                     "# swr: 0", 
                     fontsize=10)

    for i, task_label in enumerate(task_labels):
        ax6 = plt.subplot(gs[2, i])
        likelihood = np.nanmean(np.array(getattr(passthresh_session, task_label).likelihoods), axis=(0, 1))
        likelihood[np.isnan(likelihood)] = 0

        ax6.plot(passthresh_session.position.x, passthresh_session.position.y, ".", color=maze_highlight, ms=1, alpha=0.2)
        ax6.set_title("Passthresh", fontsize=14)
        ax6.axis('off')
        likelihoods = ax6.pcolormesh(xx, yy, likelihood, vmax=vmax, cmap='bone_r')
        for label in ["u", "shortcut", "novel"]:
            ax6.contour(xxx, yyy, zones[label], levels=0, linewidths=2, colors=colours[label])
        if getattr(passthresh_session, task_label).swrs is not None:
            ax6.text(np.max(passthresh_session.position.x)/4, -10, 
                     "# swr: "+str(getattr(passthresh_session, task_label).swrs.n_epochs), 
                     fontsize=10)
        else:
            ax6.text(np.max(passthresh_session.position.x)/4, -10, 
                     "# swr: 0", 
                     fontsize=10)

    for i, task_label in enumerate(task_labels):
        maxes = np.vstack([getattr(passthresh_session, task_label).maxes(zone_label) for zone_label in zone_labels])
        max_of_maxes = np.max(maxes, axis=0)

        max_counts = [np.sum(getattr(passthresh_session, task_label).maxes(zone_label) == max_of_maxes, axis=1)[0] 
                      for zone_label in zone_labels]

        ax7 = plt.subplot(gs[3, i])
        ax7.bar(np.arange(len(zone_labels)),
                max_counts,
                color=[colours[zone_label] for zone_label in zone_labels], edgecolor='k')
        ax7.set_title("Max", fontsize=14)
        ax7.set_xticks(np.arange(len(zone_labels)))
        ax7.set_xticklabels(zone_labels, rotation=90)
        ax7.spines['right'].set_visible(False)
        ax7.spines['top'].set_visible(False)
        ax7.spines['left'].set_visible(False)
        ax7.tick_params(left=False)
        ax7.tick_params(labelleft=False)

    cbar = fig.colorbar(likelihoods, cax=ax6)
    cbar.ax.tick_params(labelsize=10)

    plt.tight_layout()
    
    filename = info.session_id + "_summary-likelihoods.png"
    plt.savefig(os.path.join(output_filepath, filename))
    plt.close()
    
#     plt.show()