In [None]:
%matplotlib widget
from labdata.schema import *
import pylab as plt


plt.matplotlib.rcParams['pdf.fonttype'] = 42
plt.matplotlib.rcParams['ps.fonttype'] = 42
savepath = Path('../../figures/figure4')
savepath.mkdir(parents = True,exist_ok = True)

KILOSORT_PARAMS_SET = 5
METRICS_PARAMS_SET = 1

procedures = ProbeInsertion() * Procedure() & 'procedure_type LIKE "chronic%"'
procedures = pd.DataFrame(procedures.fetch('probe_id', 'subject_name', 'procedure_datetime','procedure_type', order_by='procedure_datetime desc', as_dict=True)).drop_duplicates()
procedures = procedures[::]
first_implant = procedures.procedure_datetime.min()

# sort probe ids by their first implant
implants = procedures[procedures.procedure_type == 'chronic implant']
first_implants = implants.groupby('probe_id')['procedure_datetime'].min().reset_index()
sorted_prbs = first_implants.sort_values('procedure_datetime', ascending=False).probe_id.values

In [None]:
sorted_prbs = ['22420006863',
               '22420007284',  
               '22420007982', 
               '22420007912',
               '22420006801',
               '22420007684',
               '22420007691',
               '22420007301',
               '22420007362', 
               '22420007231',
               '22420007032',
               '20403317093',
               '20403317493',
               '20403312592',
               '20097916762',
               '20097916141',
               '20403314292',
               '20403312442',
               '20403312213',
               '20403312751',
               '20403312621',
               '20403312753',
               '20097916222',
               '20097916182',
               '19454421152',
               '20097902741',
               '20097902851',
               ]
probe_labels = []
for p in sorted_prbs:
    probe_type = (Probe() & f'probe_id = "{p}"').fetch('probe_type')[0]
    if probe_type == '2013': # this requires latex to be installed
        pt = r'$\bf{NP2}$ '
    elif probe_type == '24':
        pt = r'$\bf{NP2a}$ '
    else:
        pt = r'$\bf{NP1}$ '
    probe_labels.append(pt+f'x{p[-4:]}')

In [None]:
UnitCount.populate(display_progress=True,processes = 4)

In [None]:
from datetime import timedelta
clims=(0,300)
bar_height = .9
bar_distance = 1.2

fig, ax = plt.subplots(figsize=(8,4))
prbs, heights = [], []
num_sessions = []
use_yscale = False
for prb_ind,prb in enumerate(sorted_prbs):
    implant_times = procedures[(procedures.probe_id == prb) & (procedures.procedure_type == 'chronic implant')].procedure_datetime.values
    explants = pd.DataFrame((Procedure*ProbeExtraction() & f'probe_id = "{prb}"').fetch())
    explant_times = explants.procedure_datetime.values
    height = prb_ind*bar_distance
    for implant in implant_times:
        valid_explants = explant_times[explant_times > implant]
        if len(valid_explants) == 0:
            explant = np.datetime64(datetime.now()+timedelta(days=1))
        else:
            explant = valid_explants[np.argmin(valid_explants - implant)]
        plt.barh([height], np.array(explant-implant), left=[implant], height=bar_height, 
                 color='#999999', edgecolor='none',alpha=1,clip_on = False)
        for ie,e in explants.iterrows():
            if e.extraction_successful == 0:
                plt.plot(e.procedure_datetime + timedelta(days = 3),height,'x',markeredgecolor = 'gray',markersize = 3, clip_on = False)
        
    prbs.append(prb)
    heights.append(height)

    #get number of single units for a probe
    units_query = Session() * UnitCount() * EphysRecording.ProbeSetting() & dict(probe_id=prb,
                                                                                 unit_criteria_id=METRICS_PARAMS_SET,
                                                                                 parameter_set_num=KILOSORT_PARAMS_SET)
    recording_dates, num_single_units = units_query.fetch('session_datetime','sua',order_by='session_datetime')
    num_sessions.append(len(num_single_units))
    if num_sessions[-1] > 270:
        print(np.array(recording_dates[-1]-recording_dates[0],dtype='timedelta64[D]'))
    offset = np.random.normal(0, 0.15, len(recording_dates))
    offset = np.clip(offset, -bar_height/2, bar_height/2)
    #offset = np.random.uniform(-bar_height/2, bar_height/2, len(recording_dates))
    if use_yscale:
        y = np.array(num_single_units)
        y = y - np.min(y)
        y = y / np.max(y)
    
        scat = plt.scatter(recording_dates, y*0.6 + 0.1 +prb_ind*bar_distance-bar_height/2, 
                          s=1, c=np.array(num_single_units), clim=clims, cmap='inferno', alpha=1,clip_on = False)
    else:
        scat = plt.scatter(recording_dates, offset + prb_ind*bar_distance, 
                           s=1, c=np.array(num_single_units), clim=clims, cmap='inferno', alpha=1,clip_on = False)


