In [1]:
import os
from collections import OrderedDict
import numpy as np
import pandas
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import skimage.measure
from scipy.stats import percentileofscore
from example.cbnu.utils import get_interval
%matplotlib inline

## Helper functions

In [2]:
def remove_nan(array):
    return array[~np.isnan(array)]

## Parameter settings

In [3]:
input_path = 'C:\\Users\\user\\Documents\\Bodo_Rueckauer\\Data\\direction_selectivity'
times_filename = '자극시점.xlsx'

## Load data

In [4]:
times_filepath = os.path.join(input_path, times_filename)
output_path = os.path.join(input_path, 'output')
if not os.path.exists(output_path):
    os.makedirs(output_path)

label_map = {'Moving_R>L': 0, 'Moving_RT>LB': 1, 'Moving_T>B': 2,  
             'Moving_LT>RB': 3, 'Moving_L>R': 4, 'Moving_LB>RT': 5, 
             'Moving_B>T': 6, 'Moving_RB>LT': 7}

num_directions = len(label_map)
num_trials = 3  # None

angles = 2 * np.pi * np.arange(num_directions) / num_directions

In [5]:
trigger_sheet = pandas.read_excel(times_filepath, sheet_name=0, header=1, 
                                  skiprows=1, nrows=num_trials,
                                  usecols=2+np.arange(num_directions))

In [6]:
trigger_times = trigger_sheet.values

if num_trials is None:
    num_trials = trigger_times.shape[1]

all_trigger_times = np.ravel(trigger_times)
trigger_durations = np.diff(all_trigger_times)
# Add a trigger duration for final trigger.
trigger_durations = np.concatenate([trigger_durations, 
                                    [np.median(trigger_durations)]])

spike_sheet = pandas.read_excel(times_filepath, sheet_name=1, header=0)

spike_times_cells = OrderedDict()
for cell_name, cell_data in spike_sheet.items():
    if 'ch_' not in cell_name:
        continue
    spike_times_cells[cell_name] = remove_nan(cell_data.to_numpy())

In [7]:
def snr(data):
    return np.var(np.mean(data, 0), -1) / np.mean(np.var(data, -1), 0)

In [8]:
all_cells = [k for k in spike_sheet.keys() if 'ch_' in k]

In [11]:
angles_deg = ['', '0', '45', '90', '135', '180', '225', '270', '315']
cmap = 'autumn'
sample_rate = 25000
cells_to_plot = all_cells  # ['ch_71a', 'ch_71b', 'ch_72a', 'ch_72b']
num_cells = len(cells_to_plot)
min_duration = np.min(trigger_durations)
min_ticks = int(sample_rate * min_duration)
num_bins = 32
num_permuations = 1000
projection = np.exp(1j * angles)

