<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Load-data" data-toc-modified-id="Load-data-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Load data</a></span></li></ul></div>

In [None]:
import IPython
# IPython.Application.instance().kernel.do_shutdown(True)

# %matplotlib notebook 
#import mpld3
#mpld3.enable_notebook()

%matplotlib widget

import os
import sys
import json
import time
import datetime
import pycwt
import statistics
import random
import pickle
import numpy as np
import scipy as sp
import pandas as pd
import seaborn as sns
import sklearn as sk
import tkinter as tk
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn import decomposition
from sklearn.decomposition import PCA
from tkinter import *
from tkinter import ttk
from sklearn import preprocessing
from datetime import date
import matplotlib.dates as mdates

from neurodsp.rhythm import sliding_window_matching
from neurodsp.utils.download import load_ndsp_data
from neurodsp.plts.rhythm import plot_swm_pattern
from neurodsp.plts.time_series import plot_time_series
from neurodsp.utils import set_random_seed, create_times
# Import listed chormap
from matplotlib.colors import ListedColormap
import matplotlib.dates as md
from matplotlib import colors as mcolors
# Scipy
from scipy import signal
from scipy import ndimage
# TKinter for selecting files
from tkinter import Tk     # from tkinter import Tk for Python 3.x
from tkinter.filedialog import askdirectory

# Add my module to python path
sys.path.append("../")

# Own libraries
from Neurogram_short import * # Recording, MyWavelet, MyWaveforms

os.environ['KMP_DUPLICATE_LIB_OK']='True'

## Load data

In [None]:
Tk().withdraw()  # keep the root window from appearing
dir_name = ('../datasets/')
path = '../datasets' # Port B

map_path = '../datasets/cambridge/recording/forearm/forearm_map.csv'

# When using port A: channels=range(0,32,1) by default port B:range(32,64,1)
# Start and dur in samples
# feinstein: channels=[0]
time_start = time.time()
load_from_file=True
downsample = 3             #Only when loading from raw   
start= 0                    # SC chronic: 7*30000              #2*60*10000 # 0*30000 ## start=2*60*10000
dur= None                # SC chronic: 4*30000           # EMG: 35*30000
port = 'Port B'
record = Recording.open_record(path, start=start, dur=dur, 
                               load_from_file=load_from_file, 
                               load_multiple_files=True,
                               downsample=downsample,
                               port=port  ,  # Select recording port
                               map_path=map_path,
                               verbose=0)

# Create directory to save figures
if not os.path.exists('%s/figures/' %(path)):
    os.makedirs('%s/figures/' %(path))
print("Time elapsed: {} seconds".format(time.time()-time_start)) 

#sys.exit()


In [None]:
# Get current time for saving (avoid overwriting)
now = datetime.datetime.now()
current_time = now.strftime("%d%m%Y_%H%M%S")

In [None]:
## Configuration

In [None]:
options_filter = [
    "None", 
    "butter", 
    "fir"]                # Binomial Weighted Average Filter

options_detection = [
    "get_spikes_threshCrossing", # Ojo: get_spikes_threshCrossing needs detects also cardiac 
                                     # spikes, so use cardiac_window. This method is slower
    "get_spikes_method",         # Python implemented get_spikes() method. Faster
    "so_cfar"]                    # Smallest of constant false-alarm rate filter

options_threshold = [
    "positive",
    "negative", 
    "both_thresh"]

In [None]:
# Configure
config_text = []
record.apply_filter = options_filter[1]    
record.detect_method = options_detection[1]                                    
record.thresh_type = options_threshold[0]
# Select channel position/number in intan (not channel number in device)

record.channels = [2,4]  # Pairs 2&4, 27&30
record.path = path  
config_text = ['Load_from_file %s' %load_from_file, 'Filter: %s'%record.apply_filter, 'Detection: %s'%record.detect_method, 'Threhold type: %s'%record.thresh_type, 'Channels: %s' %record.channels, 'Downsampling: %s' %downsample]
config_text.append('Port %s' %(port))
config_text.append('Start %s, Dur: %s' %(start,dur))
config_text.append('Channels: %s' %record.channels)
# Ramarkable timestamps (in sec) 

group = '1'

print('SELECTED GENERAL CONFIGURATION:')
print('Filter: %s'%record.apply_filter)
print('Detection: %s'%record.detect_method)
print('Threhold type: %s'%record.thresh_type)
print('Channels: %s' %record.channels) 
print('-------------------------------------')

