In [None]:
# Spike count correlations
from labdata import *
from labdata.schema import *

def cv_noise_correlations(responses, num_repeats = 100, fold_fraction = 0.5):
    ''' 
    Cross-validated noise correlations

    responses : array (NCELLS x NTRIALS)
    num_repeats : default 100
    fold_fraction : default 0.5
    '''
    
    x = responses.copy()
    x = ((x.T - np.mean(x,axis = 1))/np.std(x,axis = 1)).T
    rsc =  np.stack([np.corrcoef(x,rowvar = True) for i in 
                     [np.random.choice(np.arange(x.shape[1]), int(x.shape[1]*fold_fraction),
                                       replace = False) for i in range(num_repeats)]]).mean(axis = 0)

    rsc[np.diag_indices(rsc.shape[0])] = np.nan 
    rsc[np.tril_indices(rsc.shape[0])] = np.nan
    return rsc

# computes pairwise spike count correlations for a specific pair of neurons around a specific time
@droplets.dropletsschema
class DropletsSpikeCountParameters(dj.Manual): # to prepare to run spike count correlations accross neurons in a session
    definition = '''
    -> Session
    parameter_set_num : int         # spike sorting parameter set
    unit_criteria_id  : int        # unit criteria id
    '''
    class Condition(dj.Part):
        definition = '''
        -> master
        condition_name : varchar(48)  # 'state name;condition'
        ---

        condition_times : blob  # [offset, duration]
        '''
        
@droplets.dropletsschema
class DropletsSpikeCountsCorrelations(dj.Computed):
    num_repeats = 100
    fold_fraction = 0.7
    # session is not a key here so we can compute the correlations per parameter
    definition = '''
    -> DropletsSpikeCountParameters.Condition
    ---
    num_trials               : int
    num_repeats              : int
    spike_count_correlations : longblob  # spike count correlations
    unit_ids                 : longblob  # unit id
    probe_ids                : longblob  # id of the probe
    '''
    def make(self,key):
        unitssp = (SpikeSorting.Unit() & (UnitCount.Unit() & key & 'passes = 1').proj()).get_spike_times(
            include_metrics = True,
            extra_keys = ['shank','depth'])
        trial_events = droplets.get_trial_state_times(key)
                
        conditions,times = (DropletsSpikeCountParameters.Condition() & key).fetch('condition_name','condition_times')
        
        spks = [u['spike_times'] for u in unitssp]
        spike_counts = []
        rsc = []
        unit_ids = np.array([u['unit_id'] for u in unitssp])
        probe_ids = np.array([u['probe_num'] for u in unitssp])
        shanks = np.array([u['shank'] for u in unitssp])
        depths = np.array([u['depth'] for u in unitssp])
        idx = []
        # sort by depth and probe
        for iprobe in np.unique(probe_ids):
            ii = np.where(probe_ids == iprobe)[0]
            for ishank in np.unique(shanks[ii]):
                iishank = ii[shanks[ii] == ishank]
                iishank = iishank[np.argsort(depths[iishank])]
                idx.extend(iishank)
        
        probe_ids = probe_ids[idx]
        unit_ids = unit_ids[idx]
        spks = [spks[i] for i in idx]
        
        for cond,time in zip(conditions,times):
            event,restriction = cond.split(':')
            trial_selection = np.all(np.vstack([trial_events[k].values== float(v) 
                                                for k,v in [a.split('=') for a in restriction.split('|')]]),axis = 0)
            cnts = [[len(s[(s>=t) & (s<(t+time[1]))]) for t in trial_events[trial_selection][event].values+time[0]] for s in spks]
            x = np.stack(cnts)

            rsc = cv_noise_correlations(x, num_repeats = self.num_repeats, fold_fraction = self.fold_fraction)
            self.insert1(dict(key,
                              num_trials = x.shape[1],
                              num_repeats = self.num_repeats,
                              spike_count_correlations = rsc.astype(np.float32),
                              unit_ids = unit_ids,
                              probe_ids = probe_ids))
    

    

key = dict(subject_name = 'JC140',
           session_name  = "20231206_140953",
           parameter_set_num = 5,
           unit_criteria_id = 1)

# DropletsSpikeCountParameters().insert1(key)
condition_names = ['stim_onset:intensity_values=20|rewarded=1',
                   'stim_onset:intensity_values=-20|rewarded=1',
                   'response_onset:response=1|rewarded=1',
                   'response_onset:response=-1|rewarded=1']
condition_times = [[0.04,1],
                   [0.04,1],
                   [0,1-0.04],
                   [0,1-0.04]]
# for name,times in zip(condition_names,condition_times):
#     DropletsSpikeCountParameters.Condition.insert1(dict(key,
#                                                         condition_name = name,
#                                                         condition_times = times))

DropletsSpikeCountsCorrelations().populate(display_progress = True)



In [None]:
# Spike count correlations
from labdata import *
from labdata.schema import *

key = dict(subject_name = "JC140",
          session_name = "20231206_140953",
          parameter_set_num = 5,
          unit_criteria_id = 1)
