In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('../bioAI.mplstyle')
import expipe
import sys
import pathlib
import numpy as np
import numpy.ma as ma
import scipy
import tqdm
import pandas as pd

from scipy.interpolate import interp1d

sys.path.append('../ca2-mec') if '../ca2-mec' not in sys.path else None 
import dataloader as dl
from utils import *
from plotting_functions import *

In [3]:
project = expipe.get_project(dl.project_path())
project.actions

HBox(children=(VBox(children=(Text(value='', placeholder='Search'), Select(layout=Layout(height='200px'), opti…

### Initialise data loader

In [4]:
lim = [0,1200] # limit recording times - in seconds



#CA2 - rotte 001:
#include_actions = ['001-181220-2', '001-181220-3', '001-181220-4', '001-181220-5', '001-181220-6', '001-191220-1', '001-191220-3', '001-191220-6', '001-191220-7', '001-191220-8'] # choose actions to include
#include_actions = ['001-191220-1', '001-191220-3', '001-191220-6', '001-191220-7', '001-191220-8']
include_actions = ['001-181220-2', '001-181220-3', '001-181220-4', '001-181220-5', '001-181220-6']
#include_actions = ['001-211220-1', '001-211220-2', '001-211220-3', '001-211220-4', '001-211220-5']

#CA2 rotte 007:
#include_actions = ['007-081221-1', '007-081221-2', '007-081221-3', '007-081221-4', '007-081221-5', '007-081221-6']
#include_actions = ['007-091221-1', '007-091221-2', '007-091221-3', '007-091221-4', '007-091221-5', '007-091221-7']

#CA2 - rotte 011:
#include_actions = ['011-120321-2', '011-120321-3', '011-120321-4', '011-120321-5', '011-120321-6']
#include_actions = ['022-160322-1', '022-160322-2', '022-160322-3', '022-160322-4', '022-160322-5', '022-160322-7']

#CA2 - rotte 022:
#include_actions = ['022-160322-1', '022-160322-2', '022-160322-3', '022-160322-4', '022-160322-5', '022-160322-7']


#CA1 - rotte 144:
#include_actions = ['144-100621-1', '144-100621-2', '144-100621-3', '144-100621-4', '144-100621-5']

#MEC- rotte 001 og 002:
#include_actions = ['001-280721-1', '001-280721-2', '001-280721-3', '001-280721-4', '001-280721-5'] #Maria sine opptak fra MEC
#include_actions = ['002-050721-1', '002-050721-2', '002-050721-3', '002-050721-4', '002-050721-5'] #Maria sine opptak fra MEC


# Cast assertion error if include_actions contain actions from multiple entities. 
animal_entity = include_actions[0].split('-')[0]
multiple_entities = all([animal_entity in action_id for action_id in include_actions])
assert multiple_entities, "Requires only actions from same animal entity! Read start of notebook!"

spikes = []
tracking = {}
for action_id in include_actions:
    spikes += dl.load_spiketrains(action_id, lim=lim, identify_neurons=True)
    tracking[action_id] = dl.load_tracking(action_id, lim=lim, ca2_transform_data=True) # only get positions


print("load #spikes: ", len(spikes))

"""
# correct for inconsistent mua-annotations
spikes = dl.correct_mua(spikes, only_good_mua=True)
print("#units after mua-corrections:", len(spikes))
spatial_map = sp.SpatialMap()

# SELECT brain region(s) to include cells from
spikes = dl.in_brain_regions(spikes, ['ca2'])
print("#units after brain region selection:", len(spikes))

# only include cells that are persistent across all actions
spikes = dl.persistent_units(spikes, include_actions)
print(f"Num spike_trains: {len(spikes)}. Num persistent units: {len(spikes) / len(include_actions)}")
"""

spatial_map = sp.SpatialMap()

load #spikes:  61


In [5]:
spikes.sort(key=lambda sptr: sptr.annotations["mua_quality"],reverse=True)
spikes.sort(key=lambda sptr: sptr.annotations["action_id"])
spikes.sort(key=lambda sptr: sptr.annotations["unit_idnum"])
spikes.sort(key=lambda sptr: len(sptr.annotations["persistent_trials"]),reverse=True)
include_actions.sort()

"""
for sptr in spikes:
    print(sptr.annotations["action_id"], sptr.annotations["unit_idnum"], len(sptr.annotations["persistent_trials"]))
"""
_=2

In [None]:
unique_unsorted_unit_idnums = []
for sptr in spikes:
    if sptr.annotations["unit_idnum"] not in unique_unsorted_unit_idnums:
        unique_unsorted_unit_idnums.append(sptr.annotations["unit_idnum"])

In [None]:
cell_rate(spikes[0])

In [None]:
# sort on unit id
figscale = 1.2

add_title = True
for unit_idnum in unique_unsorted_unit_idnums:
    unit_spikes = [spike_train for spike_train in spikes if spike_train.annotations["unit_idnum"] == unit_idnum]
    fig,axs = plt.subplots(ncols=len(include_actions),figsize=(len(include_actions)*figscale, figscale))
    j=0
    for i in range(len(include_actions)):
        ax = axs[i]
        axis_off_labels_on(ax)
        if j == len(unit_spikes):
            continue
        spike_train = unit_spikes[j]
        action_id = spike_train.annotations["action_id"]
        if action_id != include_actions[i]:
            continue
        x,y,t,_ = tracking[action_id].T
        ratemap = spatial_map.rate_map(x, y, t, spike_train)
        ax.imshow(ratemap.T,origin='lower')
        j+=1
        # add action_id title to first few plots
        if add_title:
            ax.set_title(spike_train.annotations["action_id"] + "\n" + f"{cell_rate(spike_train)[0]}")
        else:
            # add average cell rate to ratemap title
            ax.set_title(cell_rate(spike_train)[0])
    add_title = False
    
    axs[0].set_ylabel(spike_train.annotations["unit_idnum"])
    
#    plt.savefig('./plots/object-001-191220,'+str(unit_idnum)+'.tiff',dpi=300)

#    fig.savefig(f\"./plots/object-001-191220-{(spike_train.annotations['unit_idnum'])}.pdf\")