## Plot the drift (estimated with DREDGE) across days 

In [None]:
#%load_ext autoreload
#%autoreload 2

In [None]:
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
#matplotlib.rcParams.update({'font.size': 18})

from labdata.schema import *

SAVEPATH = Path(r'/home/mmelin/chronic_manuscript_figures')
T_START_SEC = 300
T_END_SEC = 500 # grab spikes from minutes 5 to 10
PROBE_NUM = 0
SHANK_NUMS = [0, 1, 2, 3]

session_offset_sec = T_END_SEC - T_START_SEC

In [None]:
# query for orientation sessions
ori_session_keys = (Dataset() & 'subject_name = "JC131"' & 'dataset_name LIKE "%ori%%"').fetch('session_name', as_dict=True)
ori_session_keys = (Dataset() & 'subject_name = "JC131"' & 'dataset_name LIKE "%Droplets%"').fetch('session_name', as_dict=True)
#drop_session_keys = (Dataset() & 'subject_name = "JC131"' & 'dataset_name LIKE "%Droplet%%"').fetch('session_name', as_dict=True)

# query for sessions with the proper probe configuration and sorting parameters (no motion correction applied)
query1 = SpikeSorting * Session * EphysRecording.ProbeSetting() \
    & ori_session_keys \
    & 'parameter_set_num = 8' \
    & 'configuration_id = 3' \
    #& f'session_datetime < "2023-11-8"' 
    #& 'session_name <> "20231025_183538"' # exclude broken recording

session_dates = query1.fetch('session_name', order_by='session_datetime')

query = UnitMetrics * SpikeSorting.Unit * Session * EphysRecording.ProbeSetting() \
    & ori_session_keys \
    & 'parameter_set_num = 8' \
    & 'configuration_id = 3' \
    #& f'session_datetime < "2023-11-8"' \
    #& f'shank = {SHANK_NUM}'

query & f'session_name = "{session_dates[0]}"' 

#session_date = session_dates[0:3] # just for testing

In [None]:
def get_spike_data(session_dates, query, shank_num, t_start_sec, t_end_sec):
    #session_dates = session_dates[0:10] # testing
    session_offset_sec = t_end_sec - t_start_sec
    all_spikes = []
    spks = {}
    session_breaks = []
    for o,d in enumerate(tqdm(session_dates)):
        k = (query & dict(session_name=d, shank=shank_num)).proj()
        dat = (query & k).fetch('spike_amplitudes','spike_times','spike_positions', as_dict=True)
        fs = (EphysRecording.ProbeSetting() & f'session_name = "{d}"' & f'probe_num = {PROBE_NUM}').fetch1('sampling_rate')
        for i in range(len(dat)): # loop over units
            #plt.hist(dat[i]['spike_positions'][:,0]) # plot for spike positions
            timeinds2grab = np.logical_and(dat[i]['spike_times'] > t_start_sec * fs, dat[i]['spike_times'] < t_end_sec * fs)
            #shankinds2grab = np.logical_and(dat[i]['spike_positions'][:,0] > 0, dat[i]['spike_positions'][:,0] < 150) #FIXME: this is a temp fix until waveform positions are fixed
            #shankinds2grab = np.logical_and(dat[i]['spike_positions'][:,0] > 500, dat[i]['spike_positions'][:,0] < 650) #FIXME: this is a temp fix until waveform positions are fixed
            if shank_num in (0,1):
                depthinds2grab = dat[i]['spike_positions'][:,1] < 4790
            elif shank_num in (2,3):
                depthinds2grab = dat[i]['spike_positions'][:,1] < 3500
            #inds2grab = np.logical_and.reduce([timeinds2grab, shankinds2grab, depthinds2grab])
            inds2grab = np.logical_and(timeinds2grab, depthinds2grab)
            if np.sum(inds2grab) == 0:
                continue
            #spks = {}
            spks['amps'] =  dat[i]['spike_amplitudes'][inds2grab]
            spks['depths_um'] = dat[i]['spike_positions'][inds2grab,1]
            spike_times_s = dat[i]['spike_times'][inds2grab] / fs
            spks['times_s'] = spike_times_s + session_offset_sec*o - t_start_sec
            all_spikes.append(spks)
        session_breaks.append(session_offset_sec*o)

    all_spikes = pd.DataFrame(all_spikes).apply(lambda col: col.explode())
    session_breaks = np.array(session_breaks[1:])
    return all_spikes, session_breaks

In [None]:
shank_spikes = []
for s in SHANK_NUMS:
    all_spikes_on_shank, session_breaks = get_spike_data(session_dates, query, s, T_START_SEC, T_END_SEC)
    shank_spikes.append(all_spikes_on_shank)

In [None]:
shank_spikes[0].head()

In [None]:
from pathlib import Path
savepath = Path(r'/home/mmelin/data/JC131combinedspikes')
for i,spks in enumerate(shank_spikes):
    spks.to_csv(savepath / f'shank_{i}_spikes.csv')

np.save(savepath / f'session_breaks.npy', session_breaks)

In [None]:
sys.path.append('/home/joao/lib/dredge/dredge-python/')
from dredge.dredge_ap import register

