In [None]:
import keras
import tensorflow as tf
import keras.layers as layers
from keras.preprocessing.image import load_img, img_to_array, array_to_img
import tensorflow_probability as tfp
import numpy as np
from math import floor, ceil
from tensorflow.python.ops import math_ops
from tensorflow import math, random, shape
import os
from keras.losses import MeanSquaredError, BinaryCrossentropy
from keras.optimizers import Nadam, SGD, Adam, Adamax
from keras.activations import sigmoid
from tensorflow import convert_to_tensor as tens
from keras import backend as K
from cv2 import getGaborKernel as Gabor
from functools import reduce
from matplotlib import pyplot as plt
from math import sqrt
import itertools
import re
from random import shuffle, seed
from tensorflow.keras.utils import Sequence
from keras.constraints import NonNeg
from keras.regularizers import l1,l2,l1_l2
from keras.initializers import RandomNormal
import pickle
import time

In [None]:
BATCH_SIZE = 32
EXCITATORY_SYNAPSES_WANTED = 40
INHIBITORY_SYNAPSES_WANTED = 15
PRESYNAPTIC_THRESHOLD = .001
N_EXC = 639
SHD_NEURONS = 700
SHD_MAT_TIME = 1400


LABELS = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
LABELS += ["null", "eins", "zwei", "drei", "vier", "fÃ¼nf", "sechs", "sieben", "acht", "neun"]
NLABELS = 10

# Transform data to dataset

In [None]:
class SHDDirectoryGeneratorSequence(Sequence):
    def __init__(self, dct_label, dtype='.jpeg', balance=True, randomize=True, noise=0, random_seed=1331, validation_split=None, is_validation=False, batch_size=32):
        self.directories = dct_label
        self.dtype = dtype
        self.randomize = randomize
        self.seed = random_seed
        self.validation_split = validation_split
        self.is_validation = is_validation
        self.batch_size = batch_size
        self.x = self.y = None
        self.balance = True
        self.noise = None if not noise else NoiseLayer(noise)
        if set(dct_label.values()) != {0,1}: raise Exception("Labels should be 0 and 1.")
            
        self._create_files()
        print(self)
        
    def _create_files(self):
        lbls_dct = {0:[], 1:[]}
        files = []
        for directory,label in self.directories.items():
            lbls_dct[label] += [(directory+ ('' if directory[-1] == '/' else '/') +i, label) 
                                for i in os.listdir(directory) if re.findall(self.dtype, i)]
        if self.balance:

            seed(self.seed)
            l0 = len(lbls_dct[0])
            l1 = len(lbls_dct[1])
            if l0 > l1:
                arr = np.array([1]*l1+[0]*(l0-l1))
                shuffle(arr)
                new = []
                for i, j in zip(lbls_dct[0], list(arr)):
                    if j: new.append(i)
                lbls_dct[0] = new
            elif l0<l1:
                arr = np.array([1]*l0+[0]*(l1-l0))
                shuffle(arr)
                new = []
                for i, j in zip(lbls_dct[1], list(arr)):
                    if j: new.append(i)
                lbls_dct[1] = new
        files = lbls_dct[0] + lbls_dct[1]
                
        if self.randomize:
            seed(self.seed)
            shuffle(files)
        if self.validation_split:
            if self.is_validation:
                files = files[floor(len(files) - len(files)*self.validation_split):]
            else:
                files = files[:floor(len(files) - len(files)*self.validation_split)]
                
        self.x, self.y = zip(*files)
        self.y = np.array(self.y)
    
    def __len__(self):
        return ceil(self.y.shape[0] / self.batch_size)
    
    def __getitem__(self, idx):
        batch_x_pre, batch_y = self.x[idx * self.batch_size:(idx + 1) * self.batch_size], self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = [self.getSpikeTrain(file) for file in batch_x_pre]
        return np.array(batch_x), batch_y
    
    def getSpikeTrain(self, file):
        image = load_img(file)
        image = img_to_array(image)[:,:,0] / 255
        if self.noise is not None:
            image = self.noise(image)
        return image
    
    def __str__(self):
        return ("Training" if not self.is_validation else "Validation") + f" generator: Got {len(self.x)} files and 2 categories"

