In [59]:
import os
import numpy as np
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from example.cbnu.utils import get_interval
from scipy.io import loadmat
from scipy.signal import find_peaks
from sklearn.cluster import k_means

## Parameter settings

In [51]:
sample_rate = 25000
num_trials = 40
num_delays = 11
step_size_delays = 5
data_path = 'C:\\Users\\bodor\\Documents\\Korea\\experiment\\alternating_pulses_in_corners\\5uA_1ms_1Hz_cathodic\\Stim_Location_Green(ch61)_Blue(ch57)'
cell_path = os.path.join(data_path, 'spiketimes')

## Load data

In [105]:
cell_names = os.listdir(cell_path)
output_path = os.path.join(data_path, 'plots')
if not os.path.exists(output_path):
    os.makedirs(output_path)

spike_times = {}
for cell_name in cell_names:
    spike_times[cell_name[:-4]] = loadmat(os.path.join(cell_path, cell_name), 
                                          squeeze_me=True)['timestamps']

trigger_times = []
for filename in os.listdir(data_path):
    if 'trigger_times' in filename:
        trigger_times.append(np.loadtxt(os.path.join(data_path, filename))/1e6)

plot_kwargs = {'pre': 0.01, 'post': 0.09, 'num_bins': 50, 
               'cells_to_plot': None}

## Plotting

In [106]:
def get_peaks(_spike_times, _trigger_times, path, _delay, _sample_rate, 
              _cell_name, save_plot, **_kwargs):

    num_bins = _kwargs.get('num_bins', None)
    if num_bins is None:
        num_bins = 'auto'

    diff = np.diff(_trigger_times)
    m = np.median(diff)

    pre = _kwargs.get('pre', None)
    if pre is None:
        pre = 1
        if pre > m:
            pre = m

    post = _kwargs.get('post', None)
    if post is None:
        post = 1
        if post > m:
            post = m
    
    spike_times_section = get_interval(_spike_times, _trigger_times[0] - pre,
                                       _trigger_times[-1] + post)

    spike_times_zerocentered = []
    
    figure = Figure()
    canvas = FigureCanvas(figure)
    axes = figure.subplots(1, 1)
    axes.set_xlabel("Time [ms]")
    for trigger_time in _trigger_times:
        t_pre = trigger_time - pre
        t_post = trigger_time + post
        
        x = get_interval(spike_times_section, t_pre, t_post)
        if len(x):
            x -= trigger_time
        # Seconds to ms
        x *= 1e3
        spike_times_zerocentered.append(x)
        
    counts, bin_edges, _ = axes.hist(np.concatenate(spike_times_zerocentered),
                                     num_bins, histtype='stepfilled',
                                     facecolor='k', align='left')
    
    # median = np.median(counts)
    # mad = np.median(np.abs(counts - median))
    # min_height = median + 5 * mad
    mean = np.mean(counts)
    std = np.std(counts)
    min_height = mean + 2 * std
    peak_idxs, _ = find_peaks(counts, min_height)
    peak_heights = counts[peak_idxs]
    sort_idxs = np.argsort(peak_heights)
    max_peak_idxs = peak_idxs[sort_idxs][-2:]
    
    ymax = axes.get_ylim()[1]
    peak_times = []
    if len(max_peak_idxs) > 0:
        peak_time = bin_edges[max_peak_idxs[0]]
        axes.vlines(peak_time, 0, ymax, color='g')
        peak_times.append(peak_time)
    if len(max_peak_idxs) > 1:
        peak_time = bin_edges[max_peak_idxs[1]]
        axes.vlines(peak_time, 0, ymax, color='b')
        peak_times.append(peak_time)

    if save_plot:
        pre_ms = 1e3 * pre
        post_ms = 1e3 * post
        axes.set_xlim(-pre_ms, post_ms)
        axes.vlines(0, 0, ymax, color='r')
        axes.hlines(min_height, -pre_ms, post_ms, color='y')    
        figure.subplots_adjust(wspace=0, hspace=0)
        filepath = os.path.join(path, 'PSTH_{}_{}.png'.format(_cell_name, _delay))
        canvas.print_figure(filepath, bbox_inches='tight', dpi=100)

    return peak_times

