## Plot the drift (estimated with DREDGE) across days 

In [1]:
#%load_ext autoreload
#%autoreload 2

In [2]:
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
#matplotlib.rcParams.update({'font.size': 18})

from labdata.schema import *
PROBE_NUM = 0
SHANK_NUM = 0
T_START_SEC = 300
T_END_SEC = 500 # grab spikes from minutes 5 to 10
session_offset_sec = T_END_SEC - T_START_SEC

In [None]:
# query for orientation sessions
ori_session_keys = (Dataset() & 'subject_name = "JC131"' & 'dataset_name LIKE "%ori%%"').fetch('session_name', as_dict=True)
#drop_session_keys = (Dataset() & 'subject_name = "JC131"' & 'dataset_name LIKE "%Droplet%%"').fetch('session_name', as_dict=True)

# query for sessions with the proper probe configuration and sorting parameters
query1 = SpikeSorting * Session * EphysRecording.ProbeSetting() \
    & ori_session_keys \
    & 'parameter_set_num = 6' \
    & 'configuration_id = 3' \
    & f'session_datetime < "2023-11-8"' \
    & 'session_name <> "20231025_183538"' # exclude faulty recording

session_dates = query1.fetch('session_name', order_by='session_datetime')

query = UnitMetrics * SpikeSorting.Unit * Session * EphysRecording.ProbeSetting() \
    & ori_session_keys \
    & 'parameter_set_num = 6' \
    & 'configuration_id = 3' \
    & f'session_datetime < "2023-11-8"' \
    #& f'shank = {SHANK_NUM}'

query & f'session_name = "{session_dates[0]}"' 

In [None]:
# TODO: restrict to specific shank
# TODO: restrict to specific depth range
# TODO: restrict to single units?

def get_spike_data(session_dates, query, t_start_sec, t_end_sec):
    #session_dates = session_dates[0:10] # testing
    session_offset_sec = t_end_sec - t_start_sec
    all_spikes = {}
    all_spikes['amps'] = []
    all_spikes['depths_um'] = []
    all_spikes['times_s'] = []
    session_breaks = []
    for o,d in enumerate(tqdm(session_dates)):
        dat = (query & f'session_name = "{d}"').fetch('spike_amplitudes','spike_times','spike_positions', as_dict=True)
        fs = (EphysRecording.ProbeSetting() & f'session_name = "{d}"' & f'probe_num = {PROBE_NUM}').fetch1('sampling_rate')
        for i in range(len(dat)): # loop over units
            #plt.hist(dat[i]['spike_positions'][:,0]) # plot for spike positions
            timeinds2grab = np.logical_and(dat[i]['spike_times'] > t_start_sec * fs, dat[i]['spike_times'] < t_end_sec * fs)
            #shankinds2grab = np.logical_and(dat[i]['spike_positions'][:,0] > 0, dat[i]['spike_positions'][:,0] < 150) #FIXME: this is a temp fix until waveform positions are fixed
            shankinds2grab = np.logical_and(dat[i]['spike_positions'][:,0] > 500, dat[i]['spike_positions'][:,0] < 650) #FIXME: this is a temp fix until waveform positions are fixed
            depthinds2grab = dat[i]['spike_positions'][:,1] < 4900
            inds2grab = np.logical_and.reduce([timeinds2grab, shankinds2grab, depthinds2grab])
            #inds2grab = timeinds2grab

            all_spikes['amps'] = np.concatenate([all_spikes['amps'], dat[i]['spike_amplitudes'][inds2grab]])
            all_spikes['depths_um'] = np.concatenate([all_spikes['depths_um'], dat[i]['spike_positions'][inds2grab,1]])
            spike_times_s = dat[i]['spike_times'][inds2grab] / fs
            all_spikes['times_s'] = np.concatenate([all_spikes['times_s'], spike_times_s + session_offset_sec*o - t_start_sec])
        session_breaks.append(session_offset_sec*o)
    session_breaks = np.array(session_breaks[1:])
    return all_spikes, session_breaks

all_spikes, session_breaks = get_spike_data(session_dates, query, T_START_SEC, T_END_SEC)

In [None]:
sys.path.append('/home/joao/lib/dredge/dredge-python/')
from dredge.dredge_ap import register

motion_est, _ = register(**all_spikes)

In [None]:
from spks.viz import plot_drift_raster

plt.figure(figsize=(20,8))
plot_drift_raster(all_spikes['times_s'], all_spikes['depths_um'], all_spikes['amps'], rasterized=True)
plt.vlines(session_breaks, *plt.gca().get_ylim(), linestyles='--', colors='black')
#plt.plot(motion_est.spatial_bin_centers_um + motion_est.displacement.T)
#plt.plot(4650 + motion_est.displacement.T)
#plt.plot(4650 + motion_est.displacement.T)
plt.plot(1800 + motion_est.displacement.T)
plt.show()

In [None]:
plt.hist(all_spikes['depths_um'])
plt.xlabel('depth (from kilosort)')
plt.show()
plt.hist((query & 'shank = 0').fetch('depth'))
plt.xlabel('depth (from avg waveform)')
plt.show()
plt.plot(all_spikes['times_s'])

In [None]:
xy = (ProbeConfiguration() & 'configuration_id = 3' & 'probe_id = 20403312753').fetch('channel_coords')[0]

In [None]:
plt.scatter(*xy.T)