In [None]:
class SHDDirectoryGeneratorSequenceCategories(Sequence):
    def __init__(self, dct_label, dtype='.jpeg', randomize=True, random_seed=1331, validation_split=None, is_validation=False, batch_size=32):
        self.directories = dct_label
        self.categories = sorted(list({cat for cat in dct_label.values()}))
        self.dtype = dtype
        self.randomize = randomize
        self.seed = random_seed
        self.validation_split = validation_split
        self.is_validation = is_validation
        self.batch_size = batch_size
        self.x = self.y = None
        self._create_files()
        
    def _create_files(self):
        files = []
        for directory,label in self.directories.items():
            files += [(directory+ ('' if directory[-1] == '/' else '/') +i, 
                       [1. if cat==label else 0. for cat in self.categories]) 
                           for i in os.listdir(directory) 
                           if re.findall(self.dtype, i)]
        if self.randomize:
            seed(self.seed)
            shuffle(files)
        if self.validation_split:
            if self.is_validation:
                files = files[floor(len(files) - len(files)*self.validation_split):]
            else:
                files = files[:floor(len(files) - len(files)*self.validation_split)]
                
        self.x, self.y = zip(*files)
        self.y = np.array(self.y)
        print(self)
    
    def __len__(self):
        return ceil(self.y.shape[0] / self.batch_size)
    
    def __getitem__(self, idx):
        batch_x_pre, batch_y = self.x[idx * self.batch_size:(idx + 1) * self.batch_size], self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = [self.getSpikeTrain(file) for file in batch_x_pre]
        return np.array(batch_x), batch_y
    
    def getSpikeTrain(self, file):
        image = load_img(file)
        image = img_to_array(image)[:,:,0] / 255
        return image
    def __str__(self):
        return ("Training" if not self.is_validation else "Validation") + f" generator: Got {len(self.x)} files and {len(self.categories)} categories"

In [None]:
class NoiseLayer(keras.layers.Layer):
    def __init__(self, p=.003):
        super().__init__(name="NoiseLayer")
        self.p = p
        
    def call(self, inputs, training=None):
        if training is not True:
            noise = layers.Lambda(lambda x: self.get_rand(x))(tf.shape(inputs))
            together = K.stack([inputs, noise], axis=-1)
            return together.max(axis=-1)
        else: return inputs
        
    def get_rand(self, shape):
        return K.random_bernoulli(shape, self.p, dtype=tf.int32)[0]

In [None]:
# LABELS = ["NoI", "I"]
# INums = {3, 5, 6, 8, 9}
# train_ds = SHDDirectoryGeneratorSequence({f"../input/matrixshd/train/{i}": int(i in INums) for i in range(10)}, validation_split=0.2, batch_size=BATCH_SIZE)
# valid_ds = SHDDirectoryGeneratorSequence({f"../input/matrixshd/train/{i}": int(i in INums) for i in range(10)}, validation_split=0.2, is_validation=True, batch_size=BATCH_SIZE)
# test_ds = SHDDirectoryGeneratorSequence({f"../input/matrixshd/test/{i}": int(i in INums) for i in range(10)}, validation_split=0, is_validation=False, batch_size=BATCH_SIZE)

In [None]:
train_ds = SHDDirectoryGeneratorSequenceCategories({f"../input/matrixshd/train/{i}": i for i in range(10)}, validation_split=0.2, batch_size=BATCH_SIZE)
valid_ds = SHDDirectoryGeneratorSequenceCategories({f"../input/matrixshd/train/{i}": i for i in range(10)}, validation_split=0.2, is_validation=True, batch_size=BATCH_SIZE)
test_ds = SHDDirectoryGeneratorSequenceCategories({f"../input/matrixshd/test/{i}": i for i in range(10)}, validation_split=0, is_validation=False, batch_size=BATCH_SIZE)