record.select_channels(record.channels) # keep_ch_loc=True if we want to display following the map. Otherwise follow the order provided by selected channels.
print('map_array: %s' %record.map_array)
print('ch_loc: %s' %record.ch_loc)
print('filter_ch %s' %record.filter_ch)
print('column_ch %s' %record.column_ch)

#### Select visualization options:  

In [None]:
# Configure
record.num_rows = 2#int(round(len(record.filter_ch)/2)) # round(n_components/2) 
record.num_columns = 1#int(len(record.filter_ch)-round(len(record.filter_ch)/2))+1 
plot_ch = int(record.map_array[record.ch_loc[0]])                
print(plot_ch)
print(record.num_rows)
print(record.num_columns)
save_figure = True

##### Gain

In [None]:
gain = 1
config_text.append('Gain: %s' %(gain))

##### Maximum bpm

In [None]:
bpm = 300
record.set_bpm(bpm) # General max bpm in rat HR. Current neurograms at 180bpm
config_text.append('BPM: %s' %(bpm))

##### Spike detection config

In [None]:
spike_detection_config = {
    'general':{
        'cardiac': False,
        # Length of window that containes a neural spike
        'spike_window': [int(0.002 * record.fs), int(0.002 * record.fs)], #[int(0.002 * record.fs), int(0.002 * record.fs)],   #[int(0.0018 * record.fs), int(0.0018 * record.fs)],  # in total 0.36s length AP: 3ms to refractory and 5ms total
        'min_thr': [0, 200],              # Min & max of amplitude of identified spikes
        'half_width': [0.1/1000,10/1000],  # Length in sec from zero cross to max of waveform. 0.5/1000
        'C': 3,  #3
        'find_peaks_args': {
        # Input to find_peaks() function:
        # Required minimal horizontal distance (>= 1) in samples between 
        # neighbouring peaks. Smaller peaks are removed first until the 
        # condition is fulfilled for all remaining peaks.
        'distance': int(0.0018 * 2 * record.fs),
        }
    },
    'cfar':{
        # Parameters for cfar only
        'nstd_cfar': 3,  
        'wdur': 1501 / record.fs,   # SO-CFAR window duration in seconds   1501      5001
        'gdur': 10 / record.fs,     # SO-CFAR guard duration in seconds   10         50
    }
}

hr_detection_config = {
    'general':{
    'cardiac': True,
    # Minimum separation between beats: samples between beats minus a buffer of 0.01s
    'window' : int(record.fs / (record.bpm / 60)-(0.01*record.fs)), #  3beats/second int(0.08 / 2 * record.fs) #1500    500 2500
    # Samples around a HR where a neural spike will be discarded (HR peak may deform signal)
    'spike_window': [int(0.03*record.fs), int(0.03*record.fs)], 
    'min_thr': [0, 0],              # Not used
    'half_width': [0,0],  # Not used 
    'C': 2,
    'find_peaks_args': {
        # Input to find_peaks() function:
        # Required minimal horizontal distance (>= 1) in samples between 
        # neighbouring peaks. Smaller peaks are removed first until the 
        # condition is fulfilled for all remaining peaks.
        'distance': int(record.fs / (record.bpm / 60)-(0.01*record.fs)) #0.0018 * 2 * record.fs  # 0.23s  3beats/second
    }
    }
}

noise_detection_config = {
    'general':{
    'cardiac': False,
    # Minimum separation between artifacts: samples between beats minus a buffer of 0.05s
    'window' : 10,
    # Samples around an artifact where a neural spike will be discarded (artifact peak may deform signal)
    'spike_window': [int(0.1*record.fs), int(0.1*record.fs)],  #0.01
    'C': 20, #20
    'find_peaks_args': {
        # Input to find_peaks() function:
        # Required minimal horizontal distance (>= 1) in samples between 
        # neighbouring peaks. Smaller peaks are removed first until the 
        # condition is fulfilled for all remaining peaks.
        'distance': 10
    }
    }
}

config_text.append('spike_detection_config: %s  ||  hr_detection_config: %s ||  noise_detection_config: %s ' %(json.dumps(spike_detection_config), json.dumps(hr_detection_config), json.dumps(noise_detection_config)))

#### Final initializations (No need to change)  

In [None]:
# Initialize dataframe for results 
#----------------------------------------------------
record.rolling_metrics = pd.DataFrame()
record.summary = pd.DataFrame(columns=['Max_spike_rate', 'Min_spike_rate',
                                'Max_amplitude_sum', 'Min_amplitude_sum'])
