In [None]:
import cache
import meta
import meta_session

import nept
import numpy as np
import scipy.stats
from shapely.geometry import LineString, Point
from shapely.ops import split
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
info = meta_session.r063d2
group = "day1"

def get(key, info=None, group=None):
    if info is not None:
        return cache.load(f"ind-{info.session_id}", key)
    return cache.load(f"grp-{group}", key)

In [None]:
task_times = get("task_times", info=info)
lines = get("lines", info=info)
raw_linear = get("raw_linear", info=info)
spikes = get("spikes", info=info)
trials = get("trials", info=info)

In [None]:
# Step 1
linear = raw_linear["u_with_feeders"]
linear_max = (
    lines["u_with_feeders"].project(
        Point(*info.path_pts["feeder2"])
    )
    + meta.feeder_dist
)
maze_times = task_times["maze_times"]
trials = trials["u"]
edges = nept.get_edges(
    0,
    linear_max,
    binsize=meta.tc_binsize,
    lastbin=False,
)
speed_limit = meta.speed_limit
t_smooth = meta.t_smooth
gaussian_std = meta.gaussian_std
tuning_spikes = spikes

In [None]:
# Step 2: restrict to maze times
linear = linear[maze_times]
tuning_spikes = [
    spiketrain.time_slice(maze_times.starts, maze_times.stops)
    for spiketrain in tuning_spikes
]

In [None]:
# Step 3: restrict to trials
linear = linear[trials]
tuning_spikes = [
    spiketrain.time_slice(trials.starts, trials.stops)
    for spiketrain in tuning_spikes
]

In [None]:
# Step 4: speed treshold
run_epoch = nept.run_threshold(linear, thresh=speed_limit, t_smooth=t_smooth)
linear = linear[run_epoch]
tuning_spikes = [
    spiketrain.time_slice(run_epoch.starts, run_epoch.stops)
    for spiketrain in tuning_spikes
]

In [None]:
# Step 5: remove inactive neurons
min_n_spikes = 50
keep_spikes_idx = []
tuning_spikes = [spikes for spikes in tuning_spikes if len(spikes.time) > min_n_spikes]

In [None]:
seconds_shown = 120

_, (axtop, axbtm) = plt.subplots(nrows=2, sharex=True, figsize=(12, 10))

epoch = nept.Epoch(
    trials[0].start, trials[0].start + seconds_shown
)
axtop.plot(linear[epoch].time, linear[epoch].x, ".")
axbtm.eventplot(
    [spike.time_slice(epoch.starts, epoch.stops).time for spike in tuning_spikes],
    colors=["k"],
    linelengths=0.8,
    linewidths=1,
)
axbtm.set_ylim(len(spikes) - 0.5, -0.5)
axbtm.set_xlim(epoch.start, epoch.stop)