session = (Session() & key).proj().fetch1()
unitsel = (UnitCount.Unit() & session & key & 'passes = 1').proj()
# getting units might take some time
unitssp = (SpikeSorting.Unit() & unitsel).get_spike_times()
unitprop = pd.DataFrame((UnitMetrics() & unitsel).fetch())

In [None]:
%matplotlib ipympl
import pylab as plt

session = (Session() & key).proj().fetch1()

unitsel = (UnitCount.Unit() & session & key & 'passes = 1').proj()
# getting units might take some time
unitssp = (SpikeSorting.Unit() & unitsel).get_spike_times()
unitprop = pd.DataFrame((UnitMetrics() & unitsel).fetch())

trial_events = droplets.get_trial_state_times(session)

triggered_spikes, selected_trials, events_to_extract = droplets.trigger_droplets_spiketimes_to_states(
    [u['spike_times'] for u in unitssp],
    trial_events)
#events_to_extract = times_to_extract)

event_names = [d['label'] for d in events_to_extract if not d['ref'] is None]
event_times = [d['ref'] for d in events_to_extract if not d['ref'] is None]


ax1,ax2 = droplets.ephys.interactive_plot_rasters(triggered_spikes, trial_events, selected_trials, unitprop)

if 'event_times' in dir():
    ax1.vlines(event_times,0,np.sum(selected_trials),'k',lw = 1)
    for tm,tx in zip(event_times,event_names):
        ax1.text(tm, np.sum(selected_trials), tx, fontsize = 4, rotation = 0, va = 'bottom')

In [None]:
icell = 719
cellspks = triggered_spikes[icell]
ltrials = np.where((trial_events.intensity_values == -20) & (trial_events.rewarded == 1))[0]
rtrials = np.where((trial_events.intensity_values == 20) & (trial_events.rewarded == 1))[0]
fig = plt.figure(figsize = [4.15, 3.48])
from spks.event_aligned import plot_raster
ax0 = fig.add_axes([0.1,0.1,0.8,0.3])
plot_raster([cellspks[i] for i in ltrials[:100]],colors = 'k')
plot_raster([cellspks[i] for i in rtrials[:100]],offset=len(ltrials[:100]),colors='#d62728')
if 'event_times' in dir():
    plt.vlines(event_times,0,200,'k',lw = 1)
    for tm,tx in zip(event_times,event_names):
        plt.text(tm, 200, tx, fontsize = 4, rotation = 0, va = 'bottom')
ax1 = fig.add_axes([0.1,0.5,0.8,0.4],sharex = ax0)
from spks import binary_spikes, alpha_function
binsize = 0.01
t_decay = 0.05 
t_rise = 0.0001
decay = t_decay/binsize
kern = alpha_function(int(decay*15), t_rise=t_rise, t_decay=decay, srate=1./binsize)
tmin = -1
tmax = 2.8
edges = np.arange(tmin,tmax+binsize/2,binsize)
ed = edges[:-1]+np.diff(edges[:2])/2
sprate = binary_spikes(cellspks, edges = edges, kernel = kern)/binsize
m = sprate[ltrials].mean(axis = 0)
s = sprate[ltrials].std(axis = 0)/np.sqrt(len(ltrials))
plt.plot(ed,m,color = 'k')
plt.fill_between(ed,m-s, m+s,color='k',alpha = 0.4)
m = sprate[rtrials].mean(axis = 0)
s = sprate[rtrials].std(axis = 0)/np.sqrt(len(rtrials))
plt.plot(ed,m,color = '#d62728')
plt.fill_between(ed,m-s, m+s,color = '#d62728',alpha = 0.4)

plt.plot([0.04, 1],[15,15],'k')
plt.plot([event_times[-1], event_times[-1]+1-0.04],[15,15],'k');

plt.savefig('raster_example.pdf')

In [None]:
%matplotlib ipympl
import pylab as plt
from scipy.stats import gaussian_kde
probe_nums = np.array([u['probe_num'] for u in unitssp])
plt.close('all')

def unit_2d_mask(x,sel0,sel1):
    mask = np.zeros((len(x),len(x)),dtype = int)
    mask[x == sel0,:] = 1
    mask[:,x == sel1] += 1
    return mask == 2
    