for cell_label in cells_to_plot:
    fig = Figure(figsize=(10, 14))
    canvas = FigureCanvas(fig)
    ax00 = fig.add_subplot(4, 2, 1, projection='polar')
    ax01 = fig.add_subplot(4, 2, 2, projection='polar')
    ax10 = fig.add_subplot(4, 2, 3)
    ax11 = fig.add_subplot(4, 2, 4)
    ax20 = fig.add_subplot(4, 2, 5)
    ax21 = fig.add_subplot(4, 2, 6)
    ax30 = fig.add_subplot(4, 2, 7)
    ax31 = fig.add_subplot(4, 2, 8)
    
    spike_times_cell = spike_times_cells[cell_label]
    spike_sums = np.zeros((num_trials, num_directions, num_bins))
    for trial_idx, trigger_times_trial in enumerate(trigger_times):
        for direction_idx, trigger_time_direction in enumerate(trigger_times_trial):
            start = trigger_time_direction
            spike_times = get_interval(spike_times_cell, start, start + min_duration)
            spike_ticks = ((spike_times - start) * sample_rate).astype(int)
            spike_mask = np.zeros(min_ticks)
            spike_mask[spike_ticks] = 1
            spike_sum = skimage.measure.block_reduce(
                spike_mask, (min_ticks // num_bins,))[:num_bins]
            spike_sums[trial_idx, direction_idx] = spike_sum
        
        total_spike_counts = np.sum(spike_sums[trial_idx], -1)
        ax00.plot(angles, total_spike_counts, 'b', linewidth=0.5)
        ax00.plot((angles[-1], angles[0]), (total_spike_counts[-1], 
                                            total_spike_counts[0]), 
                  'b', linewidth=0.5)
        
        mat = spike_sums[trial_idx] / max(1, np.max(spike_sums[trial_idx]))
        u, s, vh = np.linalg.svd(mat.transpose(), full_matrices=False)
        vv = vh[0]
        vv *= -1 
        ax01.plot(angles, vv, 'b', linewidth=0.5)
        ax01.plot([angles[-1], angles[0]], [vv[-1], vv[0]], 'b', linewidth=0.5)
    
    mean_spike_counts = np.mean(np.sum(spike_sums, -1), 0)
    # mean_spike_counts /= np.max(mean_spike_counts)
    ax00.plot(angles, mean_spike_counts, 'k', marker='o')
    ax00.plot((angles[-1], angles[0]), (mean_spike_counts[-1], 
                                        mean_spike_counts[0]), 'k')
    vectorsum = np.dot(projection, mean_spike_counts)
    dsi = np.abs(vectorsum)
    ax00.plot((0, np.angle(vectorsum)), (0, dsi), color='r')

    mat = np.mean(spike_sums, 0)
    mat /= np.max(mat)
    u, s, vh = np.linalg.svd(mat.transpose(), full_matrices=False)
    vv = vh[0]
    mat_reconstructed = s[0] * np.outer(u[:, 0], vv)
    vv *= -1
    # ax21.matshow(mat_reconstructed.transpose(), cmap=cmap)
    ax21.matshow(np.expand_dims(vv, -1), cmap=cmap)
    ax31.matshow(mat, cmap=cmap)
    snr_ = snr(spike_sums)
    ax01.plot(angles, vv, 'k')#, marker='o', markerfacecolor='none')
    ax01.plot((angles[-1], angles[0]), (vv[-1], vv[0]), 'k')
    ax01.scatter(angles[snr_ > 0.6], vv[snr_ > 0.6], marker='o', color='k')
    vectorsum1 = np.dot(projection, vv)
    dsi1 = np.abs(vectorsum1)
    ax01.plot((0, np.angle(vectorsum1)), (ax01.get_ylim()[0], dsi1), color='r')
      
    mean_spike_counts_copy = np.copy(mean_spike_counts)
    dsis_permuted = []
    for i in range(num_permuations):
        np.random.shuffle(mean_spike_counts_copy)
        dsi_permuted = np.abs(np.dot(projection, mean_spike_counts_copy))
        dsis_permuted.append(dsi_permuted)
    p = 1 - percentileofscore(dsis_permuted, dsi) / 100
    ax10.hist(dsis_permuted, 'auto', histtype='stepfilled')
    ax10.vlines(dsi, 0, ax10.get_ylim()[1], 'r')
    ax10.text(dsi, 0, "p={:.2f}".format(p), color='r', 
              horizontalalignment='center', verticalalignment='top')
    
    mat_copy = np.copy(mat)
    dsis_permuted = []
    for i in range(num_permuations):
        np.random.shuffle(mat_copy)
        u, s, vh = np.linalg.svd(mat_copy.transpose(), full_matrices=False)
        vv = vh[0]
        vv *= -1
        dsi_permuted = np.abs(np.dot(projection, vv))
        dsis_permuted.append(dsi_permuted)
    p1 = 1 - percentileofscore(dsis_permuted, dsi1) / 100
    ax11.hist(dsis_permuted, 'auto', histtype='stepfilled')
    ax11.vlines(dsi1, 0, ax11.get_ylim()[1], 'r')
    ax11.text(dsi1, 0, "p={:.2f}".format(p1), color='r',
              horizontalalignment='center', verticalalignment='top')
    
    ax20.matshow(np.expand_dims(mean_spike_counts, -1), cmap=cmap)
       
    ax01.set_ylim(None, 1)
    ax10.spines['top'].set_visible(False)
    ax10.spines['bottom'].set_visible(False)
    ax10.spines['right'].set_visible(False)
    ax10.spines['left'].set_visible(False)
    ax11.spines['top'].set_visible(False)
    ax11.spines['bottom'].set_visible(False)
    ax11.spines['right'].set_visible(False)
    ax11.spines['left'].set_visible(False)
    ax11.xaxis.set_ticks([])
    ax10.yaxis.set_ticks([])
    ax10.xaxis.set_ticks([])
    ax11.yaxis.set_ticks([])
    ax21.xaxis.set_ticks([])
    ax20.xaxis.set_ticks([])
    ax31.xaxis.set_ticks([])
    ax10.xaxis.set_ticklabels([])
    ax11.xaxis.set_ticklabels([])
    ax10.yaxis.set_ticklabels([])
    ax11.yaxis.set_ticklabels([])
    ax20.xaxis.set_ticklabels([])
    ax21.xaxis.set_ticklabels([])
    ax20.yaxis.set_ticklabels(angles_deg)
    ax21.yaxis.set_ticklabels(angles_deg)
    ax31.xaxis.set_ticklabels([])
    ax31.yaxis.set_ticklabels(angles_deg)
    ax10.set_xlabel('DSi')
    ax11.set_xlabel('DSi')
    ax31.set_xlabel('Time')
    ax30.set_axis_off()
    
    # fig.subplots_adjust(wspace=0, hspace=0)

    canvas.print_figure(os.path.join(output_path, cell_label), 
                        bbox_inches='tight')

  
