# Initial Setup

In [1]:
import warnings
warnings.filterwarnings("ignore")
import sys
from herdingspikes.hs2 import HSDetection
from herdingspikes.probe import HierlmannVisapyEmulationProbe
import matplotlib.pyplot as plt
import h5py
import numpy as np
from PIL import Image

%matplotlib inline

np.random.seed(0)

In [2]:
# raw data location
data_path = 'visapy_data.npy'

In [3]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils import data

torch.cuda.set_device(3)
torch.manual_seed(0)

<torch._C.Generator at 0x7f3e7e96d690>

In [4]:
import cv2
import copy

# Spike Detection

In [5]:
# detection parameters
to_localize = True
cutout_start = 14
cutout_end = 36
threshold = 24
file_directory = 'results/'
file_name = 'ProcessedSpikes_visapy'

Probe = HierlmannVisapyEmulationProbe('visapy_data.npy')
H = HSDetection(Probe, to_localize, cutout_start, cutout_end, threshold,
                maa=0, maxsl=12, minsl=3, ahpthr=0, out_file_name=file_name, 
                file_directory_name=file_directory, save_all=True)

H.DetectFromRaw()
H.LoadDetected()

File size: 122400102
Number of channels: 102
# Sampling rate: 32000
# Localization On
# Not Masking any Channels
# Writing out extended detection info
# Number of recorded channels: 102
# Analysing frames: 1200001, Seconds:37.50003125
# Frames before spike in cutout: 32
# Frames after spike in cutout: 70
# tcuts: 44 59
# tInc: 50000
# Analysing 50000 frames; from -44 to 50059
# Analysing 50000 frames; from 49956 to 100059
# Analysing 50000 frames; from 99956 to 150059
# Analysing 50000 frames; from 149956 to 200059
# Analysing 50000 frames; from 199956 to 250059
# Analysing 50000 frames; from 249956 to 300059
# Analysing 50000 frames; from 299956 to 350059
# Analysing 50000 frames; from 349956 to 400059
# Analysing 50000 frames; from 399956 to 450059
# Analysing 50000 frames; from 449956 to 500059
# Analysing 50000 frames; from 499956 to 550059
# Analysing 50000 frames; from 549956 to 600059
# Analysing 50000 frames; from 599956 to 650059
# Analysing 50000 frames; from 649956 to 700059

# Loading and Setting up Ground Truths

## Waveforms

In [6]:
data = np.load('visapy_data.npy')

In [7]:
frame_set = data.reshape((1200001,102))
print("Frame shape: ", frame_set.shape)
print("Whole Dataset: ", frame_set)

Frame shape:  (1200001, 102)
Whole Dataset:  [[  4 -18   3 ...   5 -11  -4]
 [  3 -26  -4 ...   6 -13   6]
 [  3 -25  -6 ...   5 -18   3]
 ...
 [  1  32  -5 ...   7   2   0]
 [-17  27  -9 ...  12  -1  11]
 [-13  21  -8 ...  15   0  10]]


## Ground Truths

In [8]:
somapos = np.loadtxt('ViSAPy_somapos.gdf')
ground_truth = np.loadtxt('ViSAPy_ground_truth.gdf')

ground_truth = ground_truth.astype(int)
somapos = np.array([somapos[:,0], somapos[:,1]])

print(ground_truth.shape)
print(somapos.shape)

(64441, 2)
(2, 56)


## Generate Channel Time Combinations

In [9]:
list_times = H.spikes['t'].tolist()
list_ch = H.spikes['ch'].tolist()

x = H.spikes['x'].tolist()
y = H.spikes['y'].tolist()

print(len(x), len(list_times))

# A list of all time and channel combinations which have spikes
list_times_ch = []
for i in range(len(list_times)):
    list_times_ch.append((list_times[i], list_ch[i]))

19341 19341


## Electrode Positions

In [10]:
x_path = 'z_integer.npy'
y_path = 'y_integer.npy'
elec_x_int = np.load(x_path)
elec_y_int = np.load(y_path)

In [11]:
ky = 9 / np.max(np.diff(sorted(elec_y_int)))
elec_x = elec_x_int * ky
elec_y = elec_y_int * ky
elec_x -= np.min(elec_x)
elec_y -= np.min(elec_y)

In [12]:
datapoints = []
for i in range(len(elec_x)):
    datapoints.append([elec_x[i], elec_y[i]])
datapoints = np.asarray(datapoints)
print(datapoints)