binsize = 0.01
fig = plt.figure(figsize = [11.31,  6.28])
cnt = 0
ed = np.linspace(-0.001,0.2,100)
res = pd.DataFrame((DropletsSpikeCountsCorrelations*DropletsSpikeCountParameters & key).fetch())
for i in range(len(res)):
    for iprobe0 in [0,1,2,3,4,5]:
        a = fig.add_subplot(4,6,cnt+1)
        cnt+=1
        # plt.figure()
        mask = unit_2d_mask(res['probe_ids'].iloc[i],iprobe0,iprobe0)
        # cnts,a = np.histogram((rsc[mask == 2]),bins = np.arange(-1.1,1.1,binsize))
        # plt.plot(a[:-1],(cnts/np.sum(cnts)),label = f'{iprobe0} - {iprobe0}',alpha = 0.5,color = 'black')
        xx = res['spike_count_correlations'].iloc[i][mask]
        xx = xx[np.isfinite(xx)]
        ft = gaussian_kde(np.abs(xx),  bw_method='scott')
        plt.fill_between(ed,ft.pdf(ed)/100,0,color='lightgray',edgecolor = 'k',alpha = 0.5,
                        label = f'{iprobe0} - {iprobe0}')
        plt.xlim([ed[0],ed[-1]])
        plt.vlines(np.mean(np.abs(xx)),0,0.1,color='r',lw = 1,linestyle='-')
        plt.vlines(np.median(np.abs(xx)),0,0.1,color='b',lw = 1,linestyle='-')

        for iprobe in [0,1,2,3,4,5]:            
            mask = unit_2d_mask(res['probe_ids'].iloc[i],np.min([iprobe0,iprobe]),np.max([iprobe0,iprobe]))
            # cnts,a = np.histogram((rsc[mask == 2]),bins = np.arange(-1.1,1.1,binsize))
            xx = res['spike_count_correlations'].iloc[i][mask]
            xx = xx[np.isfinite(xx)]
            ft = gaussian_kde(np.abs(xx),  bw_method='scott')
            
            if iprobe != iprobe0:
                
                plt.plot(ed,ft.pdf(ed)/100,lw = 1,
                         label = f'{iprobe0} - {iprobe}')
                # plt.plot(a[:-1],(cnts/np.sum(cnts)),label = f'{iprobe0} - {iprobe}')
        plt.ylim([0,0.1])    
        plt.vlines(0,0,np.max(plt.ylim()),'k',lw = 0.3,linestyle='--')
        plt.legend(fontsize = 3)
        if i!=3:
            plt.xticks([])
        if iprobe0 != 0:
            plt.yticks([])

In [None]:
DropletsSpikeCountsCorrelations()
conditions = ['stim_onset:intensity_values=-20|rewarded=1',
              'stim_onset:intensity_values=20|rewarded=1',
              'response_onset:response=-1|rewarded=1',
              'response_onset:response=1|rewarded=1']
results = []
for n in conditions:
    results.append((DropletsSpikeCountsCorrelations*DropletsSpikeCountParameters & dict(key,condition_name = n)).fetch1())
results = pd.DataFrame(results)
results # 1 and 3 are ipsi

In [None]:
results

In [None]:
probe_nums = np.array([u['probe_num'] for u in unitssp])
areas = ['ORB','MD','LP','SCm','STR','DG'] 
binsize = 0.01
cnt = 0
ed = np.linspace(-0.3,0.3,100)
violindata = []
violinlabel = []
for i in range(len(results)): 
    violindata.append([])
    violinlabel.append([])
    for iprobe0 in [0,1,2,3,4,5]:
        cnt+=1

        mask = unit_2d_mask(results['probe_ids'].iloc[i],iprobe0,iprobe0)

        xx = res['spike_count_correlations'].iloc[i][mask]
        xx = xx[np.isfinite(xx)]
        violindata[-1].append([xx])
        violinlabel[-1].append([f'{areas[iprobe0]}-{areas[iprobe0]}'])
        for iprobe in [0,1,2,3,4,5]:            
            mask = unit_2d_mask(results['probe_ids'].iloc[i],np.min([iprobe0,iprobe]),np.max([iprobe0,iprobe]))
            xx = results['spike_count_correlations'].iloc[i][mask]
            xx = xx[np.isfinite(xx)]
            if iprobe != iprobe0:
                violindata[-1][-1].append(xx)
                violinlabel[-1][-1].append(f'{areas[iprobe0]}-{areas[iprobe]}')
# plt.savefig('pairwise_corr_JC140.pdf')
    
fig = plt.figure(figsize = [ 6.2 , 10.03])
icond = 2
print(res.condition_name.iloc[icond],res.condition_name.iloc[icond+1])
for iprobe in [0,1,2,3,4,5]:
    fig.add_subplot(6,1,iprobe+1)
    vpsleft = plt.violinplot(violindata[icond][iprobe],showmeans=True,side='low');
    vpsright = plt.violinplot(violindata[icond+1][iprobe],showmeans=True,side='high');
    for pc in vpsleft['bodies'] + vpsright['bodies']:
        pc.set_facecolor('gray')
        pc.set_edgecolor('black')
        pc.set_alpha(1)
    for pc in vpsright['bodies']:
        pc.set_facecolor('lightgray')
        pc.set_edgecolor('#d62728')
        pc.set_alpha(1)
    [vpsleft[c].set_color('black') for c in ['cmaxes', 'cmins', 'cbars', 'cmeans']]
    [vpsright[c].set_color('#d62728') for c in ['cmaxes', 'cmins', 'cbars', 'cmeans']]
    plt.xticks([1,2,3,4,5,6],violinlabel[icond][iprobe]);
    plt.ylim([-1,1])
    plt.yticks([-1,0,1])
