In [None]:
import keras
import tensorflow as tf
import keras.layers as layers
from keras.preprocessing.image import load_img
from keras.preprocessing.image import array_to_img
import numpy as np
from math import floor, ceil
from tensorflow.python.ops import math_ops
from tensorflow import math, random, shape
import os
from keras.losses import MeanSquaredError, BinaryCrossentropy
from keras.optimizers import Nadam, SGD, Adam, Adamax
from keras.activations import sigmoid
from tensorflow import convert_to_tensor as tens
from keras import backend as K
from cv2 import getGaborKernel as Gabor
from functools import reduce
from matplotlib import pyplot as plt
from math import sqrt
import itertools
import re
from random import shuffle, seed
from tensorflow.keras.utils import Sequence
from keras.constraints import NonNeg
from keras.regularizers import l1,l2,l1_l2
from keras.initializers import RandomNormal
import pandas as pd
import tables

# Transform data to dataset

In [None]:
import os
import urllib.request
import gzip, shutil
from tensorflow.keras.utils import get_file


cache_dir=os.path.expanduser("~/data")
cache_subdir="hdspikes"
print("Using cache dir: %s"%cache_dir)

# The remote directory with the data files
base_url = "https://compneuro.net/datasets"

# Retrieve MD5 hashes from remote
response = urllib.request.urlopen("%s/md5sums.txt"%base_url)
data = response.read() 
lines = data.decode('utf-8').split("\n")
file_hashes = { line.split()[1]:line.split()[0] for line in lines if len(line.split())==2 }

def get_and_gunzip(origin, filename, md5hash=None):
    gz_file_path = get_file(filename, origin, md5_hash=md5hash, cache_dir=cache_dir, cache_subdir=cache_subdir)
    hdf5_file_path=gz_file_path[:-3]
    if not os.path.isfile(hdf5_file_path) or os.path.getctime(gz_file_path) > os.path.getctime(hdf5_file_path):
        print("Decompressing %s"%gz_file_path)
        with gzip.open(gz_file_path, 'r') as f_in, open(hdf5_file_path, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    return hdf5_file_path



# Download the Spiking Heidelberg Digits (SHD) dataset
files = [ "shd_train.h5.gz", 
          "shd_test.h5.gz",
        ]


for fn in files:
    origin = "%s/%s"%(base_url,fn)
    hdf5_file_path = get_and_gunzip(origin, fn, md5hash=file_hashes[fn])
    print(hdf5_file_path)

In [None]:
origin = "%s/%s"%(base_url,"shd_train.h5.gz")
hdf5_file_path = get_and_gunzip(origin, "shd_train.h5.gz", md5hash=file_hashes[fn])
train_filename = hdf5_file_path
train_ds = {"name": train_filename, "fileh": tables.open_file(train_filename, mode='r')}
train_ds["units"] = train_ds['fileh'].root.spikes.units
train_ds["times"] = train_ds['fileh'].root.spikes.times
train_ds["labels"] = train_ds['fileh'].root.labels


origin = "%s/%s"%(base_url,"shd_test.h5.gz")
hdf5_file_path = get_and_gunzip(origin, "shd_test.h5.gz", md5hash=file_hashes[fn])
test_filename = hdf5_file_path
test_ds = {"name": train_filename, "fileh": tables.open_file(test_filename, mode='r')}
test_ds["units"] = test_ds['fileh'].root.spikes.units
test_ds["times"] = test_ds['fileh'].root.spikes.times
test_ds["labels"] = test_ds['fileh'].root.labels

In [None]:
# At this point we can visualize some of the data

fileh = tables.open_file(hdf5_file_path, mode='r')
units = fileh.root.spikes.units
times = fileh.root.spikes.times
labels = fileh.root.labels

# This is how we access spikes and labels
index = 0
print("Times (ms):", times[index])
print("Unit IDs:", units[index])
print("Label:", labels[index])


# A quick raster plot for one of the samples
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(16,4))
idx = np.random.randint(len(times),size=3)
for i,k in enumerate(idx):
    ax = plt.subplot(1,3,i+1)
    ax.scatter(times[k],700-units[k], color="k", alpha=0.33, s=2)
    ax.set_title("Label %i"%labels[k])
    ax.axis("off")
plt.show()

In [None]:
while True:
    i+=1
    if max(train_ds["times"][i]) > 1.3:
        plt.scatter(train_ds["times"][i], 700-train_ds["units"][i], color="k", alpha=0.33, s=2)
        break

In [None]:
plt.figure(figsize=(20,10))

plt.subplot(3,1,1)
plt.scatter(train_ds["times"][i], 700-train_ds["units"][i], color="k", alpha=0.33, s=2)
plt.xlim((0,1.5))

plt.subplot(3,1,2)
plt.scatter((1000*train_ds["times"][i]).astype('int'), 700-train_ds["units"][i], color="k", alpha=0.33, s=2)
plt.xlim((0,1500))

plt.subplot(3,1,3)
matrix = np.zeros((700,1500))
matrix[700-train_ds["units"][i]-1, (1000*train_ds["times"][i]).astype('int')] = 1
plt.imshow(matrix, cmap='binary', vmin=0, vmax=1, aspect='auto', origin='lower')
plt.xlim((0,1500))

print(np.sqrt(np.mean(np.sum(((1000*train_ds["times"][i]).astype('int') - 1000*train_ds["times"][i])**2))))

In [None]:
os.makedirs('../working/shd')

In [None]:
from PIL import Image  

In [None]:
def shd_to_images(processor, ds, out_dirc):
    for i in range(ds["times"].shape[0]):
        direc = out_dirc + "/" + str(ds["labels"][i])
        if not os.path.exists(direc): os.makedirs(direc)
        new_filename = direc + "/" + str(i) + ".jpeg"
        img = Image.fromarray(processor(ds["times"][i], ds["units"][i]))
        if img.mode != 'RGB':
            img = img.convert('RGB')
        img.save(new_filename)
        print('Converted:', new_filename)

In [None]:
def convert_spikes_to_matrix(time, unit, nTime=1400, nUnits=700):
    matrix = np.zeros((nUnits,nTime))
    matrix[nUnits-1-unit, (1000*time).astype('int')] = 255
    return matrix#np.stack([matrix, np.zeros((nUnits,nTime)),np.zeros((nUnits,nTime))], axis=-1)

In [None]:
def run_convert(ds, path):
    os.makedirs(path)
    shd_to_images(convert_spikes_to_matrix, ds, path)

In [None]:
run_convert(train_ds,  '../working/shd/train')
run_convert(test_ds,  '../working/shd/test')
shutil.make_archive("../working/shd", 'zip', "../working/shd")
shutil.rmtree("../working/shd")