In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import numpy as np
import warnings
import random
import scipy
import pickle
import os
import nept
import scalebar

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_zones
from utils_maze import get_bin_centers

thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding", "shuffled")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

# Set random seeds
random.seed(0)
np.random.seed(0)

In [None]:
colours = dict()
colours["u"] = "#2b8cbe"
colours["shortcut"] = "#31a354"
colours["novel"] = "#d95f0e"
colours["other"] = "#bdbdbd"

task_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
zone_labels = ["u", "shortcut", "novel", "other"]

In [None]:
session = true_session

vmax = 0.2
maze_highlight = "#fed976"

zones = getattr(session, task_label).zones

xx, yy = np.meshgrid(info.xedges, info.yedges)
xcenters, ycenters = get_bin_centers(info)
xxx, yyy = np.meshgrid(xcenters, ycenters)



fig = plt.figure(figsize=(12, 8))
gs = gridspec.GridSpec(4, 5)
gs.update(wspace=0.3, hspace=0.3)

ax1 = plt.subplot(gs[0, 1])
task_label = "postrecord"
tuning_curves = getattr(session, task_label).tuning_curves

raw_tuning_curves = np.nansum(tuning_curves, axis=(0,1))
raw_tuning_curves[np.isnan(raw_tuning_curves)] = 0

pp = ax1.pcolormesh(xx, yy, raw_tuning_curves, cmap='bone_r')
ax1.set_title("Raw TCs", fontsize=14)
ax1.axis('off')

ax2 = plt.subplot(gs[0, 2])
norm_tuning_curves = np.nansum(
    tuning_curves/np.nansum(tuning_curves, axis=(2,3))[..., np.newaxis, np.newaxis], axis=(0,1))
norm_tuning_curves[np.isnan(norm_tuning_curves)] = 0

pp2 = ax2.pcolormesh(xx, yy, norm_tuning_curves, cmap='bone_r')
ax2.set_title("Normalized TCs", fontsize=14)
ax2.axis('off')

ax3 = plt.subplot(gs[0, 3])
avg_tuning_curves = np.nansum(tuning_curves/tuning_curves.shape[1], axis=(0,1))
avg_tuning_curves[np.isnan(avg_tuning_curves)] = 0

pp3 = ax3.pcolormesh(xx, yy, avg_tuning_curves, cmap='bone_r')
ax3.set_title("Average TCs", fontsize=14)
ax3.axis('off')

for i, task_label in enumerate(task_labels):
    ax = plt.subplot(gs[1, i])
    likelihood = np.nanmean(np.array(getattr(session, task_label).likelihoods), axis=(0, 1))
    likelihood[np.isnan(likelihood)] = 0

    ax.plot(session.position.x, session.position.y, ".", color=maze_highlight, ms=1, alpha=0.2)
    ax.set_title("All SWRs " + task_label, fontsize=14)
    ax.axis('off')
    likelihoods = ax.pcolormesh(xx, yy, likelihood, vmax=vmax, cmap='bone_r')
    for label in ["u", "shortcut", "novel"]:
        ax.contour(xxx, yyy, zones[label], levels=0, linewidths=2, colors=colours[label])

fig.colorbar(likelihoods)

plt.tight_layout()
plt.show()

In [None]:
norm_tuning_curves.shape

In [None]:
multiple_tuning_curves.shape

In [None]:
tc = session.pauseA.tuning_curves

In [None]:
tc.shape

In [None]:
tt = tc / np.nansum(tc, axis=(2,3))[..., np.newaxis, np.newaxis]

In [None]:
ty = tc / tc.shape[1]

In [None]:
multiple_tuning_curves = np.zeros((tc.shape[2], tc.shape[3]))

for i in range(tc.shape[1]):
    multiple_tuning_curves += tc[0, i, :, :]

In [None]:
multiple_tuning_curves[np.isnan(multiple_tuning_curves)] = 0

In [None]:
plt.pcolormesh(xx, yy, multiple_tuning_curves, vmax=500, cmap='bone_r')
plt.show()

In [None]:
from analyze_classy_decode import Session, TaskTime

In [None]:
import info.r063d2 as info

true_path = os.path.join(pickle_filepath, info.session_id + "_likelihoods_true.pkl")

with open(true_path, 'rb') as fileobj:
    true_session = pickle.load(fileobj)

for task_label in task_labels:
    title = "Posterior for " + task_label + " all SWRs"
    plot_likelihood_overspace(info, true_session, task_label, colours, title)

In [None]:
# Fixing random state for reproducibility
np.random.seed(0)

# fake up some data
spread = np.random.rand(50) * 100
center = np.ones(25) * 50
flier_high = np.random.rand(10) * 100 + 100
flier_low = np.random.rand(10) * -100
data = np.concatenate((spread, center, flier_high, flier_low))

fig, axs = plt.subplots(2, 4)

# basic plot
axs[0, 0].plot(data)
axs[0, 0].set_title('basic plot')

# notched plot
axs[0, 1].plot(data)
axs[0, 1].set_title('notched plot')

# change outlier point symbols
axs[0, 2].plot(data)
axs[0, 2].set_title('change outlier\npoint symbols')

# don't show outlier points
axs[1, 0].plot(data)
axs[1, 0].set_title("don't show\noutlier points")

# horizontal boxes
axs[1, 1].plot(data)
axs[1, 1].set_title('horizontal boxes')

# change whisker length
axs[1, 2].plot(data)
axs[1, 2].set_title('change whisker length')

fig.subplots_adjust(left=0.08, right=0.98, bottom=0.05, top=0.9,
                    hspace=0.4, wspace=0.3)

plt.show()