# plt.savefig('violins_condition_response.pdf')

fig = plt.figure(figsize = [ 6.2 , 10.03])
icond = 0
print(res.condition_name.iloc[icond],res.condition_name.iloc[icond+1])
for iprobe in [0,1,2,3,4,5]:
    fig.add_subplot(6,1,iprobe+1)
    vpsleft = plt.violinplot(violindata[icond][iprobe],showmeans=True,side='low');
    vpsright = plt.violinplot(violindata[icond+1][iprobe],showmeans=True,side='high');
    
    for pc in vpsleft['bodies'] + vpsright['bodies']:
        pc.set_facecolor('gray')
        pc.set_edgecolor('black')
        pc.set_alpha(1)
    for pc in vpsright['bodies']:
        pc.set_facecolor('lightgray')
        pc.set_edgecolor('#d62728')
        pc.set_alpha(1)
    [vpsleft[c].set_color('black') for c in ['cmaxes', 'cmins', 'cbars', 'cmeans']]
    [vpsright[c].set_color('#d62728') for c in ['cmaxes', 'cmins', 'cbars', 'cmeans']]
    plt.xticks([1,2,3,4,5,6],violinlabel[icond][iprobe]);
    plt.ylim([-1,1])
    plt.yticks([-1,0,1])
    
# plt.savefig('violins_condition_stim.pdf')

In [None]:
results

In [None]:
results

In [None]:
probe_nums = np.array([u['probe_num'] for u in unitssp])
areas = ['ORB','MD','LP','SCm','STR','DG'] 
binsize = 0.01
ed = np.linspace(-0.3,0.3,100)
violindata = []
violinlabel = []
for i in range(len(results)): 
    violindata.append([])
    violinlabel.append([])
    for iprobe0 in [0,1,2,3,4,5]:
        mask = unit_2d_mask(results['probe_ids'].iloc[i],iprobe0,iprobe0)
        xx = np.abs(results['spike_count_correlations'].iloc[i][mask])
        xx = xx[np.isfinite(xx)]
        violindata[-1].append([xx])
        violinlabel[-1].append([f'{areas[iprobe0]}-{areas[iprobe0]}'])
        for iprobe in [0,1,2,3,4,5]:            
            mask = unit_2d_mask(results['probe_ids'].iloc[i],np.min([iprobe0,iprobe]),np.max([iprobe0,iprobe]))
            xx = np.abs(results['spike_count_correlations'].iloc[i][mask])
            xx = xx[np.isfinite(xx)]
            if iprobe != iprobe0:
                violindata[-1][-1].append(xx)
                violinlabel[-1][-1].append(f'{areas[iprobe0]}-{areas[iprobe]}')

nsamples = 2000 #for test
fig = plt.figure(figsize = [ 6.2 , 10.03])
icond = 2
print(results.condition_name.iloc[icond],results.condition_name.iloc[icond+1])
for iprobe in [0,1,2,3,4,5]:
    fig.add_subplot(6,1,iprobe+1)
    vpsleft = plt.violinplot(violindata[icond][iprobe],showmeans=True,side='low');
    vpsright = plt.violinplot(violindata[icond+1][iprobe],showmeans=True,side='high');
    from scipy.stats import mannwhitneyu,ranksums,ks_2samp
    for i,(x,y) in enumerate(zip(violindata[icond][iprobe],violindata[icond+1][iprobe])):
        res = ks_2samp(np.random.choice(x,nsamples,replace = False),np.random.choice(y,nsamples,replace = False))#mannwhitneyu(x,y,alternative='two-sided')
        if res.pvalue < 0.05:
            plt.text(i+1,0.7,'*')
    for pc in vpsleft['bodies'] + vpsright['bodies']:
        pc.set_facecolor('gray')
        pc.set_edgecolor('black')
        pc.set_alpha(1)
    for pc in vpsright['bodies']:
        pc.set_facecolor('lightgray')
        pc.set_edgecolor('#d62728')
        pc.set_alpha(1)
    [vpsleft[c].set_color('black') for c in ['cmaxes', 'cmins', 'cbars', 'cmeans']]
    [vpsright[c].set_color('#d62728') for c in ['cmaxes', 'cmins', 'cbars', 'cmeans']]
    plt.xticks([1,2,3,4,5,6],violinlabel[icond][iprobe]);
    plt.ylim([0,1])
    plt.yticks([0,1])
plt.title(f'Abs response condition {(results.condition_name.iloc[icond],results.condition_name.iloc[icond+1])}')
plt.savefig('violins_abs_condition_response.pdf')

