In [16]:
import os
import subprocess
from collections import OrderedDict
import numpy as np
import pandas
from McsPy.McsData import RawData
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.ticker import FormatStrFormatter, ScalarFormatter
from matplotlib.patches import ConnectionPatch
from tridesclous.datasource import get_all_channel_data

## Helper functions

In [4]:
def remove_nan(array):
    return array[~np.isnan(array)]

## Parameter settings

In [5]:
# input_path = '\\\Physiology-csk\\mea\\MEA1060\\TEST\\20191129_SpikeSorting_WT_PNW12_Male\\Righteye'
input_path = 'C:\\Users\\user\\Documents\\Bodo_Rueckauer\\Data\\raw\\All'
times_filename = '자극시점.xlsx'
traces_filename = 'DATA_spont_light_electric_ch55.mcd'

## Load data

In [6]:
times_filepath = os.path.join(input_path, times_filename)
traces_filepath = os.path.join(input_path, traces_filename)
output_path = os.path.join(input_path, 'plots')
if not os.path.exists(output_path):
    os.makedirs(output_path)

trigger_sheet = pandas.read_excel(times_filepath, sheet_name=0, header=1, 
                                  index_col=0, skiprows=1)

trigger_times = OrderedDict({
    'Full-field': OrderedDict(
        {'Full-field': None}),
    'Moving bar': OrderedDict(
        {'Moving_L>R': None, 'Moving_R>L': None, 'Moving_T>B': None, 
         'Moving_B>T': None, 'Moving_LT>RB': None, 'Moving_LB>RT': None,
         'Moving_RT>LB': None, 'Moving_RB>LT': None}),
    'Electric': OrderedDict(
        {'Cathodic': OrderedDict(
            {'10uA_0.5ms': None, '30uA_0.5ms': None, '50uA_0.5ms': None, 
             '10uA_1ms': None, '30uA_1ms': None, '50uA_1ms': None, 
             '10uA_2ms': None, '30uA_2ms': None, '50uA_2ms': None, 
             '10uA_4ms': None, '30uA_4ms': None, '50uA_4ms': None}),
         'Anodic': OrderedDict(
             {'10uA_0.5ms': None, '30uA_0.5ms': None, '50uA_0.5ms': None, 
              '10uA_1ms': None, '30uA_1ms': None, '50uA_1ms': None, 
              '10uA_2ms': None, '30uA_2ms': None, '50uA_2ms': None, 
              '10uA_4ms': None, '30uA_4ms': None, '50uA_4ms': None})})})

plot_kwargs = {
    'Full-field': {'pre': 0.5, 'post': 6, 'ymin': None, 'ymax': None, 
                   'num_bins': 100, 'cells_to_plot': ['ch_71a', 'ch_71b']},
    'Moving bar': {'pre': 0.5, 'post': 5, 'ymin': None, 'ymax': None, 
                   'num_bins': 100, 'cells_to_plot': ['ch_71a', 'ch_71b']},
    'Electric': {'pre': 1, 'post': 1, 'ymin': None, 'ymax': None, 
                 'num_bins': 100, 'cells_to_plot': ['ch_71a', 'ch_71b']}}

for section in trigger_times.values():
    for key, subsection in section.items():
        if type(subsection) == OrderedDict:
            for subkey in subsection.keys():
                subsection[subkey] = remove_nan(trigger_sheet[subkey].to_numpy())
        else:
            section[key] = remove_nan(trigger_sheet[key].to_numpy())

spike_sheet = pandas.read_excel(times_filepath, sheet_name=1, header=0)

spike_times = OrderedDict()
for cell_name, cell_data in spike_sheet.items():
    if 'ch_' not in cell_name:
        continue
    spike_times[cell_name] = remove_nan(cell_data.to_numpy())
    

### Convert mcd to h5

In [7]:
basename, ext = os.path.splitext(traces_filepath)
traces_filepath_h5 = basename + '.h5'

In [34]:
if ext in {'.mcd', '.msrs'}:
    subprocess.run(["MCDataConv", "-t", "hdf5", traces_filepath])
elif ext == '.msrd':
    subprocess.run(["MCDataConv", "-t", "hdf5", basename + '.msrs'])
else:
    raise NotImplementedError

In [8]:
def load_h5(path):

    data = RawData(path)
    assert len(data.recordings) == 1, \
        "Can only handle a single recording per file."

    electrode_data = None
    for stream_id, stream in data.recordings[0].analog_streams.items():
        if stream.data_subtype == 'Electrode':
            electrode_data = stream
            break
    assert electrode_data is not None, "Electrode data not found."

    _traces, _sample_rate = get_all_channel_data(electrode_data)

    return _traces, _sample_rate

traces, sample_rate = load_h5(traces_filepath_h5)

Recording_0 <HDF5 group "/Data/Recording_0" (1 members)>
Stream_0 <HDF5 group "/Data/Recording_0/AnalogStream/Stream_0" (3 members)>
ChannelData <HDF5 dataset "ChannelData": shape (60, 44315000), type "<i4">
ChannelDataTimeStamps <HDF5 dataset "ChannelDataTimeStamps": shape (1, 3), type "<i8">
InfoChannel <HDF5 dataset "InfoChannel": shape (60,), type "|V108">
Stream_1 <HDF5 group "/Data/Recording_0/AnalogStream/Stream_1" (3 members)>
ChannelData <HDF5 dataset "ChannelData": shape (3, 44315000), type "<i4">
ChannelDataTimeStamps <HDF5 dataset "ChannelDataTimeStamps": shape (1, 3), type "<i8">
InfoChannel <HDF5 dataset "InfoChannel": shape (3,), type "|V108">