plt.gca().xaxis.set_major_locator(plt.matplotlib.dates.MonthLocator())
date_format = plt.matplotlib.dates.DateFormatter('') # %b-%Y
plt.gca().xaxis.set_major_formatter(date_format)
plt.xticks(fontsize = 7)
plt.yticks(fontsize = 7)
plt.grid(which='major', axis='x', linestyle='--')
ax.set_axisbelow(True)
plt.ylabel('Probe number',fontsize = 8)
# plt.xlabel('Month',fontsize = 8)
from datetime import timedelta
start = first_implants.iloc[0].procedure_datetime
plt.xlim([start - timedelta(days=2),datetime.now()])
plt.yticks(heights, probe_labels); # label probes
# plt.xticks(rotation=-45);
plt.plot([start + timedelta(weeks = 4), start + timedelta(weeks = 4)],[10,10] )
cbar = plt.colorbar(scat,shrink = 0.2,ticks = [0,150,300])
cbar.solids.set_edgecolor("face")
cbar.set_label('Single units',fontsize = 8)
plt.gca().spines[['right', 'top', 'left','bottom']].set_visible(False)
plt.gca().xaxis.set_tick_params(width=0)
plt.gca().yaxis.set_tick_params(width=0)
plt.ylim([-1,max(heights)+1])

if use_yscale:
    fig.savefig(savepath/f'recording_summary_yscaled.pdf')
else:
    fig.savefig(savepath/f'recording_summary.pdf')

## Reimplant figure and so on


In [None]:
plt.xlim()

In [None]:
SUBJECTS = ['MM009','MM010','MM011']
SESSIONS = ['2023-10-16/001', # #TODO: change back to 10-17
            '2023-10-25/001',
            '2023-11-21/001'] # TODO 2023-11-28/001 was used originally, but needs to be sorted  

SORTING_PARAMETER_NUM = 5
CRITERIA_ID = 1

labels = ['Reimplant 1','Reimplant 2','Reimplant 3']

# base_query = UnitCount.Unit * UnitMetrics * EphysRecording.ProbeSetting

In [None]:
fig = plt.figure(figsize = [4,3])
x = 0
xvals = []
su_amps = []
mu_amps = []
su_pos = []
mu_pos = []
for i,(sname,sdate) in enumerate(zip(SUBJECTS, SESSIONS)):
    query_dict = dict(parameter_set_num=SORTING_PARAMETER_NUM,
                      unit_criteria_id=CRITERIA_ID,
                      subject_name=sname,
                      session_name=sdate)
    keys = (UnitCount() & query_dict).fetch(as_dict = True)
    units = (UnitCount.Unit * UnitMetrics * EphysRecording.ProbeSetting & keys)
    print(f'There are {len(units)} total units for {sname} on {sdate}')
    print(f'There are {np.sum(units.fetch("passes"))} single units for {sname} on {sdate}')
    passing, amps = units.fetch('passes','spike_amplitude')
    
    passing = passing[amps > 0]
    amps = amps[amps > 0]
    su_amps.append(amps[passing==1])
    mu_amps.append(-amps[passing==0])

    # plotting
    su_pos.append(np.random.normal(x, scale=.05, size=len(su_amps[-1])))
    mu_pos.append(np.random.normal(x, scale=.05, size=len(mu_amps[-1])))
    
parts = plt.violinplot(su_amps,showmedians = True)
for pc in parts['bodies']:
    pc.set_edgecolor('black')
    pc.set_facecolor('gray')
    pc.set_alpha(1)
for pc in [parts[k] for k in ['cmaxes', 'cmins', 'cbars', 'cmedians']]:
    pc.set_color('black')
    
parts = plt.violinplot(mu_amps,showmedians = True)
for pc in parts['bodies']:
    pc.set_edgecolor('black')
    pc.set_facecolor('gray')
    pc.set_alpha(1)
for pc in [parts[k] for k in ['cmaxes', 'cmins', 'cbars', 'cmedians']]:
    pc.set_color('black')
plt.xticks([1,2,3],labels);
for i,(r,m) in enumerate(zip(su_amps,mu_amps)):
    plt.text(i+1,450,len(r))
    plt.text(i+1,-450,len(m))
fig.savefig(savepath/f're_implants.pdf')

In [None]:
# Stats on su_amps and mu_amps
[p.shape for p in probe_mads]

In [None]:
# now let's plot noise stats
from labdata import chronic_paper as paper
x = 0
xvals = []
probe_mads = []
plt.figure()
for i,(sname,sdate) in enumerate(zip(SUBJECTS, SESSIONS)):
    query_dict = dict(subject_name=sname,
                      session_name=sdate,
                     parameter_set_num = SORTING_PARAMETER_NUM)
    noise_stats = paper.SortingChannelMAD() & query_dict
    mads = np.vstack(noise_stats.fetch('mad'))

    xvals.append(np.random.normal( scale=.05, size=len(mads.flatten())))
    probe_mads.append(mads.flatten())
parts = plt.violinplot(probe_mads, showmedians=True, showextrema=True, showmeans=False)
for pc in parts['bodies']:
    pc.set_edgecolor('black')
    pc.set_facecolor('gray')
    pc.set_alpha(1)
for pc in [parts[k] for k in ['cmaxes', 'cmins', 'cbars', 'cmedians']]:
    pc.set_color('black')
for i,(s,m) in enumerate(zip(xvals,probe_mads)):
    plt.plot(i+s+1,m,'.k')
plt.ylim([0,30]);