fig = plt.figure(figsize = [ 6.2 , 10.03])
icond = 0
from scipy.stats import ttest_ind,ks_2samp
for iprobe in [0,1,2,3,4,5]:
    fig.add_subplot(6,1,iprobe+1)
    vpsleft = plt.violinplot(violindata[icond][iprobe],showmeans=True,side='low');
    vpsright = plt.violinplot(violindata[icond+1][iprobe],showmeans=True,side='high');

    for i,(x,y) in enumerate(zip(violindata[icond][iprobe],violindata[icond+1][iprobe])):
        res = ks_2samp(np.random.choice(x,nsamples,replace = False),np.random.choice(y,nsamples,replace = False))#mannwhitneyu(x,y,alternative='two-sided')
        if res.pvalue < 0.05:
            plt.text(i+1,0.7,'*')
        plt.text(i+0.5,0.9,len(x),fontsize = 5)
    # print(res)
    for pc in vpsleft['bodies'] + vpsright['bodies']:
        pc.set_facecolor('gray')
        pc.set_edgecolor('black')
        pc.set_alpha(1)
    for pc in vpsright['bodies']:
        pc.set_facecolor('lightgray')
        pc.set_edgecolor('#d62728')
        pc.set_alpha(1)
    [vpsleft[c].set_color('black') for c in ['cmaxes', 'cmins', 'cbars', 'cmeans']]
    [vpsright[c].set_color('#d62728') for c in ['cmaxes', 'cmins', 'cbars', 'cmeans']]
    plt.xticks([1,2,3,4,5,6],violinlabel[icond][iprobe]);
    plt.ylim([0,1])
    plt.yticks([0,1])
plt.title(f'Abs stimuli condition {(results.condition_name.iloc[icond],results.condition_name.iloc[icond+1])}')
plt.savefig('violins_abs_condition_stim.pdf')

In [None]:
np.sum(np.isfinite(results['spike_count_correlations'].iloc[0]))

In [None]:
rsc,probe_ids,unit_ids = (DropletsSpikeCountsCorrelations() & 
       "condition_name = 'stim_onset:intensity_values=20|rewarded=1'").fetch1('spike_count_correlations',
                                                                              'probe_ids','unit_ids')
%matplotlib ipympl

indices = np.arange(len(unit_ids))
for iprobe in np.unique(probe_ids):
    ii = (probe_ids == iprobe)
    pos = (UnitMetrics() & [dict(key,
                                 probe_num = iprobe,
                                 unit_id = u) for u in unit_ids[ii]]).fetch('position')
    pos = np.stack(pos)
    indices[ii] = indices[ii][np.argsort(np.linalg.norm(pos-np.mean(pos,axis = 0),axis = 1))]
import pylab as plt
plt.figure()
plt.imshow(rsc,cmap = 'RdBu_r',clim = [-0.5,0.5])
ncells = []
for iprobe in np.unique(probe_ids):
    ncells.append(np.sum(probe_ids == iprobe))
plt.hlines(np.cumsum(np.hstack([0,ncells])),0,len(rsc),color = 'k',lw = 0.5)
plt.vlines(np.cumsum(np.hstack([0,ncells])),0,len(rsc),color = 'k',lw = 0.5)
# unit_ids[ii]
# plt.figure()
N = rsc.shape[0]
plt.axis([N,0,N,0])
# plt.plot(np.argsort(probe_ids))


In [None]:
# response
fig = plt.figure(figsize = [10.15,  3.69])
edges = np.arange(-1,1.01,0.05)

for p,n in zip(np.unique(probe_nums),['ORB','MD','LP','SCm','STR','DG']):
    edges = np.arange(-1,1.01,0.05)
    xx = rsc[2][unit_2d_mask(probe_nums,p,p)]
    yy = rsc[3][unit_2d_mask(probe_nums,p,p)]
    cnts = np.histogram2d(xx,yy,bins = edges)
    fig.add_subplot(2,6,(p+1))
    plt.imshow(cnts[0].T[::-1],extent = [edges[0],edges[-1],edges[0],edges[-1]],
               origin='upper',
               cmap=plt.cm.rainbow, norm=plt.matplotlib.colors.LogNorm(),
              aspect = 'equal')
    # plt.colorbar(shrink=0.3,ticks = [1,10,100])
    #plt.scatter(xx,yy,0.3,'k',alpha = 1)
    plt.title(f'{n}-{n} {np.sum(np.isfinite(xx))} pairs',fontsize = 7)
    plt.plot([-1,1],[-1,1],'k',lw = 0.5)
    plt.plot([0,0],[-1,1],'k',lw = 0.5)
    plt.plot([-1,1],[0,0],'k',lw = 0.5)
    maxax  = 1
    plt.axis([-maxax,maxax,-maxax,maxax])
    plt.xticks([-0.5,0,0.5],[])
    plt.yticks([-0.5,0,0.5],[])
    plt.axis()
    if p == 5:
        cax = fig.add_axes([0.93,0.65,0.008,0.1])
        plt.colorbar(shrink=0.3,cax = cax)
    if p == 0:
        plt.xlabel('left response',fontsize = 6)
        plt.ylabel('right response',fontsize = 6)
        plt.xticks([-0.5,0,0.5],[-0.5,0,0.5], fontsize = 5)
        plt.yticks([-0.5,0,0.5],[-0.5,0,0.5], fontsize = 5)
    fig.add_subplot(2,6,6+(p+1))
    edges = np.arange(-1,1.01,0.01)
    xx = xx[np.isfinite(xx)]
    ft = gaussian_kde(xx,  bw_method='scott')
    plt.fill_between(edges,ft.pdf(edges)/100, 0, color='lightgray', 
                     edgecolor = 'k', alpha = 1, label = 'left')
    yy = yy[np.isfinite(yy)]
    ft = gaussian_kde(yy,  bw_method='scott')
    plt.plot(edges, ft.pdf(edges)/100,'r',alpha = 1,label = 'right')
    plt.vlines(np.mean(xx),0.05,0.055,color='r',lw = 2,linestyle='-')
    plt.vlines(np.mean(yy),0.055,0.06,color='k',lw = 2,linestyle='-')
    plt.axis('tight')
    plt.ylim([0,0.07])
    plt.xlim([-0.5,0.5])
    if p == 5:
        plt.legend(fontsize = 5)
    if not p==0:
        plt.yticks([0,0.025,0.05],[])
        plt.xticks([-0.5,0,0.5],[])
    else:
        plt.xlabel('spike count correlation',fontsize=6)
        plt.ylabel('fraction',fontsize = 6)
        plt.yticks([0,0.025,0.05],[0,0.025,0.05], fontsize = 5)
        plt.xticks([-0.5,0,0.5],[-0.5,0,0.5], fontsize = 5)
