# Extractor de features de templates

In [2]:
from configparser import ConfigParser, ExtendedInterpolation

import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import copy
import gc

from matplotlib import gridspec
from tqdm import tqdm
from scipy.signal import find_peaks, butter, filtfilt
from ipywidgets import interact, interact_manual, IntSlider, Dropdown, fixed

%matplotlib inline

config = ConfigParser(interpolation=ExtendedInterpolation())
config.read('../../config.ini')

['../../config.ini']

## Adquisición de datos y parámetros generales

In [3]:
# Parámetros generales
samplerate = 20000.0

exp_name = 'MR-262'
sorting_file = 'MR-262/MR-262.result.hdf5'
templates_file = 'MR-262/MR-262.templates-1.hdf5'

## Preparación de directorios de guardado

In [4]:
def check_exp_dir(exp):
    exp_dir = 'results/{}'.format(exp)
    if os.path.isdir('results') == False:
        os.mkdir('results')
    if os.path.isdir(exp_dir) == False:
        os.mkdir(exp_dir)        
    return exp_dir

def check_templates_dir(exp):
    exp_dir = check_exp_dir(exp)
    temp_dir = os.path.join(exp_dir, 'templates')
    
    if os.path.isdir(temp_dir) == False:
        os.mkdir(temp_dir)     
    return temp_dir

## Definición de función para extracción de features

In [5]:
def get_template(temp):
    # Se busca el peak más negativo o positivo y es seleccionado como el template a analizar
    # Ojo a las células al revés ej: 826 y 827
    if np.max(-temp) >= np.max(temp):
        temp /= np.max(-temp)
    else:
        temp /= -np.max(temp)
        
    pos_peak_thr = 0.05
    neg_peak_thr = 0.5
    cross_thr = -0.5

    neg_peaks,_ = find_peaks(-temp, height=neg_peak_thr)

    p_min = neg_peaks[np.argmin(temp[neg_peaks])]

    # Extract most negative peak template
    if p_min < 50:
        p_temp = np.concatenate((np.zeros(50 - p_min), temp[:p_min + 51]))
    elif p_min + 50 >= len(temp):
        p_temp = np.concatenate((temp[p_min - 50:], np.zeros(50 - (len(temp) - p_min) + 1)))
    else:
        p_temp = temp[p_min - 50 : p_min + 51]
    
    # Se almacena center_frame en los features para visualización
    center_frame = p_min
    p_min = 50
    
    # Se filtra el template para encontrar los peaks de hombros
    b, a = butter(2, 0.4)
    filtered = filtfilt(b, a, p_temp)

    # Peaks positivos
    try:
        pre_peaks0, _ = find_peaks(filtered[:50], height=pos_peak_thr)
        pp0 = pre_peaks0[np.argmax(filtered[pre_peaks0])]
        
        sh_time0 = (pp0 - 50) / 20.0
        sh_amp0 = filtered[pp0]
    except ValueError:
        pp0 = np.nan
        sh_time0 = np.nan
        sh_amp0 = np.nan
    
    try:
        pre_peaks1, _ = find_peaks(filtered[51:], height=pos_peak_thr)
        pre_peaks1 += 51
        pp1 = pre_peaks1[np.argmax(filtered[pre_peaks1])]
        
        sh_time1 = (pp1 - 50) / 20.0
        sh_amp1 = filtered[pp1]
    except ValueError:
        pp1 = np.nan
        sh_time1 = np.nan
        sh_amp1 = np.nan

    # Cruce por -0.5 del potencial. Regresión lineal para estimar el frame "exacto"
    pc0 = np.argmax(p_temp[:51] < cross_thr) 
    pc1 = np.argmax(p_temp[51:] > cross_thr) + 51

    m0 = p_temp[pc0] - p_temp[pc0 - 1]
    a0 = p_temp[pc0] - m0 * pc0

    m1 = p_temp[pc1] - p_temp[pc1 - 1]
    a1 = p_temp[pc1] - m1 * pc1

    # "Exact" crossing frame.
    _pc0 = (-0.5 - a0) / m0
    _pc1 = (-0.5 - a1) / m1

    # Referencias de tiempo desde el peak negativo
    # (Frame_cross - Frame_negative) / 20000.0 * 1000.0 [ms]
    SNR = 10 * np.log10(np.max(np.abs(p_temp)) / np.mean(np.abs(p_temp)))    

    cross_time0 = (_pc0 - 50) / 20.0
    cross_time1 = (_pc1 - 50) / 20.0
    cross_diff = cross_time1 - cross_time0
    
    temp_feat = {}
    temp_feat['temp'] = p_temp
    temp_feat['center_frame'] = center_frame
    temp_feat['feat_frames'] = (pp0, pp1, _pc0, _pc1)
    temp_feat['sh_amp0'] = sh_amp0
    temp_feat['sh_amp1'] = sh_amp1
    temp_feat['sh_time0'] = sh_time0
    temp_feat['sh_time1'] = sh_time1
    temp_feat['cross_time0'] = cross_time0
    temp_feat['cross_time1'] = cross_time1
    temp_feat['cross_diff'] = cross_diff
    temp_feat['SNR'] = SNR
    
    del p_temp, center_frame, pp0, pp1, _pc0, _pc1, pre_peaks0, pre_peaks1
    del sh_amp0, sh_amp1, sh_time0, sh_time1, cross_time0, cross_time1, cross_diff, SNR
    del b, a, filtered, neg_peaks, p_min
    
    return temp_feat