In [None]:
plt.figure(figsize=(40,40))
for image, label in train_ds:
    labels = np.where(label==1)[1]
    for i in range(BATCH_SIZE):
        plt.subplot(BATCH_SIZE//4, 4, i+1)
        plt.imshow(image[i], aspect='auto', origin='lower', cmap="binary")
        plt.title(LABELS[labels[i]])
    break

# Use David Beniaguev's (selfishgene) trained L5PC model

In [None]:
dataset_folder = '/kaggle/input/single-neurons-as-deep-nets-nmda-test-data'

models_folder  = os.path.join(dataset_folder, 'Models')

model_name = "NMDA_TCN__DWT_8_224_217__model.h5"
# model_name = 'NMDA_TCN__DWT_7_128_153__model.h5'

model_filename  = os.path.join(models_folder, model_name)



old_model = keras.models.load_model(model_filename)
# old_model.summary()

inp = keras.Input(shape=old_model.layers[0].input.shape[1:])
x = old_model.layers[1](inp)
for layer in old_model.layers[2:-3]:
    x = layer(x)
output = old_model.layers[-3](x),old_model.layers[-2](x) 
L5PC_model = keras.Model(inp, output, name="L5PC")

for layer in L5PC_model.layers:
    layer.trainable = False

L5PC_model.summary()

In [None]:
class WiringLayer(keras.layers.Layer):
    def __init__(self, filters=1278):
        super().__init__(name="WiringLayer")
        self.filters = filters
        
    def build(self, shape):
        self.conv = keras.layers.Conv1D(self.filters, 1, use_bias=False)
        self.conv.build(shape)
        weights = np.zeros((1, shape[-1], self.filters))
        weights[:, 0] = 1
        np.apply_along_axis(np.random.shuffle,-2,weights) 
        self.conv.set_weights([weights])
        self.conv.trainable = False

    def call(self, x):
        return self.conv(x)

In [None]:
synapses = 30
presynaptic = 20
time = 50

wiring_layer = WiringLayer(synapses)
wiring_layer.build((time, presynaptic))

# example input
inp = np.zeros((1, time, presynaptic))
inp[:, 5:10, 0] = 1
inp[:,20:25, 0] = 1
inp[:, 12:18, 1] = 1
inp[:, 20:time, presynaptic//2] = 1

result = wiring_layer(inp)
weights = wiring_layer.get_weights()[0]

plt.figure(figsize=((10, 10)))
plt.suptitle("Example of Direct Random Wiring (One Presynaptic Neuron per Synapse)")

plt.subplot(2,2,1)
plt.imshow(weights[0], cmap="binary", aspect="auto")
plt.title("Wiring Layer Weights (Convolution Weights)")
plt.ylabel("preSynaptic Neuron")
plt.xlabel("Synapse (kernel)")

plt.subplot(2,2,2)
plt.title("Spike Train (Example Input)")
plt.imshow(inp[0,:,:].T, aspect='auto', cmap="binary")
plt.ylabel("preSynaptic Neuron")
plt.xlabel("Time (ms)")
plt.xticks([])
plt.yticks([])

plt.subplot(2,2,4)
plt.title("Synapse Train (Neuron input)")
plt.ylabel("Synapse")
plt.xlabel("Time (ms)")
plt.imshow(result[0,:,:].numpy().T, aspect='auto', cmap="binary")

plt.show()

# Neuron Calling Layer

In [None]:
class CallNeuron(keras.layers.Layer):
    def __init__(self, neuronModule, spare=150, name="CallNeuronLayer"):
        super().__init__(name=f"CallNeuron-{neuronModule.name}")
        self.neuron = neuronModule
        self.spare = spare
        self.neuronTime = neuronModule.input.shape[-2]
        self.synapses = neuronModule.input.shape[-1]
        self.fullTime = None
        self.timePerRun = self.neuronTime-self.spare
        self.times = None
        self.paddingSize = 0
        
    def build(self, shape):
        print("The neuron-mimicing module is:", self.neuron.name)
        if shape[-1] != self.synapses:
            raise Exception("Wrong number of synapses!")
        self.fullTime = shape[-2]
        self.paddingSize = (self.spare + self.fullTime)%self.timePerRun
        self.times = [((self.timePerRun)*i, (self.timePerRun)*i+self.neuronTime) for i in range(self.fullTime//self.timePerRun)]
        print("times:", self.times)
        print("paddingSize:", self.paddingSize)
        print("timePerRun:", self.timePerRun)
        print("fullTime:", self.fullTime)
        
    @tf.autograph.experimental.do_not_convert
    def call(self, inputs, training=None):
        inputs = layers.ZeroPadding1D(padding=((0,self.paddingSize)))(inputs)
        first = self.neuron(inputs[:,self.times[0][0]:self.times[0][1]])
        after = [self.neuron(inputs[:,i:j]) for i,j in self.times[1:]]
        aps = [first[0]]+[ap[:,self.spare:] for ap, _ in after]
        vs = [first[-1]]+[v[:,self.spare:] for _, v in after]
        concatenated_aps = layers.Concatenate(axis=-2)(aps)
        concatenated_vs = layers.Concatenate(axis=-2)(vs)
        outputs = layers.Flatten()(concatenated_aps), layers.Flatten()(concatenated_vs)
        return outputs

# Module

In [None]:
def create_model(drop_rate=.8):
    inp = keras.Input(shape=(SHD_NEURONS, SHD_MAT_TIME))
    converted = WiringLayer(2*N_EXC)(tf.transpose(inp, perm=[0,2,1]))
    if drop_rate: dropped = keras.layers.Lambda(lambda x, training: layers.Dropout(drop_rate)(x, True))(converted)
    ap, v = CallNeuron(L5PC_model)(dropped if drop_rate else converted)
    output = (dropped if drop_rate else converted, ap, v)
    model = keras.Model(inp, output)
    return model

In [None]:
model = create_model()

In [None]:
model.summary()

In [None]:
plt.figure(figsize=(30,10))
model = create_model(0)
for image, label in train_ds:
    converted, ap, v = model(image)
    converted = np.transpose(converted.numpy(), (0,2,1))
    labels = np.where(label==1)[1]
    for i in range(4):
        
        start = None
        ending = None
        for ms in range(0, image.shape[-1]): 
            if start is None and image[i, :, ms:ms+10].sum(axis=0).mean() > 5:
                start = ms
            elif start is not None and image[i, :, ms:ms+10].sum(axis=0).mean()<2:
                ending = ms
                break
        
        plt.subplot(3, 4, i+1)
        plt.imshow(image[i],aspect='auto',origin='lower', cmap="binary")
        plt.title("Input - " + LABELS[labels[i]])
        if not i: plt.ylabel("Cochlear Neurons")
        else: plt.yticks([])
        plt.xticks([])
        
        plt.subplot(3,4,4+i+1)
        plt.imshow(converted[i], cmap="binary", aspect="auto")
        
        plt.title(f"Synapse Train [{start}, {ending}] (Exc.{str(round(converted[i, :N_EXC, start:ending].sum(axis=1).mean(),2))}, Inh.{str(round(converted[i, N_EXC:, start:ending].sum(axis=1).mean(),2))})")
        if not i: plt.ylabel("L5PC Synapse")
        else: plt.yticks([])
        plt.xticks([])
        
        
        plt.subplot(3,4, 8+i+1)
        plt.title("NN Prediction")
        plt.plot(ap[i]*120-80, label=f"AP {str(round(ap[i].numpy().max()*100))}%")
        plt.plot(v[i]-67.7, label="v")
        if not i: plt.ylabel("Soma Voltage (normalised AP %s)")
        else: plt.yticks([])
        plt.xlabel("Time (ms)")
        plt.ylim((-90, 70))
        plt.legend()
    break
    

# Transform To Dataset

In [None]:
main_dir = '../working/shd'
os.makedirs(main_dir)
how_many_to_transform_per_label = None  # set to zero or None to convert all the dataset
datasets_to_convert = [(train_ds, "training"), (valid_ds, "validation"), (test_ds, "test")]
datasets_to_convert = datasets_to_convert[1:]  # change at your will

In [None]:
import shutil
def zip_and_delete(directory, zip_name, to_zip=True):
    if not to_zip: return
    shutil.make_archive(zip_name, 'zip', directory)
    print(f'Done Zipping! Check for {zip_name}.')
    shutil.rmtree(directory)
    print('Done erasing photos!')

In [None]:
for ds, dir_name in datasets_to_convert:
    dct = {lbl: 0 for lbl in range(20)}
    os.makedirs(main_dir+"/"+dir_name)
    for lbl in range(10):
        os.makedirs(f'{main_dir}/{dir_name}/{lbl}')
    
    for img, label in ds:
        converted, aps, vs = model(img)
        labels = np.where(label==1)[1]

        for i in range(converted.shape[0]):
            im = img[i]
            lbl = labels[i]
            inp = converted[i]
            ap = aps[i]
            v = vs[i]
            curDir = f'{main_dir}/{dir_name}/{lbl}/{dct[lbl]}'
            os.makedirs(curDir)
            with open(curDir+"/image.npy", 'wb') as f: np.save(f, im)
            with open(curDir+"/matrix.npy", 'wb') as f: np.save(f, inp)
            with open(curDir+"/spikePrediction.npy", 'wb') as f: np.save(f, ap)
            with open(curDir+"/voltagePrediction.npy", 'wb') as f: np.save(f, v)
            dct[lbl] += 1
            print(curDir)
    zip_and_delete(main_dir+'/'+dir_name, main_dir+'/'+dir_name, to_zip=True)