# plt.savefig('rsc_per_condition_response_single_area.pdf')

In [None]:
results

In [None]:
# stimulus 
fig = plt.figure(figsize = [10.15,  3.69])
edges = np.arange(-1,1.01,0.05)
i = 0
icond = 0
probe_nums = results['probe_ids'].iloc[icond]
for p,n in zip(np.unique(probe_nums),['ORB','MD','LP','SCm','STR','DG']):
    edges = np.arange(-1,1.01,0.05)
    mask = unit_2d_mask(probe_nums,p,p)
    rsc0 = results['spike_count_correlations'].iloc[icond][mask]    
    rsc1 = results['spike_count_correlations'].iloc[icond+1][mask]    
    xx = rsc0
    yy = rsc1
    cnts = np.histogram2d(xx,yy,bins = edges)
    fig.add_subplot(2,6,(p+1))
    plt.imshow(cnts[0].T[::-1],extent = [edges[0],edges[-1],edges[0],edges[-1]],
               origin='upper',
               cmap=plt.cm.rainbow, norm=plt.matplotlib.colors.LogNorm(),
              aspect = 'equal')

    plt.title(f'{n}-{n} {np.sum(np.isfinite(xx))} pairs',fontsize = 7)
    plt.plot([-1,1],[-1,1],'k',lw = 0.5)
    plt.plot([0,0],[-1,1],'k',lw = 0.5)
    plt.plot([-1,1],[0,0],'k',lw = 0.5)
    maxax  = 0.5
    plt.axis([-maxax,maxax,-maxax,maxax])
    plt.xticks([-0.5,0.5],[])
    plt.yticks([-0.5,0.5],[])
    plt.axis()
    if p == 5:
        cax = fig.add_axes([0.93,0.65,0.008,0.1])
        plt.colorbar(shrink=0.3,cax = cax)
    if p == 0:
        plt.xlabel(results.condition_name.iloc[icond],fontsize = 6)
        plt.ylabel(results.condition_name.iloc[icond+1],fontsize = 6)
        plt.xticks([-0.5,0,0.5],[-0.5,0,0.5], fontsize = 5)
        plt.yticks([-0.5,0,0.5],[-0.5,0,0.5], fontsize = 5)
    fig.add_subplot(2,6,6+(p+1))
    edges = np.arange(-1,1.01,0.01)
    xx = xx[np.isfinite(xx)]
    ft = gaussian_kde(xx,  bw_method='scott')
    plt.fill_between(edges,ft.pdf(edges)/100, 0, color='lightgray',
                     edgecolor = 'k', alpha = 1, label = results.condition_name.iloc[icond])
    yy = yy[np.isfinite(yy)]
    ft = gaussian_kde(yy,  bw_method='scott')

    plt.plot(edges, ft.pdf(edges)/100,'r',alpha = 1,label = results.condition_name.iloc[icond+1])
    plt.vlines(np.mean(yy),0.045,0.05,color='r',lw = 2,linestyle='-')
    plt.vlines(np.mean(xx),0.05,0.055,color='k',lw = 2,linestyle='-')
    plt.axis('tight')
    plt.ylim([0,0.06])
    plt.xlim([-maxax,maxax])
    if p == 5:
        plt.legend(fontsize = 5)
        
    if not p==0:
        plt.yticks([0,0.025,0.05],[])
        plt.xticks([-0.5,0,0.5],[])
    else:
        plt.xlabel('spike count correlation',fontsize=6)
        plt.ylabel('fraction',fontsize = 6)
        plt.legend(fontsize = 6)

        plt.yticks([0,0.025,0.05],[0,0.025,0.05], fontsize = 5)
        plt.xticks([-0.5,0,0.5],[-0.5,0,0.5], fontsize = 5)