## Función de gráficos de templates por célula

In [6]:
def plot_template(temp, resp, index):
    save_file = os.path.join(check_templates_dir(exp_name),
                             '{}.png'.format(index))
    
    # Plotting every template of the cell
    fig = plt.figure(figsize=(6, 6)) 
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 3]) 

    ax0 = plt.subplot(gs[0])
    
    p_min = resp['center_frame']
    
    ax0.plot(temp)
    ax0.plot(p_min, temp[p_min], 'o')
    ax0.axvspan(p_min - 50, p_min + 50, alpha=0.3, color='green')
    ax0.set_title('{}, {} templates'.format(exp_name, index))
        
    ax1 = plt.subplot(gs[1])
    # Se grafica el template y el peak del potencial de acción
    ax1.plot(resp['temp'])
    ax1.plot(50, resp['temp'][50], 'o')

    # Se grafica peaks positivos
    ax1.plot(resp['feat_frames'][0], resp['sh_amp0'], 'o', color='green')
    ax1.plot(resp['feat_frames'][1], resp['sh_amp1'], 'o', color='green')

    # Se grafica el frame calculado para el cruce por -0.5 del potencial normalizado
    ax1.plot(resp['feat_frames'][2], -0.5, 'o', color='purple')
    ax1.plot(resp['feat_frames'][3], -0.5, 'o', color='purple')
    ax1.set_title('Template selected')
    plt.tight_layout()
    
    return fig, gs

## Extractor de features y almacenamiento de figuras
Esta celda de código genera un .csv con los features extraídos de los templates por célula, además de almacenar las figuras de templates en la carpeta ./results

In [7]:
# Variables para posterior visualización
cell_indexes = []

with h5py.File(templates_file, 'r') as pot:
    #spikes = spks['/spiketimes/'+key][...].flatten()/samplerate
    temp_x = pot['temp_x'][:].ravel()
    temp_y = pot['temp_y'][:].ravel() # Cell-identifier
    temp_data = pot['temp_data'][:].ravel()
    
cell_indexes = np.linspace(0, temp_y[-1], temp_y[-1] + 1, dtype=int)

indexes = ['temp_{}'.format(i) for i in cell_indexes]            
columns = ['sh_amp0', 'sh_amp1', 'sh_time0', 'sh_time1',
           'cross_time0', 'cross_time1', 'cross_diff', 'SNR']

save_file = os.path.join(check_exp_dir(exp_name), 'temp_features.csv')

df = pd.DataFrame(columns=columns)
df.to_csv(save_file, mode='w')

for idx, cell in tqdm(zip(cell_indexes, indexes), total=len(indexes)):
    temp = temp_data[temp_y == idx]
    temp_feat = get_template(temp)

    # Guardar figuras de templates
    fig, gs = plot_template(temp, temp_feat, cell)

    # Guardar features en csv
    temp_feat.pop('temp')
    temp_feat.pop('center_frame')
    temp_feat.pop('feat_frames')

    df = pd.DataFrame([temp_feat], columns=columns, index=[cell])
    
    with open(save_file, "a") as output:
        df.to_csv(output, header=False)

    plt.clf()
    plt.cla()
    plt.close(fig)
    del temp, temp_feat, df, gs

    gc.collect()
        

 13%|█▎        | 280/2106 [01:00<06:35,  4.61it/s]


KeyboardInterrupt: 

## Visualizador de templates
Una vez ejecutada la celda anterior se puede visualizar los templates de células para selección manual y descarte de templates inválidos. Otros criterios de descarte automático pueden ser incorporados al script extractor en el futuro

In [8]:
def vis_template(cell_key):
    cell = vis_temp[cell_key]
    plot_template(cell[0], cell[1], cell_key)

interact(vis_template,
         cell_key=Dropdown(options=vis_temp.keys()));

interactive(children=(Dropdown(description='cell_key', options=('temp_0', 'temp_1', 'temp_2', 'temp_3', 'temp_…