In [304]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import seaborn
import networkx as nx
import pandas as pd
import scipy.stats as stt
import scipy as sp
import re
import statsmodels.api as sm

clrs = seaborn.color_palette()

seaborn.set(font_scale=1.5,style='ticks')



sys.path.append("/Users/yves/Documents/Code/mec_ephys/packages/")
import mecll


In [496]:
ephys_data_root = "/Users/yves/Documents/ephys_data/2020_12_17//"

In [497]:
# load ephys data
ephys_dataset = mecll.hpc.load_ephys_data(ephys_data_root)

#if '12_17' in ephys_data_root:
#    ephys_dataset.offset = 800400

/Users/yves/Documents/ephys_data/2020_12_17//events/Rhythm_FPGA-107.0/TTL_1/timestamps.npy


In [498]:
#path = "/Users/yves/Documents/ephys_data/data/line_loop_ephys1/'456674_10'-2020-12-15-134119.txt"
path = "/Users/yves/Documents/ephys_data/data/line_loop_ephys1/'456674_10'-2020-12-17-151542.txt"



In [499]:
behaviour_dset = mecll.hpc.load_behavioural_data(path)

In [500]:
transition_count_dicts = mecll.hpc.get_transitions_count_dict(behaviour_dset.dat_dict['port'],
                                                              behaviour_dset.task_times)


In [501]:
seq0 = mecll.hpc.get_seq_from_transitions(transition_count_dicts[0])
seq1 = mecll.hpc.get_seq_from_transitions(transition_count_dicts[1])

seqs = [seq0,seq1]



In [502]:
behaviour_dset.task_times

[[[0, 358317], [1401702, 1866505]], [[358317, 1401702], [1866504, 5128921.0]]]

In [503]:
aligner = mecll.rsync.Rsync_aligner(behaviour_dset.rsync_times_behaviour,ephys_dataset.rsync_times_spike[::2],units_A=1,units_B=1)

In [504]:
aligned_spike_times = aligner.B_to_A(ephys_dataset.unaligned_spike_times) + 800400#+ ephys_dataset.offset

In [505]:
#np.nanmax(aligned_spike_times)/83230564.34827931

In [506]:
# get the distance matrices
task_distance = mecll.hpc.run_rsa.line_distance_matrix(seq0)
task_distance = mecll.hpc.run_rsa.line_distance_matrix(seq0)
spatial_distance = mecll.hpc.run_rsa.get_spatial_distance_matrix(mecll.poke_pos,seq0)

In [507]:
def get_binned_spikes(spks,bin_size=10):
    """ Takes in spike times of units and returns binned spike
        rates (no smoothing) at specified resolution. Bin size
        is in ms
    """

    #30 because sampling rate of the ephys is 
    maxT = (np.nanmax([np.nanmax(i) for i in spks])/30.)/bin_size

    spk_arr = np.zeros([len(spks),int(np.ceil(maxT))])
    for i,u in enumerate(spks):
        spk_arr[i,np.floor(u/30/bin_size).astype("int")[:,0]] = 1
    
    return spk_arr


In [508]:
def get_unit_spike_lists(spkT,spkC,unit_ids=None):
    """ Takes in essentially the kilosort/phy output and returns
        lists with the spike times per unit
        
        Arguments:
        ===================================
        
        spkT: spike times
        
        spkC: cluster membership of each spike
        
        unit_ids: units you want to sort
    """
    
    if unit_ids is None:
        unit_ids = np.arange(len(np.unique(spkC)))
    
    spks = []
    for uid in unit_ids:
        tmp = spkT[np.where(spkC==uid)[0]]
        spks.append(tmp[np.where(np.logical_not(np.isnan(tmp)))[0]])
    return spks

In [509]:
#ephys_dataset.spike_clusters

In [510]:
unit_ids = np.where(ephys_dataset.cluster_quality=='good')[0]
nUnits = len(unit_ids)

bin_size = 1
spks = get_unit_spike_lists(aligned_spike_times, ephys_dataset.spike_clusters, unit_ids=unit_ids)
spks = [i for i in spks if len(i)>0]

spike_arr = get_binned_spikes(spks,bin_size=1)

In [536]:
task_nr = 0

In [537]:
# get the distance matrices
task_distance = mecll.hpc.run_rsa.line_distance_matrix(seqs[task_nr])
spatial_distance = mecll.hpc.run_rsa.get_spatial_distance_matrix(mecll.poke_pos,seqs[task_nr])

In [538]:
n_neurons = spike_arr.shape[0]
n_ports = len(seq1)

In [539]:
firing_in_ports = np.zeros([n_ports,n_neurons])
poke_counter = np.zeros(n_ports)

In [540]:
sorted_seq = sorted(seqs[task_nr])

In [541]:
#task_nr = 0
for i in behaviour_dset.dat_dict['port']:
    
    if i[0] in sorted_seq:
        #print(2)
        if mecll.hpc.check_in_range(behaviour_dset.task_times[task_nr].copy(),1000*i[2]):
            #print(1)
            t = int(i[2] * 1000)
            port_poked = sorted_seq.index(i[0])
            poke_spk = spike_arr[:,t-100:t+100]  # 100ms window around poke detection
            firing_in_ports[port_poked] += np.nanmean(poke_spk,axis=1)
            poke_counter[port_poked] += 1


    


In [542]:
def remove_diagonal(A):
    removed = A[~np.eye(A.shape[0], dtype=bool)].reshape(A.shape[0], int(A.shape[0])-1, -1)
    return np.squeeze(removed)



In [543]:
#mean_port_firing_rate

In [544]:
mean_port_firing_rate = firing_in_ports/poke_counter[:,None]

In [545]:
task1_corrs = np.corrcoef(mean_port_firing_rate)

In [546]:
zscore = lambda x: x

In [547]:
X = np.vstack([np.ones_like(remove_diagonal(spatial_distance).flatten()),
               zscore(remove_diagonal(spatial_distance).flatten()),
               zscore(remove_diagonal(task_distance).flatten()),
               #zscore(remove_diagonal(task_2_distance).flatten())
              ]).T
y = remove_diagonal(task1_corrs).flatten()
res = sm.OLS(y,X,hasconst=True).fit()

In [548]:

res.summary()

0,1,2,3
Dep. Variable:,y,R-squared:,0.058
Model:,OLS,Adj. R-squared:,-0.011
Method:,Least Squares,F-statistic:,0.8388
Date:,"Thu, 27 Jan 2022",Prob (F-statistic):,0.443
Time:,15:21:56,Log-Likelihood:,3.2238
No. Observations:,30,AIC:,-0.4476
Df Residuals:,27,BIC:,3.756
Df Model:,2,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
const,0.7472,0.238,3.138,0.004,0.259,1.236
x1,-0.0012,0.001,-1.254,0.221,-0.003,0.001
x2,-0.0475,0.044,-1.069,0.295,-0.139,0.044

0,1,2,3
Omnibus:,6.511,Durbin-Watson:,2.057
Prob(Omnibus):,0.039,Jarque-Bera (JB):,4.774
Skew:,0.788,Prob(JB):,0.0919
Kurtosis:,4.155,Cond. No.,981.0