plt.savefig('rsc_per_condition_stimulus_single_area.pdf')    

In [None]:
# stimulus 
fig = plt.figure(figsize = [10.15,  3.69])
edges = np.arange(-1,1.01,0.05)
i = 0
icond = 2
probe_nums = results['probe_ids'].iloc[icond]
for p,n in zip(np.unique(probe_nums),['ORB','MD','LP','SCm','STR','DG']):
    edges = np.arange(-1,1.01,0.05)
    mask = unit_2d_mask(probe_nums,p,p)
    rsc0 = results['spike_count_correlations'].iloc[icond][mask]    
    rsc1 = results['spike_count_correlations'].iloc[icond+1][mask]    
    xx = rsc0
    yy = rsc1
    cnts = np.histogram2d(xx,yy,bins = edges)
    fig.add_subplot(2,6,(p+1))
    plt.imshow(cnts[0].T[::-1],extent = [edges[0],edges[-1],edges[0],edges[-1]],
               origin='upper',
               cmap=plt.cm.rainbow, norm=plt.matplotlib.colors.LogNorm(),
              aspect = 'equal')

    plt.title(f'{n}-{n} {np.sum(np.isfinite(xx))} pairs',fontsize = 7)
    plt.plot([-1,1],[-1,1],'k',lw = 0.5)
    plt.plot([0,0],[-1,1],'k',lw = 0.5)
    plt.plot([-1,1],[0,0],'k',lw = 0.5)
    maxax  = 0.5
    plt.axis([-maxax,maxax,-maxax,maxax])
    plt.xticks([-0.5,0.5],[])
    plt.yticks([-0.5,0.5],[])
    plt.axis()
    if p == 5:
        cax = fig.add_axes([0.93,0.65,0.008,0.1])
        plt.colorbar(shrink=0.3,cax = cax)
    if p == 0:
        plt.xlabel(results.condition_name.iloc[icond],fontsize = 6)
        plt.ylabel(results.condition_name.iloc[icond+1],fontsize = 6)
        plt.xticks([-0.5,0,0.5],[-0.5,0,0.5], fontsize = 5)
        plt.yticks([-0.5,0,0.5],[-0.5,0,0.5], fontsize = 5)
    fig.add_subplot(2,6,6+(p+1))
    edges = np.arange(-1,1.01,0.01)
    xx = xx[np.isfinite(xx)]
    ft = gaussian_kde(xx,  bw_method='scott')
    plt.fill_between(edges,ft.pdf(edges)/100, 0, color='lightgray',
                     edgecolor = 'k', alpha = 1, label = results.condition_name.iloc[icond])
    yy = yy[np.isfinite(yy)]
    ft = gaussian_kde(yy,  bw_method='scott')

    plt.plot(edges, ft.pdf(edges)/100,'r',alpha = 1,label = results.condition_name.iloc[icond+1])
    plt.vlines(np.mean(yy),0.045,0.05,color='r',lw = 2,linestyle='-')
    plt.vlines(np.mean(xx),0.05,0.055,color='k',lw = 2,linestyle='-')
    plt.axis('tight')
    plt.ylim([0,0.06])
    plt.xlim([-maxax,maxax])
    if p == 5:
        plt.legend(fontsize = 5)
        
    if not p==0:
        plt.yticks([0,0.025,0.05],[])
        plt.xticks([-0.5,0,0.5],[])
    else:
        plt.xlabel('spike count correlation',fontsize=6)
        plt.ylabel('fraction',fontsize = 6)
        plt.legend(fontsize = 6)

        plt.yticks([0,0.025,0.05],[0,0.025,0.05], fontsize = 5)
        plt.xticks([-0.5,0,0.5],[-0.5,0,0.5], fontsize = 5)
plt.savefig('rsc_per_condition_response_single_area.pdf')    

In [None]:

from labdata import *
from labdata.schema import *
UnitMetrics.populate(display_progress = True,processes = 10)

In [None]:
waves = pd.DataFrame((SpikeSorting().Waveforms()*UnitMetrics() & [dict(key,
                                 probe_num = iprobe,
                                 unit_id = u) for u in unit_ids[ii][:5]]).fetch())

chmap = (SpikeSorting & dict(key,probe_num = iprobe)).fetch1('channel_coords')

In [None]:
waves.iloc[0].active_electrodes

In [None]:
import pylab as plt
plt.figure()

for i,w in waves.iterrows():
    t = np.linspace(0,50,w.waveform_median.shape[0])
    for ie,(ch,n) in enumerate(zip(chmap,w.waveform_median.T)):
        if ie in w.active_electrodes:
            plt.plot(t+1500*i+ch[0],n/20+ch[1])
    plt.text(w.position[1]+1500*i,w.position[1],'*')

In [None]:

chmap[:2],np.linalg.norm(chmap - chmap[0],axis = 1)[:2]

