# Plot drift across sessions

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
import labdatatools as ldt
import spks
from natsort import natsorted
from pathlib import Path

from tqdm import tqdm

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams.update({'font.size': 18})

## 1. Get the data from GDrive

In [None]:
SUBJECT = 'JC131'
DATATYPE = 'kilosort2.5'
DPATH = Path(ldt.labdata_preferences['paths'][0])
SAVEPATH = Path.home() / 'chronic_manuscript_figures'

files = ldt.rclone_list_files(SUBJECT)
files.head()

ori_sessions = files[files.datatype == 'orientation'].session
ori_sessions = np.unique(ori_sessions).tolist()

temp = []
for sess in ori_sessions: # ensure that kilosort data exists
    if DATATYPE in files[files.session == sess].datatype.values:
        ses = files[files.session == sess].iloc[0].session
        temp.append(ses)
ori_sessions = np.array(temp)
print(len(ori_sessions))

ori_sessions = natsorted(ori_sessions)
ori_sessions = ori_sessions[:-1] #discard session that was taken way later
ori_sessions = np.array(ori_sessions)


In [None]:
def sessionfolder_to_datetime(folder,folderformat = '%Y%m%d_%H%M%S'):
    dd = datetime.datetime.strptime(folder,folderformat)
    return dd #dd.strftime('%Y-%m-%d %H:%M:%S')

datetime_scale = [sessionfolder_to_datetime(f) for f in ori_sessions]

In [None]:
#ori_sessions = ori_sessions[:5] #truncate for testing

In [None]:
#for date in ori_sessions:
#    ldt.rclone_get_data(subject=SUBJECT, session=date, datatype=DATATYPE, excludes=['**.bin']) #spike sorting
#    ldt.rclone_get_data(subject=SUBJECT, session=date, includes=['**imec*.ap.meta'], excludes=['**.bin']) #meta files

In [None]:
# need to drop sessions with a different channelmap
from spks import *

coords = None
keep_idx = []
metapaths = []
for i,date in enumerate(ori_sessions):
    metapath = Path(ldt.get_filepath(subject=SUBJECT, session=date, subfolders=['ephys_*','*'], filename='*imec*.ap.meta'))
    #print(metapath)
    meta = read_spikeglx_meta(metapath)
    if coords is None:
        coords = meta['coords']
    if np.array_equal(coords, meta['coords']):
        keep_idx.append(i)
        metapaths.append(metapath)
    else:
        print(f'Dropping {ori_sessions[i]} due to different channelmap')


ori_sessions = ori_sessions[keep_idx]
    

## 2. Load spike positions, amplitudes, and depths

In [40]:
def get_all_clusters(kilosort_paths):
    clus = []
    for p in kilosort_paths:
        print(p)
        clus.append(Clusters(p, load_template_features=True))
    return clus

def compute_session_offsets_and_srate(metapaths):
    fileoffset_seconds = 0
    session_breaks = [0]
    srates = []
    for m in metapaths:
        meta = read_spikeglx_meta(m)
        srates.append(meta['sRateHz'])
        fileoffset_seconds += meta['fileTimeSecs']
        session_breaks.append(fileoffset_seconds)
    return session_breaks, srates