## Plotting

In [20]:
label_map = {}
num_rows = 8
num_columns = 8
to_skip = {(0, 0), (0, num_columns - 1), (num_rows - 1, 0),
           (num_rows - 1, num_columns - 1)}

for c in range(num_columns):
    for r in range(num_rows):
        if (r, c) in to_skip:
            continue
        label_map["{}{}".format(c + 1, r + 1)] = len(label_map)

def plot_cell(_traces, _spike_times, _trigger_times, path, _title, 
              stimulus_type, _sample_rate, _cell_name, **_kwargs):

    num_bins = _kwargs.get('num_bins', None)
    if num_bins is None:
        num_bins = 'auto'

    diff = np.diff(_trigger_times)
    m = np.median(diff)

    pre = _kwargs.get('pre', None)
    if pre is None:
        pre = 1
        if pre > m:
            pre = m

    post = _kwargs.get('post', None)
    if post is None:
        post = 1
        if post > m:
            post = m
    
    start = int((_trigger_times[0] - pre) * _sample_rate)
    stop = int((_trigger_times[-1] + post) * _sample_rate)
    tr = _traces[np.arange(start, stop)]

    ymin = _kwargs.get('ymin', None)
    if ymin is None:
        ymin = np.min(tr)
        if ymin < -1e-3:
            ymin = np.percentile(tr[tr < 0], 2)

    ymax = _kwargs.get('ymax', None)
    if ymax is None:
        ymax = np.max(tr)
        if ymax > 1e-3:
            ymax = np.percentile(tr[tr > 0], 98)

    mask = np.logical_and(
        np.greater_equal(_spike_times, _trigger_times[0] - pre),
        np.less(_spike_times, _trigger_times[-1] + post))
    spike_times_section = _spike_times[mask]

    spike_times_zerocentered = []
    
    num_trials = len(_trigger_times)

    figure = Figure()
    canvas = FigureCanvas(figure)
    axes = figure.subplots(num_trials + 2, 1)
    axes[-1].set_xlabel("Time [s]")
    color = 'k'
    for i in range(num_trials):
        trigger_time = _trigger_times[i]
        t_pre = trigger_time - pre
        t_post = trigger_time + post
        start = int(t_pre * _sample_rate)
        stop = int(t_post * _sample_rate)
        trace = _traces[np.arange(start, stop)]
        axes[i].set_ylim(ymin, ymax)
        axes[i].set_xlim(0, (pre + post) * sample_rate)
        axes[i].plot(trace, color=color, linewidth=0.1)
        if i > 0:
            axes[i].axis('off')
        
        mask = np.logical_and(np.greater_equal(spike_times_section, t_pre),
                              np.less(spike_times_section, t_post))
        x = spike_times_section[mask]
        if len(x):
            x -= trigger_time
        spike_times_zerocentered.append(x)
        
    axes[-2].eventplot(spike_times_zerocentered, color=color, linewidths=0.5,
                       lineoffsets=-1)

    counts, _, _ = axes[-1].hist(np.concatenate(spike_times_zerocentered),
                                 num_bins, histtype='stepfilled',
                                 facecolor=color)
    
    fmt = ScalarFormatter()
    fmt.set_scientific(True)
    fmt.set_powerlimits((-3, 4))
    axes[0].yaxis.set_major_formatter(fmt)
    axes[0].spines['top'].set_visible(False)
    axes[0].spines['right'].set_visible(False)
    axes[0].spines['bottom'].set_visible(False)
    axes[0].set_xticks([])
    axes[-2].set_xlim(-pre, post)
    axes[-2].axis('off')
    axes[-1].set_xlim(-pre, post)
    axes[-1].yaxis.set_major_formatter(FormatStrFormatter('%d'))
    axes[-1].spines['top'].set_visible(False)
    axes[-1].spines['right'].set_visible(False)
    axes[-1].add_artist(ConnectionPatch((0, 0), (pre * _sample_rate, ymax), 
                                        'data', 'data', axes[-1], axes[0], 
                                        color='r'))
    
    figure.subplots_adjust(wspace=0, hspace=0)
    stimulus_type = stimulus_type.replace('>', '-')
    filepath = os.path.join(path, '{}_{}_{}.png'.format(_cell_name, _title, 
                                                        stimulus_type))
    canvas.print_figure(filepath, bbox_inches='tight', dpi=200)


for cell_name, cell_spikes in spike_times.items():
    
    channel_idx = label_map[cell_name[3:5]]
    channel_data = traces[:, channel_idx]           
    for title, section in trigger_times.items():
        kwargs = plot_kwargs[title]
        cells_to_plot = kwargs.get('cells_to_plot', None)
        for key, subsection in section.items():
            if cells_to_plot is not None and cell_name not in cells_to_plot:
                continue
            if type(subsection) == OrderedDict:
                for subkey in subsection.keys():
                    plot_cell(channel_data, cell_spikes, subsection[subkey],
                              output_path, '{} ({})'.format(title, key), 
                              subkey, sample_rate, cell_name, **kwargs)
            else:
                plot_cell(channel_data, cell_spikes, section[key], output_path,
                          title, key, sample_rate, cell_name, **kwargs)