In [None]:
indices = np.arange(len(unit_ids))

for iprobe in np.unique(probe_ids):
    ii = (probe_ids == iprobe)
    pos = (UnitMetrics() & [dict(key,
                                 probe_num = iprobe,
                                 unit_id = u) for u in unit_ids[ii]]).fetch('position')
    pos = np.stack(pos)
# unit_ids[ii]
# plt.figure()

# plt.plot(np.argsort(probe_ids))

In [None]:
w = (SpikeSorting.Waveforms() & dict(key,
                                 probe_num = 5,
                                 unit_id = 18)).fetch1('waveform_median')
channel_coords = (SpikeSorting() & dict(key,
                                 probe_num = 5,
                                 unit_id = 18)).fetch1('channel_coords')
from spks.waveforms import waveforms_position
waveforms_position(w.reshape([1,*w.shape]),channel_coords)

In [None]:
from spks.waveforms import estimate_active_channels
madthresh = 3
cluster_waveforms_mean = w.reshape([1,*w.shape]).copy()

nclusters,nsamples,nchannels = cluster_waveforms_mean.shape
N = int(nsamples/3)
# estimate_active_channels(cluscluster_waveforms_mean,2)
peak_amp = [mwave[nsamples//2-N:nsamples//2+N,:].max(axis=0) - 
                mwave[nsamples//2-N:nsamples//2+N,:].min(axis=0)
            for mwave in cluster_waveforms_mean]
channel_mad = np.median(peak_amp)/0.6745

activeidx = []
for p in peak_amp:
        activeidx.append(np.where(p>channel_mad*madthresh)[0])    
    # get the threshold from the median_abs_deviation
activeidx

In [None]:

from spks.utils import bandpass_filter
cluster_waveforms_mean = w.reshape([1,*w.shape]).copy()

cluster_waveforms_mean[0,:,:] = bandpass_filter(cluster_waveforms_mean[0].T,30000,
                                                1000,10000).T

peak_amp = [mwave[nsamples//2-N:nsamples//2+N,:].max(axis=0) - 
                mwave[nsamples//2-N:nsamples//2+N,:].min(axis=0)
            for mwave in cluster_waveforms_mean]
channel_mad = np.median(peak_amp)/0.6745

activeidx = []
for p in peak_amp:
        activeidx.append(np.where(p>channel_mad*madthresh)[0])    
    # get the threshold from the median_abs_deviation
activeidx
plt.figure()
plt.plot(cluster_waveforms_mean[0]+1*np.arange(cluster_waveforms_mean.shape[-1]),'k');
plt.plot(cluster_waveforms_mean[0][:,activeidx[0]]+
         1*np.arange(cluster_waveforms_mean.shape[-1])[activeidx[0]]);

In [None]:
activeidx

In [None]:
from sklearn.metrics import euclidean_distances
UnitMetrics() & [dict(key,
                                 probe_num = iprobe,
                                 unit_id = u) for u in unit_ids[ii]]
# euclidean_distances(pos,pos.mean(axis = 1))

In [None]:
key = {'subject_name': 'JC140', 'session_name': '20231206_140953', 'parameter_set_num': 5, 'unit_criteria_id': 1}

unitssp = (SpikeSorting().Unit() & (UnitCount.Unit() & key & 'passes = 1').proj()).get_spike_times()
trial_events = droplets.get_trial_state_times(key)

cond,times = (DropletsSpikeCountParameters() & key).fetch1('condition_names','condition_times')

spks = [u['spike_times'] for u in unitssp]
spike_counts = []
rsc = []
event,restriction = cond.split(':')
trial_selection = np.all(np.vstack([trial_events[k].values== float(v) 
                                    for k,v in [a.split('=') for a in restriction.split('|')]]),axis = 0)
cnts = [[len(s[(s>=t) & (s<(t+times[1]))]) for t in trial_events[trial_selection][event].values+times[0]] for s in spks]
x = np.stack(cnts)
rsc = np.stack([np.corrcoef(x,rowvar=True) for i in 
                [np.random.choice(np.arange(x.shape[1]),int(x.shape[1]*fold_fraction),
                                  replace = False) for i in range(num_repeats)]]).mean(axis = 0))
rsc[np.diag_indices(rsc.shape[0])] = np.nan
rsc[np.tril_indices(rsc.shape[0])] = np.nan

dict(key,
     num_trials = x.shape[1],
     num_repeats = num_repeats,
     spike_count_correlations = rsc,
     unit_ids = np.array([u['unit_id'] for u in unitssp]),
     probe_ids = np.array([u['probe_num'] for u in unitssp])


In [None]:
rsc[-1]

In [None]:
from labdata.schema import *
# UnitCountCriteria().insert1(dict(unit_criteria_id = 6,sua_criteria = 'isi_contamination < 0.1 & amplitude_cutoff < 0.5 & spike_duration > 0.1 & spike_amplitude > 50'))
UnitCount.populate(display_progress=True, processes = 6)