record.summary.index.name = 'channel'
record.sig2noise = [] #To save the snr for each channel

# Intialize dataframes for wavelet decomposition
#----------------------------------------------------
neural_wvl = pd.DataFrame(columns=record.filter_ch)
neural_wvl_denoised = pd.DataFrame(columns=record.filter_ch)
other_wvl = pd.DataFrame(columns=record.filter_ch)
substraction_wvl = pd.DataFrame(columns=record.filter_ch)

### Plot raw signal

In [None]:
record.plot_freq_content(record.original,int(plot_ch), nperseg=512, max_freq=4000, ylim=[-500, 500], dtformat='%H:%M:%S',
                         figsize=(10, 10), savefigpath='%s/figures/%s_ch%s_original-%s.png' %(record.path, port, plot_ch, current_time), 
                         show=False) 

### Channel referencing

In [None]:
radial_chanels = ['ch_1', 'ch_2', 'ch_3', 'ch_4', 'ch_26', 'ch_27', 'ch_28', 'ch_29', 'ch_30','ch_31']

channels =  radial_chanels  #record.filter_ch
ref_ch_name = 'mean' #'mean'
if ref_ch_name == 'mean':
    all_ch_list = [col for col in channels if col.startswith('ch_')] 
    ref_ch = record.original[all_ch_list].mean(axis=1)
else:
    ref_ch = record.original['ch_%s'%ref_ch_name]  
record.referenced = record.original[record.filter_ch].sub(ref_ch, axis=0)
record.referenced['seconds'] = record.original['seconds']
record.recording=record.referenced
record.recording.name = 'referenced'
config_text.append('ref_ch: %s' % ref_ch_name)


In [None]:
record.plot_freq_content(record.referenced,int(plot_ch), nperseg=512, max_freq=4000, ylim=[-750,750], dtformat='%H:%M:%S',figsize=(10, 10),
                         savefigpath='%s/figures/%s_ch%s_ref%s-%s.png' %(record.path, port, plot_ch, ref_ch_name, current_time),
                         show=True) 
    

### Filtering

#### Bandwidth filter

In [None]:
# Configure
filt_config = {
    'W': [400, 2000],  # (max needs to be <fs/2 per Nyquist)
    'None': {},
    'butter': {
            'N': 4,                # The order of the filter
            'btype': 'bandpass', #'bandpass', #'hp'  #'lowpass'     # The type of filter.
    },      
    'fir': {
            'n': 4,
    },
    'notch': {
            'quality_factor': 30,
    },
}

filt_config['butter']['Wn'] = filt_config['W']
filt_config['butter']['fs'] = record.fs

config_text.append('filt_config: %s' %json.dumps(filt_config))

##### Apply filter

In [None]:
# Configure
time_start = time.time()
signal2filter = record.recording
config_text.append('signal2filter: %s' %signal2filter.name)
record.filter(signal2filter, record.apply_filter, **filt_config[record.apply_filter])
# Change from float64 to float 16
record.filtered = convertDfType(record.filtered, typeFloat='float32')
#print(record.filtered.dtypes)
print("Time elapsed: {} seconds".format(time.time()-time_start))

##### Plot filtered signal