peaks = [[] for _ in range(num_delays)]
peak_diffs = [[] for _ in range(num_delays)]
cells_to_plot = plot_kwargs.get('cells_to_plot', None)
for cell_name, cell_spikes in spike_times.items():
    if cells_to_plot is not None and cell_name not in cells_to_plot:
            continue
    for i in range(num_delays):
        delay = step_size_delays * i
        window = slice(i * num_trials, (i + 1) * num_trials)
        cell_peaks = get_peaks(cell_spikes, trigger_times[0][window], output_path,
                               delay, sample_rate, cell_name, False, **plot_kwargs)
        if len(cell_peaks) == 2:
            peak_diffs[i].append(np.abs(cell_peaks[1] - cell_peaks[0]))
        elif len(cell_peaks) == 1:
            if i == 0:
                peak_diffs[i].append(cell_peaks[0])
            # if i == 1:
            #     peak_diffs[i].append(cell_peaks[0])
        peaks[i] += cell_peaks

In [115]:
figure = Figure()
canvas = FigureCanvas(figure)
axes = figure.subplots(1, 1)
axes.boxplot(peak_diffs, notch=True, patch_artist=True) 
axes.set_xticklabels(step_size_delays * np.arange(num_delays))
axes.set_xlabel("Stimulus delay [ms]")
axes.set_ylabel("Response delay [ms]")
axes.plot([1, num_delays], [0, step_size_delays * (num_delays - 1)])
canvas.print_figure(os.path.join(output_path, 'delay_diffs'), bbox_inches='tight', dpi=100)

In [114]:
figure = Figure()
canvas = FigureCanvas(figure)
axes = figure.subplots(1, 1)
colors = ['b', 'g']
cluster_means = [[], []]
for i, delay_peaks in enumerate(peaks):
    target_delay = i * step_size_delays
    delay_peaks = np.array(delay_peaks)
    weights = np.ones_like(delay_peaks)
    weights[delay_peaks > target_delay + 30] = 0.1
    weights[delay_peaks < 0] = 0
    c = colors[0]
    kwargs = dict(boxprops=dict(facecolor=c, color=c), capprops=dict(color=c),
                  whiskerprops=dict(color=c), 
                  flierprops=dict(color=c, markeredgecolor=c))
    if i == 0:
        mean, _, _ = k_means(np.expand_dims(delay_peaks, -1), 1, weights,
                             np.array([[0]]), n_init=1, n_jobs=-1)
        axes.boxplot(delay_peaks, notch=True, patch_artist=True, positions=[i],
                     **kwargs)
        cluster_means[0].append(mean)
        cluster_means[1].append(mean)
    else:
        means, labels, _ = k_means(np.expand_dims(delay_peaks, -1), 2, weights,
                                   np.array([[0], [target_delay]]), n_init=1, n_jobs=-1)
        means_sorted = np.sort(means, 0)
        if not np.array_equal(means, means_sorted):
            labels = np.logical_not(labels)
            means = means_sorted
        for cluster_id in [0, 1]:
            cluster = delay_peaks[labels == cluster_id]
            mean = means[cluster_id]
            c = colors[cluster_id]
            kwargs = dict(boxprops=dict(facecolor=c, color=c), 
                          capprops=dict(color=c), whiskerprops=dict(color=c), 
                          flierprops=dict(color=c, markeredgecolor=c))
            axes.boxplot(cluster, notch=True, patch_artist=True, positions=[i],
                         **kwargs)
            axes.plot(i, mean)
            cluster_means[cluster_id].append(mean)
axes.set_xticks(np.arange(num_delays))
axes.set_xticklabels(step_size_delays * np.arange(num_delays))
axes.set_xlabel("Stimulus delay [ms]")
axes.set_ylabel("Response times [ms]")
axes.plot([0, 10], [0, 50], colors[1])
axes.plot(cluster_means[0], colors[0], linestyle='--')
axes.plot(cluster_means[1], colors[1], linestyle='--')
offset = np.mean(cluster_means[0])
axes.plot(np.array(cluster_means[0]) - offset, colors[0], linestyle=':')
axes.plot(np.array(cluster_means[1]) - offset, colors[1], linestyle=':')
axes.hlines(0, 0, 10, colors[0])
canvas.print_figure(os.path.join(output_path, 'peaks'), bbox_inches='tight', dpi=100)