[[ 59.54058193 179.98346095]
 [  0.         233.97886677]
 [ 74.42572741 242.97794793]
 [ 59.54058193  17.99816233]
 [ 29.77029096  35.99632466]
 [ 59.54058193 143.98621746]
 [ 44.65543645 242.97794793]
 [ 14.88514548 296.97243492]
 [ 74.42572741  44.99540582]
 [  0.         107.9898928 ]
 [ 29.77029096 143.98621746]
 [  0.          17.99816233]
 [ 59.54058193  53.99448698]
 [ 44.65543645  80.99264931]
 [ 74.42572741 170.98437979]
 [ 44.65543645 206.98070444]
 [  0.         197.98162328]
 [ 14.88514548 278.97427259]
 [ 44.65543645 296.97243492]
 [ 14.88514548  44.99540582]
 [ 59.54058193  89.99173047]
 [  0.         215.9797856 ]
 [ 74.42572741 278.97427259]
 [ 14.88514548  26.99724349]
 [ 44.65543645  44.99540582]
 [ 44.65543645  98.99081164]
 [ 74.42572741 152.98621746]
 [ 29.77029096   0.        ]
 [ 74.42572741  26.99724349]
 [  0.          53.99448698]
 [ 44.65543645 134.98713629]
 [ 44.65543645  62.99356815]
 [ 59.54058193 287.97335375]
 [ 59.54058193 215.9797856 ]
 [ 59.54058193

# Generate Images - ONLY USE FIRST TIME!

# Loading Images

In [13]:
# Creates a correspondence between HS2 and Ground Truth

new_list_times_ch = []
new_location = {}
neuron_channels = {}

for i, val in enumerate(list_times_ch):
    for ind, j in enumerate(ground_truth[:, 1]):
        if val[0] in range(j-2, j+2):
            neuron = ground_truth[ind, 0] - 1
            new_list_times_ch.append((val[0], val[1]))
            new_location[(val[0], val[1])] = (somapos[0, neuron], somapos[1, neuron])
            try:
                neuron_channels[neuron].append((val[0], val[1]))
            except:
                neuron_channels[neuron] = [(val[0], val[1])]
            break

In [14]:
image_data = {}

for i in range(len(new_list_times_ch)):
    filename = "frames_small_attention/frame_"+str(new_list_times_ch[i][0])+"_"+str(new_list_times_ch[i][1])+".png"
    temp_pic = cv2.imread(filename)
    temp_pic = cv2.cvtColor(temp_pic, cv2.COLOR_RGB2GRAY)
    temp_pic = temp_pic/255
    image_data[(new_list_times_ch[i][0], new_list_times_ch[i][1])] = temp_pic

# Removing False Triggers from Neuron List

In [15]:
# Filter only relevant frames
from collections import Counter

updated_neuron_channels = {}
for i in neuron_channels.keys():
    updated_neuron_channels[i] = []
    elec_list = np.array(neuron_channels[i])[:, 1]
    occ = Counter(elec_list)
    final_elec_list = []
    for d in occ.keys():
        if occ[d] > 80:
            final_elec_list.append(d)
    for j in final_elec_list:
        for k in neuron_channels[i]:
            if k[1] == j:
                updated_neuron_channels[i].append(k)

In [16]:
# New list of times and channels

list_loc_ch = []
for i in updated_neuron_channels.values():
    list_loc_ch = list_loc_ch + i
print(len(list_loc_ch))

8977


In [18]:
import random

random.seed(0)
random.shuffle(list_loc_ch)

train_lim = (int)(0.75*len(list_loc_ch))
val_lim = (int)(0.80*len(list_loc_ch))

list_times_ch_train = list_loc_ch[:train_lim]
list_times_ch_val = list_loc_ch[train_lim:val_lim]
list_times_ch_test = list_loc_ch[val_lim:]

print(len(list_times_ch_train),len(list_times_ch_val), len(list_times_ch_test))

6732 449 1796


In [21]:
updated_neuron_channels.keys()

dict_keys([46, 41, 12, 49, 27, 43, 45, 28, 52, 53, 22, 21, 10, 34, 5, 35, 11, 50, 48, 47, 54, 18, 38, 9, 32, 13, 42, 1, 40, 31, 37, 29, 2, 39, 55, 24, 17, 0, 6, 8, 26, 7, 23, 36, 14, 3, 51, 25, 30, 15, 19, 33, 44, 20, 16, 4])