# Robustness of Crayfish Nervous System to Environmental pH
## Step 1: Format Data

In [1]:
import neo, os
import numpy as np
from mne import create_info
from mne.io import RawArray

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## I/O parameters.
stop_list = ['171208_pH74_1_1.abf', '171208_pH72_1_1.abf']

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main body.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Locate files.
files = [f for f in os.listdir('raw') if f.endswith('abf')]
files = [f for f in files if f not in stop_list]

for f in files:

    ## Load recordings.
    recordings, = neo.AxonIO(filename='raw/%s' %f).read_block().segments
    signal, = recordings.analogsignals

    ## Concatenate raw recordings.
    data = np.hstack([np.asarray(signal, dtype=np.float64)  * 1e-6]).T
    if f.startswith('171212'): data *= -1 # Fix recording issue on Day 2.
    
    ## Create info object.
    sfreq = float(signal.sampling_rate)
    ch_names = ['nerve']
    ch_types = 'bio'

    info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)

    ## Create Raw object.
    raw = RawArray(data, info, verbose=False)

    ## Save Raw object.  
    raw.save('raw/%s' %f.replace('.abf','_raw.fif'), overwrite=True, verbose=False)
    
print('Done.')

Done.


## Step 2: Preprocess Data

In [2]:
import os
import numpy as np
from mne import Epochs, make_fixed_length_events
from mne.io import Raw
from pandas import DataFrame, concat
from spike_sorting import find_threshold, peak_finder

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## I/O parameters.
sessions = ['171208', '171212']

## Filter parameters.
l_freq = 300
h_freq = 3000

## Epoching.
duration = 9.99 # seconds

## Spike detection parameters.
k = 5
reject = 200 # uV

## Spike sorting parameters.
n_clusters = 5
n_refs = 10

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

dataset = []
for session in sessions:
    
    ## Locate files.
    files = sorted([f for f in os.listdir('raw') if f.startswith(session) and f.endswith('fif')])
    msg = '\nSession = %s' %session
    print('%s\n%s' %(msg,'-'*len(msg)))
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Assemble recordings.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    print('Preprocessing data')
    
    data = []
    for f in files:
        
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Load and prepare data.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

        ## Load raw.
        raw = Raw('raw/%s' %f, preload=True, verbose=False)
        
        ## Filter data.
        raw = raw.filter(l_freq, h_freq, picks=[0], method='fir', phase='zero', 
                         fir_design='firwin', verbose=False)
        
        ## Crop raw (remove filter artifact).
        raw = raw.crop(0.05)
        
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Epoching.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        
        ## Make events of equal length.
        events = make_fixed_length_events(raw, 1, duration=duration)
        
        ## Make epochs.
        epochs = Epochs(raw, events, tmin=0, tmax=duration, baseline=None,
                        picks=[0], preload=True, verbose=False)
        
        ## Append.
        data.append( epochs.get_data().squeeze() )
        
    ## Concatenate recordings.
    data = np.array(data)
    data *= 1e6 # Convert to uV.
    times = epochs.times
    
    ## Print metadata.
    n_recordings, n_epochs, n_times = data.shape
    print('  Data = %s recordings, %s epochs, %0.2fs' %(n_recordings, n_epochs, n_times / raw.info['sfreq']))
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Spike detection.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    print('Spike Detection')
    
    ## Set threshold.
    threshold = find_threshold(data, k)
    print('  Threshold = %0.2f uV' %threshold)
    
    ## Iterate over epochs.
    spikes = []
    for i, f in enumerate(files):
        
        for j in range(n_epochs):
            
            ## Extract metadata.
            session, pH, _, recording, _ = f.split('_')
            pH = float(pH[-2:]) / 10
            
            ## Detect spikes.
            peak_loc, peak_mag = peak_finder(data[i,j], threshold)
            if np.any(peak_loc): peak_loc = times[peak_loc] + j * times.max()
            
            ## Store as DataFrame. Append.
            df = DataFrame( np.vstack([peak_loc, peak_mag]).T, columns=('Time','Amplitude') )
            for column, value in zip(['Epoch','Recording','pH','Session'], [j+1, recording, pH, session]):
                df.insert(0, column, value)
            spikes.append(df)
            
    ## Concatenate DataFrames.
    spikes = concat(spikes)
    
    ## Amplitude rejection.
    spikes = spikes[spikes.Amplitude < reject]
    print('  N spikes = %s' %spikes.shape[0])
    
    ## Append.
    dataset.append(spikes)
    
## Concatenate DataFrames.
spikes = concat(dataset)
spikes.to_csv('spikes.csv', index=False)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Compute counts.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Setup GroupBy object.
columns = ['Session','pH','Recording','Epoch']
gb = spikes.groupby(columns)

## Compute counts.
counts = gb.Amplitude.count().reset_index()
counts.columns = columns + ['Count']

## Compute average spike amplitude.
amplitude = gb.Amplitude.mean().reset_index()

## Merge. Save.
counts = counts.merge(amplitude, on=columns)
counts.to_csv('counts.csv', index=False)

print('Done.')


Session = 171208
-----------------
Preprocessing data
  Data = 35 recordings, 12 epochs, 9.99s
Spike Detection
  Threshold = 17.60 uV
  N spikes = 77124

Session = 171212
-----------------
Preprocessing data
  Data = 39 recordings, 12 epochs, 9.99s
Spike Detection
  Threshold = 18.06 uV
  N spikes = 160525
Done.
