In [None]:
import matplotlib.pyplot as plt
from scipy import signal
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import numpy as np
from collections import defaultdict
import seaborn as sns
from numpy.fft import fft, ifft, fftfreq
import scipy
from scipy import interpolate
import pandas as pd
from FWHM import FWHM
from scipy.signal import savgol_filter
from scipy.signal import welch
from scipy.integrate import simps
import pickle
import os, sys, json
import gc
from scipy.interpolate import interp1d
from joblib import dump, load

# Ephys loading

In [None]:
%run importrhdutilities.py

In [None]:
filename = os.getcwd()+'PATH' 
result, data_present = load_file(filename)
raw_data = result['amplifier_data']

In [None]:
stim = result['board_dig_in_data'][0]
fs = 20000
onsets = []
for i in range(1,len(stim)):
    if stim[i-1] == False and stim[i] == True:  
        onsets.append(i)  
onset_time = np.array(onsets)/fs

# DA data loading

In [None]:
DA = pd.read_excel( os.getcwd() + 'PATH', engine='openpyxl')
DA_data = DA['NAME']
t_DA = DA['NAME']

In [None]:
x_non_nan = t_DA[~np.isnan(DA_data)]
y_non_nan = DA_data[~np.isnan(DA_data)]
f = interp1d(x_non_nan, y_non_nan, kind='linear', fill_value="extrapolate")
DA_data_interp = f(t_DA[~np.isnan(t_DA)])

# Pin Mapping

In [None]:
# Mice mapping
pin_map = {}
pin_map[0] = [0,0]
pin_map[1] = [4,0] 
pin_map[2] = [2,0]
pin_map[3] = [6,0]
pin_map[4] = [7,0]
pin_map[5] = [8,0]
pin_map[6] = [5,0]
pin_map[7] = [3,0]
pin_map[8] = [1,0]



In [None]:
# Rat Mapping 1
pin_map = {}
pin_map[0] = [13,0]
pin_map[1] = [11,0]
pin_map[2] = [14,0]
pin_map[3] = [12,0]
pin_map[4] = [8,0]
pin_map[5] = [10,0]
pin_map[6] = [9,0]
pin_map[7] = [7,0]
pin_map[8] = [3,0]
pin_map[9] = [6,0]
pin_map[10] = [5,0]
pin_map[11] = [4,0]
pin_map[12] = [0,0]
pin_map[13] = [2,0]
pin_map[14] = [1,0]

In [None]:
# Rat mapping 2
pin_map = {}
pin_map[0] = [8,0]
pin_map[1] = [14,0]
pin_map[2] = [9,0]
pin_map[3] = [4,0]
pin_map[4] = [13,0]
pin_map[5] = [12,0]
pin_map[6] = [11,0]
pin_map[7] = [10,0]
pin_map[8] = [6,0]
pin_map[9] = [2,0]
pin_map[10] = [1,0]
pin_map[11] = [0,0]
pin_map[12] = [3,0]
pin_map[13] = [5,0]
pin_map[14] = [7,0]



# Filtering

In [None]:
trim_start = TIME1
trim_end = TIME2
fs = 20000
raw_data_copy = raw_data.copy()
raw_data_cut = raw_data_copy[:,int(trim_start*fs):int(trim_end*fs)]

In [None]:
order = 2
low = 1
high = 300
fs = 20000
filtered = []
b,a = signal.butter(order, [low, high], 'bp', fs=fs)
for row in raw_data_cut:
  filtered.append(signal.lfilter(b, a, row))
filtered = np.array(filtered)
med = np.median(filtered,axis=0)
filtered = filtered - med

# Data Trimming

In [None]:
data_to_process = filtered

In [None]:
stride = 120
for i in range(5):
    pos = i*stride*fs + pos1 - 30*i
    r = 500
    for ch in range(ch_NO):
        for n in range((stride+1)*10):
            data_to_process[ch][(pos+2000*n):(r+pos+2000*n)] = 0
#

In [None]:
for ch in range(ch_NO):
    for n in range(50):
        filtered[ch][(pos1+2000*n):(r+pos1+2000*n)] = 0

# Artifact removal