shank_motion_estimates = []
for all_spikes_on_shank in shank_spikes:
    motion_est, _ = register(**all_spikes_on_shank, bin_s=1)
    shank_motion_estimates.append(motion_est)

In [None]:
from spks.viz import plot_drift_raster

for i,shank in enumerate(SHANK_NUMS):
    all_spikes = shank_spikes[i]
    plt.figure(figsize=(10,4))
    plot_drift_raster(all_spikes['times_s'], all_spikes['depths_um'], all_spikes['amps'], n_spikes_to_plot=50_000, rasterized=True, cmap='gray_r',clim=(0,100))
    plt.vlines(session_breaks, *plt.gca().get_ylim(), linewidth=.5, linestyles='--', colors='black', label='Session breaks')

    motion_est = shank_motion_estimates[i]
    offset = np.mean(plt.ylim())
    plt.plot(motion_est.time_bin_centers_s, offset + motion_est.displacement.T, color='red', lw=1, alpha=.8)
    plt.xlabel('Time (s)')
    plt.ylabel('Depth along shank (um)')
    plt.legend()
    #plt.savefig(SAVEPATH / f'JC131_across_session_drift_shank_{shank}.pdf', bbox_inches='tight', dpi=300)
    plt.show()

In [None]:
lims = np.concatenate([np.array([0]), session_breaks])

def compute_intersession_drift(motion_est, lims):
    avg_pos = []
    for start, end in zip(lims[:-1], lims[1:]):
        avg_pos.append(np.mean(motion_est.displacement[start:end])) # mean position per session
    diffs = np.diff(avg_pos)
    return avg_pos, diffs

#intersession_positions, intersession_drifts = [(compute_intersession_drift(m, lims)) for m in shank_motion_estimates]
intersession_drifts = [compute_intersession_drift(m, lims) for m in shank_motion_estimates]

In [None]:
for pos,delta in intersession_drifts:
    fig, ax = plt.subplots(figsize=(12,2))
    plt.plot(np.arange(len(pos)),pos, color='black', linewidth=.8, alpha=.4)
    plt.scatter(np.arange(len(pos)),pos, color='red')
    plt.ylim(-50,50)
    plt.xlabel('Days from first session')
    plt.ylabel('Shank position (um)')
    plt.gca().spines[['right', 'top']].set_visible(False)

    plt.savefig(SAVEPATH / f'JC131_shank_pos_shank_{shank}.pdf', bbox_inches='tight', dpi=500)

In [None]:
fig, ax = plt.subplots(figsize=(12,2))
cols = ['black','red','red','red']
marker = ['o','o','o','o']
drifts_2_plot = [intersession_drifts[0], intersession_drifts[2]]
for i,(pos,delta) in enumerate(drifts_2_plot):
    plt.plot(np.arange(len(pos)),pos, color='black', linewidth=.8, alpha=.4)
    plt.scatter(np.arange(len(pos)),pos, color=cols[i], marker=marker[i], s=18)

plt.ylim(-50,50)
plt.xlabel('Days from first session')
plt.ylabel('Shank position (um)')
plt.gca().spines[['right', 'top']].set_visible(False)
plt.savefig(SAVEPATH / f'JC131_shank_pos_all_shanks.pdf', bbox_inches='tight', dpi=500)

In [None]:
xvals = []
x = 0
cols = ['black','black','red','red']
cols = ['grey','grey','red','red']
labs = ['Shank 0','Shank 1','Shank 2','Shank 3']
for i,(pos,delta) in enumerate(intersession_drifts):
    scatter_positions = np.random.normal(x, .03, len(delta))
    #fig, ax = plt.subplots(figsize=(2,4))
    parts = plt.violinplot([delta], [x], showextrema=False, showmedians=False)
    plt.scatter(scatter_positions, delta, color='black', s=4, alpha=.5)
    quartile1, median, quartile3 = np.percentile(delta, [25, 50, 75])
    plt.hlines(median, x - .1, x + .1, color='black', linestyle='-', lw=2)
    for pc in parts['bodies']:
        pc.set_facecolor(cols[i])
        pc.set_edgecolor('black')
        pc.set_alpha(1)
    xvals.append(x)
    x += .6

plt.ylim(-40, 40)
plt.ylabel('Drift between sessions (um)')
plt.xlabel('')
plt.xticks(xvals, labs)
plt.gca().spines[['right', 'top']].set_visible(False)
#plt.savefig(SAVEPATH / f'JC131_drift_violin_shank_{i}.pdf', bbox_inches='tight', dpi=500)
plt.savefig(SAVEPATH / f'JC131_drift_violin_all_shanks.pdf', bbox_inches='tight', dpi=300)
plt.show()

In [None]:
plt.hist(all_spikes['depths_um'])
plt.xlabel('depth (from kilosort)')
plt.show()
plt.hist((query & 'shank = 0').fetch('depth'))
plt.xlabel('depth (from avg waveform)')
plt.show()
plt.plot(all_spikes['times_s'])

In [None]:
xy = (ProbeConfiguration() & 'configuration_id = 3' & 'probe_id = 20403312753').fetch('channel_coords')[0]

In [None]:
plt.scatter(*xy.T)