def get_subset_of_spikes(clus,
                         srates,
                         shank_num,
                         mindepth=None,
                         maxdepth=None,
                         spike_fraction=None,
                         t_start_sec=0,
                         t_end_sec=None):
    amps, depths, times = [],[],[]
    session_breaks = [0]
    fileoffset_seconds = 0
    for i,c in enumerate(clus):
        depth = c.spike_positions[:,1]
        amp = np.abs(c.spike_amplitudes)
        spike_times = c.spike_times / srates[i]
        cluster_ids_on_desired_shank = c.cluster_info[c.cluster_info.shank == shank_num].cluster_id.values
        spikes_on_desired_shank = np.isin(c.spike_clusters, cluster_ids_on_desired_shank)

        valid_time_indices = spike_times >= t_start_sec
        spike_times = spike_times[valid_time_indices]
        amp = amp[valid_time_indices]
        depth = depth[valid_time_indices]
        spikes_on_desired_shank = spikes_on_desired_shank[valid_time_indices]

        if t_end_sec is not None:
            valid_time_indices = spike_times <= t_end_sec
            spike_times = spike_times[valid_time_indices]
            amp = amp[valid_time_indices]
            depth = depth[valid_time_indices]
            spikes_on_desired_shank = spikes_on_desired_shank[valid_time_indices]
        
        spike_times = spike_times - t_start_sec # make spike times start at zero
        spike_times += fileoffset_seconds


        amp = amp[spikes_on_desired_shank]
        depth = depth[spikes_on_desired_shank]
        spike_times = spike_times[spikes_on_desired_shank]

        #not_noise  = (clu.cluster_info.active_channels)<40  #TODO: do something to throw out noise spikes?

        #print(np.min(depth))
        #print(np.max(depth))
        if maxdepth is not None:
            amp = amp[depth < maxdepth]
            spike_times = spike_times[depth < maxdepth]
            depth = depth[depth < maxdepth]


        if mindepth is not None:
            amp = amp[depth > mindepth]
            spike_times = spike_times[depth > mindepth]
            depth = depth[depth > mindepth]

        if spike_fraction is not None:
            subselect_inds = np.random.choice(np.arange(len(amps)), len(amps) // spike_fraction, replace=False)
            amp = amp[subselect_inds]
            spike_times = spike_times[subselect_inds]
            depth = depth[subselect_inds]

        amps.extend(amp)
        depths.extend(depth)
        times.extend(spike_times)

        if not len(spike_times):
            spike_times = [t_end_sec-t_start_sec]
        fileoffset_seconds = spike_times[-1] + 1e-6
        session_breaks.append(fileoffset_seconds)


    amps = np.stack(amps)
    times = np.stack(times)
    depths = np.stack(depths)

    assert amps.shape == depths.shape == times.shape
    return amps, times, depths, session_breaks

In [None]:
# method 1: more memory intense
kilosort_paths = [DPATH / SUBJECT / date / DATATYPE / 'imec0' for date in ori_sessions]

clus = get_all_clusters(kilosort_paths)
session_breaks, srates = compute_session_offsets_and_srate(metapaths)


In [42]:
from dredge.dredge_ap import register
from spks.viz import plot_drift_raster

for shank in [0,1,2,3]:
    amps, times, depths, session_breaks = get_subset_of_spikes(clus, srates, shank, mindepth=4300, maxdepth=4900, t_start_sec=60, t_end_sec=300)
    motion_est, _ = register(amps, depths, times)

    plt.figure(figsize=(20,8))
    plot_drift_raster(times, depths, amps, rasterized=True)
    plt.vlines(session_breaks, *plt.gca().get_ylim(), linestyles='--', colors='black')
    #plt.plot(motion_est.spatial_bin_centers_um + motion_est.displacement.T)
    plt.plot(4650 + motion_est.displacement.T)

    filename = SAVEPATH / f'motion_over_days_shank_{shank}_V1.pdf'
    plt.savefig(filename, format='pdf',dpi=500, bbox_inches='tight')
    
    amps, times, depths, session_breaks = get_subset_of_spikes(clus, srates, shank, mindepth=1500, maxdepth=2100, t_start_sec=60, t_end_sec=300)
    motion_est, _ = register(amps, depths, times)

    plt.figure(figsize=(20,8))
    plot_drift_raster(times, depths, amps, rasterized=True)
    plt.vlines(session_breaks, *plt.gca().get_ylim(), linestyles='--', colors='black')
    #plt.plot(motion_est.spatial_bin_centers_um + motion_est.displacement.T)
    plt.plot(1800 + motion_est.displacement.T)

    filename = SAVEPATH / f'motion_over_days_shank_{shank}_thalamus.pdf'
    plt.savefig(filename, format='pdf',dpi=500, bbox_inches='tight')

In [None]:
## method 2: slower
#kilosort_paths = [DPATH / SUBJECT / date / DATATYPE / 'imec0' for date in ori_sessions]
#SHANK = 0
#amps, times, depths, session_breaks = get_all_spikes(kilosort_paths, metapaths, SHANK, mindepth=3000, t_start_sec=60, t_end_sec=300)