In [None]:
N = 5*20000
T = 1/20000
peak_to_cut = []
yprobe = []
CH_NO = 15
data_to_process = filtered
round = int(np.round((len(data_to_process[0])/fs-4)/2.5+1))
for n in range(1999):
  peak = np.arange(50*n+45,50*n+56) 
  peak_to_cut = np.concatenate((peak_to_cut, peak), axis=None)
for ch in range(CH_NO):
  yall = []
  prev = []
  y_sum = []
  for n in range(round):
    window1 = int(2.5*n*fs)
    window2 = int((2.5*n+5)*fs)
    yf = fft(data_to_process[ch][window1:window2])
    xf = fftfreq(N, T)
    xf_sym = np.concatenate((xf[N//2:N], xf[0:N//2]), axis=None)
    yf_sym = np.concatenate((yf[N//2:N], yf[0:N//2]), axis=None)
    xf_cut = np.delete(xf_sym,peak_to_cut.astype(int))
    yf_cut = np.delete(yf_sym,peak_to_cut.astype(int))
    pp = interpolate.PchipInterpolator(xf_cut, yf_cut, axis=0, extrapolate=None)
    yn = pp(xf_sym)
    xb = np.concatenate((xf_sym[N//2:N], xf_sym[0:N//2]), axis=None)
    yb = np.concatenate((yn[N//2:N], yn[0:N//2]), axis=None) 
    yi = ifft(yb)
    yall = np.concatenate((yall, yi[int(1.5*fs):4*fs]), axis=None)
  yprobe.append(yall)
yprobe = np.real(yprobe)

# Spike Sorting

In [None]:
import spikeinterface.full as si
import probeinterface as pi
from probeinterface.plotting import plot_probe
from probeinterface import Probe
import spikeinterface.preprocessing as spre
import mountainsort5 as ms5
from mountainsort5.util import create_cached_recording
from tempfile import TemporaryDirectory
from spikeinterface.core import concatenate_recordings
from spikeinterface import extractors as se

In [None]:
file = os.getcwd()+'PROBEMAP'
mapping = np.array(pd.read_excel(file))

probe = Probe(ndim=2, si_units='um')
probe.set_contacts(positions=mapping, shapes='square', shape_params={'width': 10})

In [None]:
channel_indices = np.arange(ch_NO)
probe.set_device_channel_indices(channel_indices)
print(probe.device_channel_indices)

In [None]:
raw_rec = clean_recording.set_probe(probe)

In [None]:
sub_recording = raw_rec
recording_preprocessed = spre.whiten(sub_recording, dtype=np.float32)
sorting = ms5.sorting_scheme2(
recording_preprocessed,
sorting_parameters=ms5.Scheme2SortingParameters(
        phase1_detect_channel_radius=150,
        detect_channel_radius= 50,
        training_duration_sec = 60,
        phase1_npca_per_channel = 10,
        phase1_detect_threshold = 5.5,
        detect_threshold = 5.5
    ),
    )

In [None]:
sorting_data = filtered
ch_to_unit = defaultdict(dict)
for unit_id in sorting.get_unit_ids():
    peaks = sorting.get_unit_spike_train(unit_id)
    v = {}
    for channel in range(ch_NO):
        temp = []
        for idx, peak in enumerate(peaks):
            if idx < 500:
                temp.append(sorting_data[channel][peak])
        v[channel] = np.mean(temp)
    ch_to_unit[unit_id] = min(v, key=lambda k: v[k])
ch_to_unit

In [None]:
spike_time_pre = {}
mean_waveform_pre = {}
for unit_id in sorting.get_unit_ids():
  unit = 0
  channel = ch_to_unit[unit_id]
  while (channel,unit) in spike_time_pre.keys():
    unit = unit + 1
  spike_time_pre[channel, unit] = sorting.get_unit_spike_train(unit_id)
  mean_waveform_pre[channel, unit] = waveform[channel, unit_id]

# Curation

In [None]:
pca = PCA(n_components=3)
components = {}
aligned_ch = mean_waveform_pre
fig, axs = plt.subplots(ch_NO, 1, sharex = 'col', sharey = 'row')

for channel, unit_id in spike_time_pre.keys():
  x = channel#pin_map[channel][0]
  if len(aligned_ch[channel, unit_id]) < 3:
    axs[x].set_ylim([-20, 20])
    axs[x].tick_params(axis='both', which='major', labelsize=10)
    continue
  aligned_std = StandardScaler().fit_transform(aligned_ch[channel, unit_id])
  principal_components = pca.fit_transform(aligned_std)
  components[channel,unit_id] = principal_components
  axs[x].scatter(components[channel,unit_id][:,0], components[channel,unit_id][:,1], s=3, c='black')
  axs[x].set_xlim([-40, 40])
  axs[x].set_ylim([-40, 40])
  axs[x].tick_params(axis='both', which='major', labelsize=10)
  fig.set_figheight(50)
  fig.set_figwidth(5)

In [None]:
# spike curation
color={}
color[0] = 'steelblue'
color[1] = 'orange'
color[2] = 'seagreen'
color[3] = 'red'
color[4] = 'purple'
color[5] = 'gray'
color[6] = 'black'
color[7] = 'pink'
fs = 20000

channel = ch_NO
unit_id = 0
k =1
wave_cluster = {}
spk_time = spike_time_pre[channel, unit_id]
spk_wave = mean_waveform_pre[channel, unit_id]
timing = {}
comp = components[channel, unit_id]
compo = {}
kmeans = KMeans(n_clusters=k, init='k-means++', max_iter=500, n_init=10, random_state=0).fit(comp)
fig = plt.figure(figsize = (10, 7))
for cluster in range(k):
  clust = np.where(kmeans.labels_ == cluster)[0]
  timing[cluster] = spk_time[clust]
  compo[cluster] = comp[clust]
  plt.scatter(comp[clust,0],comp[clust,1], s=30, c=colors[cluster], alpha = 1)
  plt.xlim([-50, 50])
  plt.ylim([-50, 50])

single_unit_time = defaultdict(list)
single_unit_waveform = defaultdict(list)

fig = plt.figure(figsize = (10, 7))
fig, axs = plt.subplots(1, k+1, sharex = 'col', sharey = 'row')
for cluster in range(k):
    t = np.linspace(0,4,80)
    clust = np.where(kmeans.labels_ == cluster)[0]
    for index in clust:
          single_unit_time[channel,unit_id,cluster].append(spk_time[index])
          single_unit_waveform[channel,unit_id,cluster].append(spk_wave[index])
          axs[cluster].plot(t, spk_wave[index],color=colors[cluster],linewidth=6)
    mean_spike = np.mean(single_unit_waveform[channel,unit_id,cluster],axis = 0)
    half_max = (max(mean_spike) + min(mean_spike))/2
    axs[cluster].plot(t, mean_spike,color = 'white',linewidth=12)
    axs[cluster].axhline(half_max ,color = 'red',linewidth=12)
    axs[cluster].tick_params(axis='both', which='both', labelbottom=True, labelleft=True, labelsize = 50)

    fig.set_figheight(20)
    fig.set_figwidth(50)
plt.ylim(-200,100)
plt.axis('off')

for cluster in range(k):
    noise = []
    SNR = []
    peak = []
    mean_waveform = np.mean(single_unit_waveform[channel,unit_id,cluster],axis = 0)
    average = signal.resample(mean_waveform,5*len(t))
    p = np.where(average == max(average))[0][0]
    v = np.where(average == min(average))[0][0]
    firing_rate = len(single_unit_time[channel,unit_id,cluster])/((len(input[0])/fs))
    if not FWHM(average):
        fwhm = 0
    else:
        fwhm = FWHM(average)
    pvt = np.abs(p-v)
    isi_single = np.diff(single_unit_time[channel,unit_id,cluster])
    isi_violation = np.abs((len(np.where(isi_single<2*fs/1000)[0])))/(len(single_unit_time[channel,unit_id,cluster]))
    isi_cluster = np.diff(single_unit_time[channel,unit_id,cluster])
    for stamp in single_unit_time[channel,unit_id,cluster]:
      peak_select = int(stamp)
      noise_select = input[channel][peak_select-36:peak_select-28]
      peak_to_peak = max(mean_waveform) - min(mean_waveform)
      rms_noise = np.sqrt(np.mean(noise_select**2))
      SNR.append(peak_to_peak/rms_noise) 
      peak.append(peak_to_peak)
      noise.append(rms_noise)
    noise_med = np.median(noise)
    SNR_med = np.median(SNR)
    peak_med = np.median(peak)
    plt.show()
    print('')
    print("Channel:%d |"%(channel),"cluster:%d |"%(cluster), "SNR:%f |"%(SNR_med), "peak_to_peak signal:%f |"%(peak_med),"rms_noise:%f |"%(noise_med),'firing rate:%fHz '%(firing_rate))

    print('FWHM:%fms |'%(fwhm/(5*fs/1000)), 'pvt:%fms |'%(pvt/(5*fs/1000)), 'ISI violation:%f'%(isi_violation*100)+"%")
    
    print('')

    if fwhm/(5*fs/1000) < 0.15 or fwhm/(5*fs/1000) > 0.75:
      print("fwhm did not pass!")
    if pvt/(5*fs/1000) < 0.15 or pvt/(5*fs/1000) > 0.85:
      print("peak-to-valley did not pass!")
    if isi_violation > 0.015:
      print("ISI violation over 1%!")
    if firing_rate < 0.1:
      print('firing rate smaller than 0.1 Hz!')
    if SNR_med < 4:
      print("low spike SNR!")  
    print('*******************************')

    print('')


In [None]:
cluster =0
single_unit_time = defaultdict(list)
single_unit_waveform = defaultdict(list)
clust = np.where(kmeans.labels_ == cluster)[0]
plt.figure(figsize = (5,7))
single_unit_time[channel,cluster].append(spk_time[index])
single_unit_waveform[channel,cluster].append(spk_wave[index])
plt.plot(t, spk_wave[index], color=colors[6],linewidth=1)
plt.plot(t, np.mean(single_unit_waveform[channel,cluster],axis = 0),color = 'white',linewidth=5)
plt.xlim(0.5,3.5)
plt.ylim(-200,150)
plt.axis('off')

In [None]:
# Initialize once only for one dataset 
mean_waveform_all={}
spike_time_all = {}

In [None]:
# Record curated spikes
unit = 0
while (channel,unit) in spike_time_all.keys():
    unit = unit + 1
spike_time_all[channel, unit] = single_unit_time[channel, cluster]
mean_waveform_all[channel, unit] = np.mean(single_unit_waveform[channel, cluster],axis = 0)
plt.plot(mean_waveform_all[channel, unit], linewidth =15)
plt.axis('off')

In [None]:
#show it in table
#input = filtered
pvt = {}
firing_rate = {}
fwhm = {}
isi_violation = {}
chan = []
clusterid = []
noise_med = {}
peak_med = {}
SNR_med = {}
L_ratio = {}
sh = []
el= []
for ch, cluster in spike_time_all.keys():
    SNR = []
    peak = []
    noise = []
    mean_waveform = mean_waveform_all[ch, cluster]
    upsampled_mean = signal.resample(mean_waveform,5*80)
    p = np.where(upsampled_mean == max(upsampled_mean))[0][0]
    v = np.where(upsampled_mean == min(upsampled_mean))[0][0]
    firing_rate[ch,cluster] = len(spike_time_all[ch,cluster])/((len(input[0])/fs))
    pvt[ch,cluster] = np.abs(p-v)/(5*fs/1000)
    isi_single = np.diff(spike_time_all[ch, cluster])
    isi_violation[ch, cluster] = len(np.where(isi_single<(2*fs/1000))[0])/(len(spike_time_all[ch,cluster])+1)
    for stamp in spike_time_all[ch,cluster]:
        peak_select = int(stamp)
        noise_select = input[ch][peak_select-36:peak_select-28]
        peak_to_peak = max(mean_waveform) - min(mean_waveform)
        rms_noise = np.sqrt(np.mean(noise_select**2))
        SNR.append(peak_to_peak/rms_noise)
        peak.append(peak_to_peak)
        noise.append(rms_noise)
    noise_med[ch, cluster] = np.median(noise)
    SNR_med[ch,cluster] = np.median(SNR)
    peak_med[ch,cluster] = np.median(peak)
    if not FWHM(upsampled_mean):
        fwhm[ch,cluster] = 'NA'
    else:
        fwhm[ch,cluster] = FWHM(upsampled_mean)/100
df = pd.DataFrame()
for key in spike_time_all.keys():
    chan.append(key[0])
    clusterid.append(key[1])


df['channel'] = chan
df['cluster'] = clusterid 
#df['firing rate'] = firing_rate.values()
df['FWHM'] = fwhm.values()
df['PVT'] = pvt.values()
df['ISI violation'] = isi_violation.values()
df['ISI violation'] = df['ISI violation'].apply(lambda x: f'{x*100:.2f}%')
df['SNR'] = SNR_med.values()
df['signal'] = peak_med.values()
df = df.round(2)
df

In [None]:
#spike raster
injection_time = 129
time_fscv = injection_time-15
time_ephys = time_fscv + onset_time[0] - 1.5
fs = 20000
L = 500
plt.figure(figsize=(40,5))
unit_no = len(spike_time_all)
pos = 0
color1 = [72/255,148/255,162/255]
color2= [243/255,173/255,17/255]
for ch in [1,4,5,6,7,2,0,14,8,13,3,12,9,10,11]:
    for clust in range(2):
        if (ch,clust) in spike_time_all.keys():
            for timestamp in spike_time_all[ch,clust]:
                ymin=(len(spike_time_all)-pos)*2/unit_no/2
                ymax=((len(spike_time_all)-pos)*2+1)/unit_no/2
                plt.axvline(timestamp, ymin-0.05, ymax-0.05, color = color1 ,linewidth = 3)
                #print([ymin,ymax])
            pos = pos + 1
plt.axis('off')
plt.xlim(int((time_ephys-trim_start)*fs),int((time_ephys-trim_start+200)*fs))
plt.show()

# LFP analysis

In [None]:
injection_time = 55
time_fscv = injection_time - onset_time[0] - 15
duration =300
time_ephys = time_fscv + onset_time[0] - 1.5

In [None]:
#spectrogram
t0 = time_ephys
L = int((round-1)*2.5+5) - 5
first_section = yprobe
data_to_plot = yprobe
background_start = 20
background_end = 30
nperseg = 4000
noverlap = 3500
vmin = 0
vmax = 5

chs = [0,8,2,1,6,3,4,5] 
row = len(chs)
fig, axs = plt.subplots(row, col, sharex='col', figsize=(200, 45))  # Removed sharey='row' as we'll have different y-axes
for x, ch in enumerate(chs):
    _, _, Sxx_background = signal.spectrogram(first_section[ch][int(background_start*fs):int(background_end*fs)], fs=fs, nperseg=nperseg, noverlap=noverlap, mode='psd')
    avg_background = np.mean(Sxx_background, axis=1, keepdims=True)
    f, t, Sxx = signal.spectrogram(np.real(data_to_plot[ch][int(t0*fs):int((t0+duration)*fs)]), fs = 20000, nperseg=nperseg, noverlap=noverlap, mode='psd')
    Sxx_corrected = Sxx - avg_background
    cax = axs[x].imshow(Sxx_corrected, aspect='auto', origin='lower', extent=[min(t), max(t), min(f), max(f)], cmap="plasma", interpolation='bilinear', vmin=vmin, vmax=vmax)
    axs[x].axvline(x=injection_time - time_fscv, color='white', linestyle='--', linewidth = 3)
    #axs[x].set_title('channel%d' % (ch), size=40)
    axs[x].set_ylim([1, 150])
    #axs[x].set_xlim([time_ephys-trim_start ,time_ephys+duration-trim_start])D
    cbar = fig.colorbar(cax,ax=axs[x])
    cbar.ax.tick_params(labelsize=30)
    axs[x].tick_params(axis='both', which='major', labelbottom=True, labelleft=True, labelsize=30)
    
# Set the figure size
fig.set_size_inches(25, 20)
plt.tight_layout()
plt.show()


# Power calculation

In [None]:
#load DA data
DA = pd.read_excel( os.getcwd() + 'PATH', engine='openpyxl')
DA_data = DA['NAME']
t_DA = DA['NAME']

x_non_nan = t_DA[~np.isnan(DA_data)]
y_non_nan = DA_data[~np.isnan(DA_data)]
f = interp1d(x_non_nan, y_non_nan, kind='linear', fill_value="extrapolate")
DA_data_interp = f(t_DA[~np.isnan(t_DA)])

In [None]:
window_size = 30
f_DA = 10
time_fscv = 462
end_fscv = 712
length = end_fscv - time_fscv
time_ephys = time_fscv + onset_time[0] - 1.5
window = np.ones(int(window_size))/float(window_size)
smoothed_data = np.convolve(DA_data_interp, window, 'same')
plt.figure(figsize=(6,4))
plt.plot(t_DA, smoothed_data, linewidth = 5, color = 'seagreen')
plt.xlim(time_fscv,end_fscv)
plt.ylim(0,80)
plt.tick_params(axis='both', which='major', labelsize=10)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.axis('off')

In [None]:
# filter data in bands
order = 2
low =60
high =100

fs = 20000
filtered = []
b,a = signal.butter(order, [low, high], 'bp', fs=fs)
for row in raw_data_cut:
  filtered.append(signal.lfilter(b, a, row))
filtered = np.array(filtered)
med = np.median(filtered,axis=0)
filtered = filtered - med

N = 5*20000
T = 1/20000
peak_to_cut = []
yprobe = []
CH_NO = 15
data_to_process = filtered
round = int(np.round((len(data_to_process[0])/fs-4)/2.5+1))
for n in range(1999):
  peak = np.arange(50*n+45,50*n+56) 
  peak_to_cut = np.concatenate((peak_to_cut, peak), axis=None)
for ch in range(CH_NO):
  yall = []
  prev = []
  y_sum = []
  for n in range(round):
    window1 = int(2.5*n*fs)
    window2 = int((2.5*n+5)*fs)
    yf = fft(data_to_process[ch][window1:window2])
    xf = fftfreq(N, T)
    xf_sym = np.concatenate((xf[N//2:N], xf[0:N//2]), axis=None)
    yf_sym = np.concatenate((yf[N//2:N], yf[0:N//2]), axis=None)
    xf_cut = np.delete(xf_sym,peak_to_cut.astype(int))
    yf_cut = np.delete(yf_sym,peak_to_cut.astype(int))
    pp = interpolate.PchipInterpolator(xf_cut, yf_cut, axis=0, extrapolate=None)
    yn = pp(xf_sym)
    xb = np.concatenate((xf_sym[N//2:N], xf_sym[0:N//2]), axis=None)
    yb = np.concatenate((yn[N//2:N], yn[0:N//2]), axis=None) 
    yi = ifft(yb)
    yall = np.concatenate((yall, yi[int(1.5*fs):4*fs]), axis=None)
  yprobe.append(yall)
yprobe = np.real(yprobe)
round

In [None]:
# calculate power
delay = onset_time[0]
time_ephys = time_fscv + onset_time[0] - 1.5 - delay
start_point = int((time_ephys - trim_start)*fs)
step = 0.1
fs = 20000
time = np.linspace(0,length,int(length/step))
delta_power = defaultdict(list)
for idx, ch in enumerate([1, 4, 5, 6, 7, 2, 0, 14, 8, 13, 3, 9, 12, 10, 11]):
    delta_power[ch] = []
    for n in range(len(time)):
        filtered_signal = yprobe[ch][int(start_point + n*step*fs):int(start_point + (n+1)*step*fs)]
        signal_power = np.mean(np.square(filtered_signal))
        if signal_power <10000:
            delta_power[ch].append(signal_power)
        else:
            delta_power[ch].append(np.nan)
for idx, ch in enumerate([1, 4, 5, 6, 7, 2, 0, 14, 8, 13, 3, 9, 12, 10, 11]):
    delta_power[ch] = np.array(delta_power[ch])       
    x_non_nan = time[~np.isnan(delta_power[ch])]
    y_non_nan = delta_power[ch][~np.isnan(delta_power[ch])]
    f = interp1d(x_non_nan, y_non_nan, kind='linear', fill_value="extrapolate")
    delta_power[ch] = f(time)
power_smooth = {}
window_size = 30
for idx, ch in enumerate([1, 4, 5, 6, 7, 2, 0, 14, 8, 13, 3, 9, 12, 10, 11]):
    delta_power[ch] = np.array(delta_power[ch])
    window = np.ones(int(window_size))/float(window_size)
    power_smooth[ch] =  np.convolve(delta_power[ch], window, 'same')
    
ch = 11

plt.figure(figsize=(5,2))
plt.plot(time, power_smooth[ch],linewidth = 3, color = 'cadetblue')
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.tick_params(axis='both', which='major', labelbottom=True, labelleft=True, labelsize=12)

In [None]:
#calculate correlation
corr_ch = []
lag_max_corr = {}
corr_ch_shift = []
corr_normalized = {}
for ch in [1, 4, 5, 6, 7, 2, 0, 14, 8, 13, 3, 9, 12, 10, 11]:
    LFP =  power_smooth[ch]
    lags = np.arange(-int(len(DA_cut))+1, int(len(DA_cut)))
    corr = np.correlate(DA_cut - np.mean(DA_cut), LFP - np.mean(LFP), mode='full')
    denom = np.sqrt(np.sum((DA_cut - np.mean(DA_cut)) ** 2) * np.sum((LFP - np.mean(LFP)) ** 2))
    corr_normalized[ch] = corr / denom
    lag_max_corr[ch] = lags[np.argmax(corr_normalized[ch])]

    plt.figure(figsize=(10, 2))
    plt.plot(lags, corr_normalized[ch], label='Cross-Correlation')
    plt.axvline(lag_max_corr[ch], color='red', linestyle='--', label=f'Max Correlation at lag {lag_max_corr[ch]}')
    plt.title('channel%d'%(ch))
    plt.xlabel('Lags of bins (100ms)')
    plt.ylabel('Correlation Coefficient')
    plt.legend()

    print(max(corr_normalized[ch]))
    corr_ch.append(corr_normalized[ch][lags == 2])
    corr_ch_shift.append(corr_normalized[ch][lags == lag_max_corr[ch]])
    plt.show()
    
ch_str = ['1', '4', '5', '6', '7', '2', '0', '14', '8', '13', '3', '9', '12', '10', '11']
plt.figure(figsize = [10,1])
plt.plot(ch_str, corr_ch, linewidth = '5', color = 'purple')
plt.xlabel('Channel No (from bottom to tip)')
plt.ylabel('Correlation Coefficient')
print("NAc:%f"%(max(corr_ch[6:16])))
print("Ctx:%f"%(max(corr_ch[0:6])))


cov_ch = []
for ch in [1, 4, 5, 6, 7, 2, 0, 14, 8, 13, 3, 9, 12, 10, 11]:
    LFP =  power_smooth[ch]
    cov_matrix = np.cov(DA_cut, LFP)
    cov_ch.append(cov_matrix[0, 1])
plt.figure(figsize = [10,2])
plt.plot(ch_str, cov_ch, linewidth = '5', color = 'purple')
plt.xlabel('Channel No (from bottom to tip)')
plt.ylabel('Correlation Coefficient')

# Plotting PSD

In [None]:
yprobe = np.real(yprobe)
fs=20000
fig, axs = plt.subplots()

f, Pxx_den = signal.welch(yprobe[ch][16*fs:19*fs],fs, nperseg = 4000)
Pxx_smooth = signal.savgol_filter(Pxx_den, 3, 1)
f2, Pxx_den2 = signal.welch(yprobe[ch][int(28*fs):int(31*fs)],fs, nperseg = 4000)
Pxx_smooth2 = signal.savgol_filter(Pxx_den2, 3, 1)
Pxx_smooth = signal.savgol_filter(Pxx_den, 3, 1)
axs.plot((f), 10*np.log10(Pxx_smooth), c='black', linewidth=4)
axs.plot((f), 10*np.log10(Pxx_smooth2), c='red', linewidth=4)
axs.tick_params(axis='both', which='major', labelsize = 20)
axs.set_xlim([0, 100])
axs.set_ylim([-10, 30])


fig.set_figheight(5)
fig.set_figwidth(5)


plt.show()