In [None]:
text_label = 'Filtered'
text = 'Channels after %s filtering'%record.apply_filter
record.plot_freq_content(record.filtered,int(plot_ch), nperseg=512, max_freq=4000, ylim=[-100, 100], dtformat='%H:%M:%S',
                         figsize=(10, 10), savefigpath='%s/figures/%s_ch%s_butter_filtering-%s.png' %(record.path,port,plot_ch, current_time),

#### Notch filtering

### NOISE: Envelope derivative operator (EDO)

In [None]:
noise_edo=pd.DataFrame()

for ch in record.filter_ch:
    x = record.recording[ch].to_numpy() 
    # Apply EDO filter
    x_e = gen_edo(x)
    # Store in DF to be loaded in signal analysis
    noise_edo[ch] = x_e
noise_edo['seconds'] = np.asarray(record.recording['seconds'])
noise_edo.index = pd.DatetimeIndex(noise_edo.seconds * 1e9)
noise_edo.index.name = 'time'
noise_edo.name = 'noise_edo'

config_text.append('signal to EDO: %s' %record.recording.name)

# Change from float64 to float 32
#noise_edo = convertDfType(noise_edo, typeFloat='float32')
#print(noise_edo.dtypes)

In [None]:
record.plot_freq_content(noise_edo, int(plot_ch), nperseg=512, max_freq=4000, dtformat='%H:%M:%S.%f',
                         figsize=(10, 10), show=False) 

if save_figure:
    pass


## Signal analysis

In [None]:
# Configure

# Select the signal that will be used to extract the neural spikes from
record.signal2analyse = record.recording     # record.filtered #neural_wvl #substraction_wvl #record.recording # record.recording #neural_wvl_denoised
record.signal2extract = record.signal2analyse  #record.filtered #substraction_wvl #record.recording #record.recording #neural_wvl_denoised #record.recording
print('Analysing signal: %s' %record.signal2analyse.name)

noise_signal = noise_edo

config_text.append('signal2analyse: %s' %record.signal2analyse.name)
config_text.append('signal2extract: %s' %record.signal2extract.name)
try:
    config_text.append('noise_signal: %s '%noise_signal.name)
    print(noise_signal.name)
except:
    config_text.append('noise_signal: %s' %'None') #'noise_edo' 'other_wvl_denoised'

cardiac_noise_config = False  # vs other noise configuration
consider_noise = True

#record.manual_thres = [10, -10] 

dtformat = '%M:%S.%f' #'%H:%M:%S'

verbose = True


In [None]:
#-------------------------------------------------------------
# Initialize figures
#-------------------------------------------------------------
time_start_analysis=time.time()

fig, axes = plt.subplots(record.num_rows, record.num_columns, figsize=(15, 5), sharex=True)
fig.suptitle('Identified peaks', fontsize=16, family='serif')

fig4, axes_metric = plt.subplots(2,1,figsize=(12, 10))
fig4.suptitle('Metrics evolution', fontsize=16, family='serif')

if (record.num_rows*record.num_columns)>1:
    fig2, axes_wv = plt.subplots(record.num_rows, record.num_columns, figsize=(15, 8), sharex=True)
    fig2.suptitle('Waveforms', fontsize=16, family='serif')  
    axes = axes.flatten()
    axes_wv = axes_wv.flatten()

if cardiac_noise_config:
    other_detection_config = hr_detection_config
else:
    other_detection_config = noise_detection_config    

save_waveforms = []
for n, j in enumerate(record.ch_loc): 
    
    if verbose:
        print(n)
        print(j)
        
    try:
        if record.signal2analyse.name=='ica':
            ch = 'ch_ica'
        else:
            ch = 'ch_%s'%int(record.map_array[j])
        
        if len(noise_signal)>0:
            noise = noise_signal[ch]
        else:
            noise = []   
        print(ch)
        
        #-------------------------------------------------------------
        # Start pipeline of cardiac and neural peaks identification
        #-------------------------------------------------------------
        hr_idx, hr_vector, spikes_idx, waves, spikes_vector, spikes_vector_loc, index_first_edge = \
                                record.pipeline_peak_extraction(ch, noise_signal, other_detection_config, 
                                                                spike_detection_config,
                                                                consider_noise=consider_noise, verbose=True)
        # Create waveform object for processing of waveforms
        waveforms = MyWaveforms(waves, record.signal2extract, record.fs, spikes_vector_loc, num_clus, record.path)

        #-------------------------------------------------------------
        # Compute overall rolling metrics 
        #-------------------------------------------------------------
        # If not enough spikes
        if len(spikes_idx) < 5:
            try: 
                record.rolling_metrics['%s' %ch] = np.nan
            except TypeError:
                print('%s was not processed and will not appear in rolling_metrics dataframe' %ch)    
            continue
        else:
            record.recording['spikes_amplitudes_%s' %ch] = spikes_vector
            record.recording['spikes_locations_%s' %ch] = spikes_vector_loc
            record.rolling_metrics['%s' %ch] = record.compute_rolling_metrics(record.signal2analyse,axes_metric, ch, 
                                                                              window=10, units='s', dtformat=dtformat, 
                                                                              show_plot=True, time_marks=time_marks)   # Returns only [metric_spikes_rate_%s' %ch]

    except KeyError:
        print('channel %s not found' %int(record.map_array[j]))

print ('Analysis done! Time elapsed: {} seconds'.format(time.time()-time_start_analysis))

if save_figure:
    fig4.savefig('%s/figures/metrics_evolution-%s_%s_group%s.svg' %(path, current_time, run, group), facecolor='w')

sys.exit()