In [None]:
import keras
import tensorflow as tf
import keras.layers as layers
from keras.preprocessing.image import load_img, img_to_array, array_to_img
import tensorflow_probability as tfp
import numpy as np
from math import floor, ceil
from tensorflow.python.ops import math_ops
from tensorflow import math, random, shape
import os
from keras.losses import MeanSquaredError, BinaryCrossentropy
from keras.optimizers import Nadam, SGD, Adam, Adamax
from keras.activations import sigmoid
from tensorflow import convert_to_tensor as tens
from keras import backend as K
from cv2 import getGaborKernel as Gabor
from functools import reduce
from matplotlib import pyplot as plt
from math import sqrt
import itertools
import re
from random import shuffle, seed
from tensorflow.keras.utils import Sequence
from keras.constraints import NonNeg
from keras.regularizers import l1,l2,l1_l2
from keras.initializers import RandomNormal
import pickle
import time
import pandas as pd

In [None]:
BATCH_SIZE = 32
EXCITATORY_SYNAPSES_WANTED = 8
INHIBITORY_SYNAPSES_WANTED = 6
PRESYNAPTIC_THRESHOLD = .001
N_EXC = 639
LABELS = ["EVEN", "ODD"]
LABELS = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
NLABELS = 2

# Transform data to dataset

In [None]:
class SHDDirectoryGeneratorSequence(Sequence):
    def __init__(self, dct_label, dtype='.jpeg', balance=True, randomize=True, noise=0, delay=0, random_seed=1331, validation_split=None, is_validation=False, batch_size=32):
        self.directories = dct_label
        self.dtype = dtype
        self.randomize = randomize
        self.seed = random_seed
        self.validation_split = validation_split
        self.is_validation = is_validation
        self.batch_size = batch_size
        self.x = self.y = None
        self.balance = True
        self.noise = None if not noise else NoiseLayer(noise)
        self.delay = None if not delay else GeometricDelay(delay)
        if set(dct_label.values()) != {0,1}: raise Exception("Labels should be 0 and 1.")
            
        self._create_files()
        print(self)
        
    def _create_files(self):
        lbls_dct = {0:[], 1:[]}
        files = []
        for directory,label in self.directories.items():
            lbls_dct[label] += [(directory+ ('' if directory[-1] == '/' else '/') +i, label) 
                                for i in os.listdir(directory) if re.findall(self.dtype, i)]
        if self.balance:

            seed(self.seed)
            l0 = len(lbls_dct[0])
            l1 = len(lbls_dct[1])
            if l0 > l1:
                arr = np.array([1]*l1+[0]*(l0-l1))
                shuffle(arr)
                new = []
                for i, j in zip(lbls_dct[0], list(arr)):
                    if j: new.append(i)
                lbls_dct[0] = new
            elif l0<l1:
                arr = np.array([1]*l0+[0]*(l1-l0))
                shuffle(arr)
                new = []
                for i, j in zip(lbls_dct[1], list(arr)):
                    if j: new.append(i)
                lbls_dct[1] = new
        files = lbls_dct[0] + lbls_dct[1]
                
        if self.randomize:
            seed(self.seed)
            shuffle(files)
        if self.validation_split:
            if self.is_validation:
                files = files[floor(len(files) - len(files)*self.validation_split):]
            else:
                files = files[:floor(len(files) - len(files)*self.validation_split)]
                
        self.x, self.y = zip(*files)
        self.y = np.array(self.y)
    
    def __len__(self):
        return ceil(self.y.shape[0] / self.batch_size)
    
    def __getitem__(self, idx):
        batch_x_pre, batch_y = self.x[idx * self.batch_size:(idx + 1) * self.batch_size], self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = [self.getSpikeTrain(file) for file in batch_x_pre]
        return np.stack(batch_x, axis=0), batch_y
    
    def getSpikeTrain(self, file):
        image = load_img(file)
        image = img_to_array(image)[:,:,np.newaxis,0] / 255
        if self.noise is not None:
            image = self.noise(image, training=not self.is_validation)
        if self.delay is not None:
            image = self.delay(image, training=not self.is_validation)
        return image
    
    def __str__(self):
        return ("Training" if not self.is_validation else "Validation") + f" generator: Got {len(self.x)} files and 2 categories"

In [None]:
class SHDDirectoryGeneratorSequenceCategories(Sequence):
    def __init__(self, dct_label, dtype='.jpeg', randomize=True, random_seed=1331, validation_split=None, is_validation=False, batch_size=32):
        self.directories = dct_label
        self.categories = sorted(list({cat for cat in dct_label.values()}))
        self.dtype = dtype
        self.randomize = randomize
        self.seed = random_seed
        self.validation_split = validation_split
        self.is_validation = is_validation
        self.batch_size = batch_size
        self.x = self.y = None
        self._create_files()
        
    def _create_files(self):
        files = []
        for directory,label in self.directories.items():
            files += [(directory+ ('' if directory[-1] == '/' else '/') +i, 
                       [1. if cat==label else 0. for cat in self.categories]) 
                           for i in os.listdir(directory) 
                           if re.findall(self.dtype, i)]
        if self.randomize:
            seed(self.seed)
            shuffle(files)
        if self.validation_split:
            if self.is_validation:
                files = files[floor(len(files) - len(files)*self.validation_split):]
            else:
                files = files[:floor(len(files) - len(files)*self.validation_split)]
                
        self.x, self.y = zip(*files)
        self.y = np.array(self.y)
        print(self)
    
    def __len__(self):
        return ceil(self.y.shape[0] / self.batch_size)
    
    def __getitem__(self, idx):
        batch_x_pre, batch_y = self.x[idx * self.batch_size:(idx + 1) * self.batch_size], self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = [self.getSpikeTrain(file) for file in batch_x_pre]
        return np.array(batch_x), batch_y
    
    def getSpikeTrain(self, file):
        image = load_img(file)
        image = img_to_array(image)[:,:,np.newaxis,0] / 255
        return image
    
    def __str__(self):
        return ("Training" if not self.is_validation else "Validation") + f" generator: Got {len(self.x)} files and {len(self.categories)} categories"

In [None]:
class NoiseLayer(keras.layers.Layer):
    def __init__(self, p=.003):
        super().__init__(name="NoiseLayer")
        self.p = p
        self.shape = None
    def call(self, inputs, training=None):
        if training is not True:
            allshape = 1
            for i in inputs.shape: allshape *=i
            noise = (np.random.rand(allshape) > (1-self.p)).reshape(inputs.shape)[np.newaxis]
            together = np.concatenate([inputs[np.newaxis], noise])
            return together.max(axis=0)
        else: return inputs

In [None]:
class GeometricDelay(keras.layers.Layer):
    def __init__(self, p=.02, to_plot=False):
        super().__init__(name="GeometricDelay")
        self.p = p
        self.to_plot = to_plot
        
    def call(self, inputs, training=None):
        if training is True or self.to_plot:
            fulltime = inputs.shape[-2]
            delay = tf.keras.layers.Lambda(lambda x: self.get_delay())(None)
#             delay = self.get_delay()
            if len(inputs.shape) == 4:
                inputs = layers.ZeroPadding2D(((0,0), (delay,0)))(inputs)[:, :, :fulltime, :]
            elif len(inputs.shape) == 3:
                inputs = layers.ZeroPadding1D((delay,0))(inputs)[:, :fulltime, :]
            return inputs
        else: return inputs
        
    def get_delay(self):
#         return np.random.geometric(self.p, 1)
        return tf.cast(tfp.distributions.Geometric(probs=[self.p]).sample()[0], dtype=tf.int32)

In [None]:
LABELS = ["NoI", "I"]
INums = {3, 5, 6, 8, 9}
train_ds = SHDDirectoryGeneratorSequence({f"../input/matrixshd/train/{i}": int(i in INums) for i in range(10)}, validation_split=0.2, batch_size=BATCH_SIZE, noise=False, delay=.01)
valid_ds = SHDDirectoryGeneratorSequence({f"../input/matrixshd/train/{i}": int(i in INums) for i in range(10)}, validation_split=0.2, is_validation=True, batch_size=BATCH_SIZE)
test_ds = SHDDirectoryGeneratorSequence({f"../input/matrixshd/test/{i}": int(i in INums) for i in range(10)}, validation_split=0, is_validation=False, batch_size=BATCH_SIZE)

In [None]:
# train_ds = SHDDirectoryGeneratorSequenceCategories({f"../input/matrixshd/train/{i}": i%10 for i in range(10)}, validation_split=0.2, batch_size=BATCH_SIZE)
# valid_ds = SHDDirectoryGeneratorSequenceCategories({f"../input/matrixshd/train/{i}": i%10 for i in range(10)}, validation_split=0.2, is_validation=True, batch_size=BATCH_SIZE)

In [None]:
# train_ds = SHDDirectoryGeneratorSequenceCategories({f"../input/matrixshd/train/{i}": i==0 for i in range(NLABELS)}, validation_split=0.2, batch_size=BATCH_SIZE)
# valid_ds = SHDDirectoryGeneratorSequenceCategories({f"../input/matrixshd/train/{i}": i==0 for i in range(NLABELS)}, validation_split=0.2, is_validation=True, batch_size=BATCH_SIZE)

Binary

In [None]:
plt.figure(figsize=(40,40))
image, label = train_ds.__getitem__(0)
for i in range(BATCH_SIZE):
    plt.subplot(BATCH_SIZE//4, 4, i+1)
    plt.imshow(image[i,:,:,0],aspect='auto',origin='lower', cmap="binary")
    plt.title(LABELS[label[i]])

Categorical

In [None]:
# plt.figure(figsize=(40,40))
# for _, (image, label) in enumerate(train_ds):
#     for i in range(BATCH_SIZE):
#         plt.subplot(BATCH_SIZE//4, 4, i+1)
#         plt.imshow(image[i,:,:,0],aspect='auto',origin='lower', cmap="binary")
#         plt.title(LABELS[np.where(label[i] == 1.)[0][0]])
#     break

# Use David Beniaguev's (selfishgene) trained L5PC model

In [None]:
class Module:
    def __init__(self, path, module_name, name):
        self.name = name
        self.module = self.create_module(path, module_name)
    
    def create_module(self, path, module_name):
        models_folder  = os.path.join(path, 'Models')
        model_filename  = os.path.join(models_folder, module_name)
        old_model = keras.models.load_model(model_filename)
        inp = keras.Input(shape=old_model.layers[0].input.shape[1:])
        x = old_model.layers[1](inp)
        for layer in old_model.layers[2:-2-1]:
            x = layer(x)
        output = old_model.layers[-3](x)
        module = keras.Model(inp, output)
#         L5PC_model.compile(optimizer=Nadam(lr=0.0001), loss='binary_crossentropy', loss_weights=[1.])

        for layer in module.layers:
            layer.trainable = False

        module.summary()
        return module  

In [None]:
multiple_neurons = False
dataset_folder = '/kaggle/input/single-neurons-as-deep-nets-nmda-test-data'

In [None]:
if multiple_neurons:
    neurons = [Module(dataset_folder, "NMDA_TCN__DWT_8_224_217__model.h5", "Module_8_layers"), 
           Module(dataset_folder, "NMDA_TCN__DWT_7_292_169__model.h5", "Module_7_layers"), 
           Module(dataset_folder, "NMDA_TCN__DWT_9_256_241__model.h5", "Module_9_layers")]
else:
    neurons = Module(dataset_folder, "NMDA_TCN__DWT_7_128_153__model.h5", "module_7_layers")

# Some custom layers and funcs to use

In [None]:
class ToBoolLayer(keras.layers.Layer):
    def __init__(self, threshold=0.5, mult=15, use_sigmoid=True, use_special_sigmoid=None, name="ToBoolLayer"):
        super().__init__(name=name)
        self.use_sigmoid = use_sigmoid
        self.special_sigmoid = use_special_sigmoid
        self.threshold = threshold
        self.mult = mult
    def build(self, shape):
        if self.special_sigmoid is not None:
            if not len(self.special_sigmoid) == 2:
                raise Exception("Special sigmoid should be in format (pos_mult, neg_mult).")
            else:
                self.sigmoid = SigmoidThreshold(*self.special_sigmoid, self.threshold)
        elif self.use_sigmoid:
            self.sigmoid = SigmoidThresholdEasier(threshold=self.threshold, mult=self.mult)
    
    def call(self, inputs, training=None):
        if not training:
            inputs = inputs - self.threshold
            inputs = math_ops.ceil(inputs)
        elif self.use_sigmoid:
            inputs = self.sigmoid(inputs)
        return inputs

In [None]:
class SpikeProcessor(keras.layers.Layer):
    def __init__(self, nSynapses, start=None, end=None, name="SpikeProcessor"):
        super().__init__(name=name)
        self.nSynapses = nSynapses
        self.start = start
        self.end = end

    def build(self, shape):
        pass
    
    def call(self, inputs):
        if self.start is not None:
            if self.end is not None:
                inputs = inputs[:,:,self.start:self.end]
            else:
                inputs = inputs[:,:,self.start:]
        elif self.end is not None:
            inputs = inputs[:,:,:self.end]
        preds_sum = math.reduce_sum(inputs, axis=-1, keepdims=True)
        return_value = preds_sum/self.nSynapses
        return layers.Flatten()(return_value)

In [None]:
class PredictNeuron(keras.layers.Layer):
    """flattens the temporal dimensions, order by new_order and pads with 0's to fill"""
    def __init__(self, neuron_module, name="NeuronPrediction"):
        super().__init__(name=name)
        self.neurons = neuron_module
        self.single = not isinstance(self.neurons, list)
        if self.single:
            print("Neuron module is", self.neurons.name)
        else:
            message = f"Averaging between {len(self.neurons)} neurons:\n"
            for neuron in self.neurons:
                message += f"\t- {neuron.name}\n"
            print(message)
                
    def call(self, inputs):
        if not self.single:
            outputs = math_ops.reduce_mean(tf.stack([module.module(inputs) for module in self.neurons], axis=-1), axis=-1)
        else:
            outputs = self.neurons.module(inputs)
        return outputs

In [None]:
def loss_for_me(y_true, y_preds):
    return MeanSquaredError()(1, y_preds)

In [None]:
def spikes_preds_processing(preds, n_spikes=50):
    preds_sum = math.reduce_sum(preds, axis=-1, keepdims=True)
    return preds_sum/(n_spikes)

In [None]:
def SigmoidThreshold(pos_mult=1, neg_mult=1, threshold=0):
    """returns a sigmoid function [0,1]->[0,1] with the center at threshold given, and slope multified"""
    def sigmoid_threshold(x):
        new_x = -x+threshold
        multiplier = (pos_mult + neg_mult + (-pos_mult + neg_mult)*K.sign(new_x)) / 2 
        return 1/(1+math.exp(multiplier*(new_x)))
    return sigmoid_threshold

In [None]:
x = np.arange(0,1,0.01)
threshold = 1/16
pos_mult = 10
neg_mult = 50
plt.plot(x, SigmoidThreshold(pos_mult, neg_mult, threshold)(x))
plt.ylim(0,1)
plt.axhline(0.5, color='r', linestyle='--')
plt.axvline(threshold, color='g', linestyle='--')
plt.show()

In [None]:
def SigmoidThresholdEasier(mult=1, threshold=0.5):
    """returns a sigmoid function [0,1]->[0,1] with the center at threshold given, and slope multified"""
    def sigmoid_threshold(x):
        return 1/(1+math.exp(mult*(threshold-x)))
    return sigmoid_threshold

how the new sigmoid looks like

In [None]:
x = np.arange(0, 1, 0.01)
threshold = 0.01
mult = 75
plt.plot(x, SigmoidThresholdEasier(mult, threshold)(x))
plt.ylim(0,1)
plt.xlim(0,1)
plt.axhline(0.5, color='r', linestyle='--')
plt.axvline(threshold, color='g', linestyle='--')
plt.show()

will be in use later

In [None]:
def set_num_syn_loss(syns_wanted_per_ms=50):
    def num_syn_loss(y_true, y_preds):
        return MeanSquaredError()(1,y_preds/syns_wanted_per_ms)
    return num_syn_loss

In [None]:
class ToNeuronInput(keras.layers.Layer):
    """flattens the temporal dimensions, order by new_order and pads with 0's to fill"""
    def __init__(self, full, padding=0, new_order=None, name="NeuronInput", part=200, batch_size=BATCH_SIZE):
        super().__init__(name=name)
        self.full = full
        self.new_order = None
        self.zero_padding = isinstance(padding, int) and padding == 0
        self.padding = padding
        self.part = part
        self.batch_size = batch_size
    
    def build(self, shape):
        self.ms = shape[1]*shape[2]
        self.times = self.part // self.ms
        print(self.padding.shape)
        self.padding = tf.reshape(self.padding, (self.padding.shape[0]*self.padding.shape[1], self.padding.shape[-1]))
        self.padding = K.variable(self.padding[np.newaxis, :self.full-self.ms*self.times])
    
    def call(self, inputs):
        new_inp = layers.Reshape((self.ms, inputs.shape[-1]))(inputs)
        if self.new_order is not None:
            new_inp = gather(new_inp, self.new_order, axis=-2)
        if self.times > 1:
            new_inp = layers.Concatenate(axis=1)([new_inp for _ in range(self.times)])
        if self.zero_padding:
            new_inp = layers.ZeroPadding1D(padding=(self.full - self.ms*self.times, 0))(new_inp)
        else:
            new_inp = layers.Concatenate(axis=-2)([tf.tile(self.padding, [tf.shape(new_inp)[0], 1, 1]), new_inp])
        return new_inp
    
    @tf.function
    def gather(x, ind):
        return tf.gather(x + 0, ind)

# Loss

In [None]:
def MSE_RMS_SynapsesPerMS(wanted, size=N_EXC, batch_size=32, eps=1e-3):
    def mse_rms(X, mu):
        return tf.math.reduce_mean(tf.math.reduce_sum(tf.square(tf.math.sqrt(tf.math.reduce_mean(tf.square(X+eps), axis=-1)) - mu), axis=-1))
    real_wanted = np.sqrt(wanted / size)
    zeros_wanted = np.sqrt(1. - wanted / size)
    def mse_rms_synapses_per_ms(y_true, y_preds):
        result = tf.math.sqrt((mse_rms(y_preds, real_wanted)**2 + mse_rms(1. - y_preds, zeros_wanted)**2) / 2)
        return result
    return mse_rms_synapses_per_ms

In [None]:
def MeanSquaredErrorSynapsesPerMS(batch_size=32, y_true_for_real=False):
    def mean_squared_error_synapses_per_ms(y_true, y_preds):
        squared_difference = tf.square((y_true if y_true_for_real else 1)-y_preds)
        mean = tf.reduce_mean(squared_difference, axis=-1)
        return mean
    return mean_squared_error_synapses_per_ms

In [None]:
def MaxErrorSynapsesPerMS(batch_size=32, y_true_for_real=False):
    def max_error_synapses_per_ms(y_true, y_preds):
        maxes = tf.reduce_max(y_preds, axis=-1)
        squared_difference = tf.square(1-y_preds)
        mean = tf.reduce_mean(squared_difference, axis=-1)
        return mean
    return max_error_synapses_per_ms

In [None]:
max_acceptable_spikes_per_ms = 3.0
max_acceptable_spikes_deviation = 20.0
activity_reg_constant = 0.2 * 0.0028


def pre_synaptic_spike_regularization(activation_map):
    # sum over all dendritic locations
    x = K.sum(activation_map, axis=2)

    # ask if it's above 'max_acceptable_spikes_per_ms'
    x = K.relu(x - max_acceptable_spikes_per_ms)

    # if above threshold, apply quadratic penelty
    x = K.square(x / max_acceptable_spikes_deviation)

    # average everything
    x = activity_reg_constant * K.mean(x)

    return x

# Noise Augmentation

In [None]:
plt.figure(figsize=(90,30))
sigma = .01
noise = NoiseLayer(sigma)
for _, (image, label) in enumerate(train_ds):
    imageNoi = noise(image[:,:,:,0])
    for i in range(3):
        x = image[i,:,:,0]
        plt.subplot(3,5,i+1)
        plt.title("Original")
        plt.imshow(x, cmap="binary")
        plt.axis("off")
        plt.subplot(3,5,5+i+1)
        plt.title(r"Noise ($\sigma$={})".format(sigma))
        plt.imshow(imageNoi[i], cmap="binary")
        plt.axis("off")
        plt.subplot(3,5, 11+i)
        plt.title("Difference")
        plt.imshow(imageNoi[i]-image[i,:,:,0], cmap="binary")
        plt.axis("off")
    break

# Delay Augmentation

In [None]:
plt.figure(figsize=(40,10))
for _, (image, label) in enumerate(train_ds):
    layer = GeometricDelay(.01, True)
    for i in range(5):
        imageDel = layer(image[np.newaxis, i])
        plt.subplot(2,5,i+1)
        plt.title("Original")
        plt.imshow(image[i,:,:,0], cmap="binary")
#         plt.axis("off")
        plt.subplot(2,5,5+i+1)
        plt.title("Delay")
        plt.imshow(imageDel[0,:,:,0], cmap="binary")
#         plt.axis("off")
    break

# Noisy Optimizer
https://arxiv.org/abs/1511.06807

https://github.com/cpury/keras_gradient_noise

In [None]:
import inspect
import importlib

def add_gradient_noise(BaseOptimizer, keras=None):
    """
    Given a Keras-compatible optimizer class, returns a modified class that
    supports adding gradient noise as introduced in this paper:
    https://arxiv.org/abs/1511.06807
    The relevant parameters from equation 1 in the paper can be set via
    noise_eta and noise_gamma, set by default to 0.3 and 0.55 respectively.
    By default, tries to guess whether to use default Keras or tf.keras based
    on where the optimizer was imported from. You can also specify which Keras
    to use by passing the imported module.
    """
    if keras is None:
        # Import it automatically. Try to guess from the optimizer's module
        if hasattr(BaseOptimizer, '__module__') and BaseOptimizer.__module__.startswith('keras'):
            keras = importlib.import_module('keras')
        else:
            keras = importlib.import_module('tensorflow.keras')

    K = keras.backend

    if not (
        inspect.isclass(BaseOptimizer) and
        issubclass(BaseOptimizer, keras.optimizers.Optimizer)
    ):
        raise ValueError(
            'add_gradient_noise() expects a valid Keras optimizer'
        )

    def _get_shape(x):
        if hasattr(x, 'dense_shape'):
            return x.dense_shape

        return K.shape(x)

    class NoisyOptimizer(BaseOptimizer):
        def __init__(self, noise_eta=0.3, noise_gamma=0.55, **kwargs):
            super(NoisyOptimizer, self).__init__(**kwargs)
            with K.name_scope(self.__class__.__name__):
                self.noise_eta = K.variable(noise_eta, name='noise_eta')
                self.noise_gamma = K.variable(noise_gamma, name='noise_gamma')

        def get_gradients(self, loss, params):
            grads = super(NoisyOptimizer, self).get_gradients(loss, params)

            # Add decayed gaussian noise
            t = K.cast(self.iterations, K.dtype(grads[0]))
            variance = self.noise_eta / ((1 + t) ** self.noise_gamma)

            grads = [
                grad + K.random_normal(
                    _get_shape(grad),
                    mean=0.0,
                    stddev=K.sqrt(variance),
                    dtype=K.dtype(grads[0])
                )
                for grad in grads
            ]

            return grads

        def get_config(self):
            config = {'noise_eta': float(K.get_value(self.noise_eta)),
                      'noise_gamma': float(K.get_value(self.noise_gamma))}
            base_config = super(NoisyOptimizer, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

    NoisyOptimizer.__name__ = 'Noisy{}'.format(BaseOptimizer.__name__)

    return NoisyOptimizer

In [None]:
NoisySGD = add_gradient_noise(SGD)
NoisyNadam = add_gradient_noise(Nadam)
NoisyAdamax = add_gradient_noise(Adamax)

In [None]:
eta = 0.03
gamma = 0.55

n = 150

var = []
for t in range(n):
    var.append(eta/(1+t)**gamma)
plt.plot(np.arange(n), var)
plt.title(r'Noise $\sigma^{2} $ through epochs')
plt.xlabel('ephocs')
plt.ylabel(r'noise $\sigma^{2}$')
plt.show()

In [None]:
class SynapsePruner(keras.callbacks.Callback):
    def __init__(self, rate1=0.9921752738654147, rate2=0.9765258215962441, iterations=8, axis=-1, splitPoint = N_EXC, prune_layer="WiringLayer"):
        super().__init__()
        self.rate1 = rate1 * 100
        self.rate2 = rate2 * 100
        self.iterations = iterations
        self.curIterations = iterations
        self.split = splitPoint
        self.layer_to_prune = prune_layer
        print(self)
    
    def __str__(self):
        string = f"\nSynapsePruner callback:\nrate1 = {self.rate1} ({str(round((1-self.rate1/100)*N_EXC,2))} : {N_EXC})\n"
        string += f"rate2 = {self.rate2} ({str(round((1-self.rate2/100)*N_EXC, 2))} : {N_EXC})\n"
        string += f"every {self.iterations} iterations"
        return string
        
    def build(self, shape):
        pass
        
#     @tf.autograph.experimental.do_not_convert
    def on_train_batch_end(self, batch, logs=None):
        if self.curIterations: self.curIterations -= 1
        else: 
            self.curIterations = self.iterations
            self.prune()
    
    def prune(self):
        layer = self.model.get_layer(self.layer_to_prune)
        kernels, bias = layer.get_weights()
        kernels = kernels[:,0,0,:]
        p1_x = kernels[:, :self.split]
        p2_x = kernels[:, self.split:]
        percentsP1 = np.percentile(p1_x, self.rate1, axis=-1)
        percentsP2 = np.percentile(p2_x, self.rate2, axis=-1)
        arr = np.concatenate([(p1_x.T>percentsP1).T, (p2_x.T > percentsP2).T], axis=-1)
        arr = kernels * arr
        assert arr.shape == kernels.shape, print("Oh no. something went wrong")
        layer.set_weights([arr[:,np.newaxis, np.newaxis, :], bias])

In [None]:
class SynapsePruner(keras.callbacks.Callback):
    def __init__(self, kmax1=5, kmax2=15, iterations=8, axis=-1, splitPoint = N_EXC, prune_layer="WiringLayer"):
        super().__init__()
        self.kmax1 = kmax1
        self.kmax2 = kmax2
        self.iterations = iterations
        self.curIterations = iterations
        self.split = splitPoint
        self.layer_to_prune = prune_layer
        print(self)
    
    def __str__(self):
        string = f"\nSynapsePruner callback:\nkmax1 = {self.kmax1}\n"
        string += f"kmax2 = {self.kmax2}\n"
        string += f"every {self.iterations} iterations\n Layer to Prune: {self.layer_to_prune}"
        return string
        
    def build(self, shape):
        pass
        
#     @tf.autograph.experimental.do_not_convert
    def on_train_batch_end(self, batch, logs=None):
        if self.curIterations: self.curIterations -= 1
        else: 
            self.curIterations = self.iterations
            self.prune()
    
    def prune(self):
        layer = self.model.get_layer(self.layer_to_prune)
        kernels, bias = layer.get_weights()
        kernels = kernels[:,0,0,:]
        p1_x = kernels[:, :self.split]
        p2_x = kernels[:, self.split:]
        kmax1 = np.partition(p1_x, -self.kmax1, axis=-2)[-self.kmax1]
        kmax2 = np.partition(p2_x, -self.kmax2, axis=-2)[-self.kmax2]
        arr1 = p1_x * (p1_x >= kmax1)
        arr2 = p2_x * (p2_x >= kmax2)
        arr = np.concatenate([arr1, arr2], axis=-1)
        assert arr.shape == kernels.shape, tf.print("Oh no. something went wrong")
        layer.set_weights([arr[:,np.newaxis, np.newaxis,:], bias])

In [None]:
class NSynapseRegularizer(keras.regularizers.Regularizer):

    def __init__(self, strength=0.01, nSynapseEx=5, nSynapseIn=5, threshold = 0.07, mult = 75):
        self.strength = strength
        self.sigmoid = SigmoidThresholdEasier(mult, threshold)
        self.loss = lambda nSynapse, x: tf.reduce_mean(tf.math.pow(1 - x/nSynapse, 2))
        self.nSynapseEx = nSynapseEx
        self.nSynapseIn = nSynapseIn
        self.percentEx = (1 - nSynapseEx / N_EXC) * 100
        self.percentIn = (1 - nSynapseIn / N_EXC) * 100
        print(self.percentEx)

    def __call__(self, x):
        
#         tf.print(cur_max, tf.reduce_mean(x))
        x = x[:,0,0,:] # self.sigmoid(x)[:,0,0,:]
#         tf.print(tf.reduce_mean(x))
        ex = x[:, :N_EXC]
        inh = x[:, N_EXC:]
        percentsEx = tfp.stats.percentile(ex, self.percentEx, axis=-1)
        percentsIn = tfp.stats.percentile(inh, self.percentIn, axis=-1)
#         tf.print(percentsEx.shape)
#         tf.print((tf.transpose(x[:,:N_EXC])*percentsEx).shape)
        dist_from_half_ex = tf.reduce_mean(1-tf.math.pow(percentsEx-tf.transpose(ex), 2))
        dist_from_half_in = tf.reduce_mean(1-tf.math.pow(percentsIn-tf.transpose(inh), 2))
#         exX = tf.reduce_sum(x[:,:N_EXC], axis=-1)
#         inX = tf.reduce_sum(x[:,N_EXC:], axis=-1)
#         lossEx = self.loss(self.nSynapseEx, exX)
#         lossIn = self.loss(self.nSynapseIn, inX)
#         return self.strength * (lossEx + lossIn + dist_from_half)
        return self.strength * (dist_from_half_ex + dist_from_half_in)

In [None]:
class CallNeuron(keras.layers.Layer):
    def __init__(self, neuronModules, spare=150, name="CallNeuronLayer"):
        super().__init__(name=name)
        self.neurons = neuronModules
        self.spare = spare
        self.neuronTime =  neuronModules.module.input.shape[-2] if not isinstance(neuronModules, list) else neuronModules[0].module.input.shape[-2]
        self.synapses = neuronModules.module.input.shape[-1] if not isinstance(neuronModules, list) else neuronModules[0].module.input.shape[-1]
        self.fullTime = None
        self.timePerRun = self.neuronTime-self.spare
        self.times = None
        self.paddingSize = 0
        self.neuron = None
        
    def build(self, shape):
        self.neuron = PredictNeuron(self.neurons)
#         print("Module is:", self.neuron)
        if shape[-1] != self.synapses:
            raise Exception(f"Wrong number of synapses! Neuron synapse number is {self.synapses} and input shape is {shape}")
        self.fullTime = shape[-2]
        self.paddingSize = self.fullTime%self.timePerRun
        self.times = [((self.timePerRun)*i, (self.timePerRun)*i+self.neuronTime) for i in range(self.fullTime//self.timePerRun)]    
        
    @tf.autograph.experimental.do_not_convert
    def call(self, inputs, training=None):
        inputs = layers.ZeroPadding1D(padding=((0,self.paddingSize)))(inputs)
        outputs = layers.Concatenate(axis=-2)([self.neuron(inputs[:,self.times[0][0]:self.times[0][1]])]+[self.neuron(inputs[:,i:j,])[:,self.spare:] for i,j in self.times[1:]])
        return outputs

# Model

In [None]:
def create_model(optimizer, sigmoid_threshold=0.8, use_sigmoid=True, sigmoid_mult=25, dropout=.2, 
                 excitatory_wanted=EXCITATORY_SYNAPSES_WANTED, inhibitory_wanted=INHIBITORY_SYNAPSES_WANTED, qSynapse=(0.05, 0.05)):
    
    inp = keras.Input(shape=(700, 1400,1))
    if dropout: x = layers.Dropout(dropout)(inp)
    x = layers.Conv2D(1278, (inp.shape[-3], 1), activity_regularizer=pre_synaptic_spike_regularization, name="WiringLayer")(inp)
    x = K.squeeze(x, axis=-3)
    x = layers.BatchNormalization()(x)
    x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, 
                    use_sigmoid=use_sigmoid, 
                    mult=sigmoid_mult, 
                    name='preNeuronBool')(x)
    if dropout: x = layers.Dropout(dropout)(x)
    nExcitatorySynapsesPerMS = SpikeProcessor(excitatory_wanted, name='nExcitatory', end=N_EXC)(x)
    nInhibitorySynapsesPerMS = SpikeProcessor(inhibitory_wanted, name='nInhibitory', start=N_EXC)(x)
    x = CallNeuron(L5PC_model, name="SpikeTrain")(x)
    x = ToBoolLayer(threshold=.2, 
                    use_sigmoid=False, 
                    name='postNeuronBool')(x)
    output = layers.Dense(1, activation="sigmoid", name="nSpikes")(x)
    model = keras.Model(inp, [output, nExcitatorySynapsesPerMS, nInhibitorySynapsesPerMS])
    model.compile(optimizer=optimizer, 
                  loss={'nSpikes': MeanSquaredError(), 'nExcitatory': MeanSquaredErrorSynapsesPerMS(), 'nInhibitory': MeanSquaredErrorSynapsesPerMS()}, 
                  metrics={'nSpikes': keras.metrics.BinaryAccuracy()}, 
                  loss_weights=[1-sum(qSynapse), *qSynapse])
    return model

In [None]:
def create_model(optimizer, sigmoid_threshold=0.8, use_sigmoid=True, sigmoid_mult=25, dropout=.2, nSynapse=True, delay=.02, dropout2=False,
                     excitatory_wanted=EXCITATORY_SYNAPSES_WANTED, inhibitory_wanted=INHIBITORY_SYNAPSES_WANTED, qSynapse=(0.05, 0.05)):
    
    inp = keras.Input(shape=(700, 1400,1))
    x = inp
    if delay: x = GeometricDelay(delay)(x)    
    if dropout: x = layers.Dropout(dropout)(x)
    
    # Wiring
    x = layers.Conv2D(1278, (inp.shape[-3], 1), activity_regularizer=pre_synaptic_spike_regularization, kernel_constraint=keras.constraints.NonNeg(), 
                      bias_constraint=keras.constraints.NonNeg(), name="WiringLayer")(x)
    x = layers.Reshape((x.shape[-2], x.shape[-1]))(x)
    
    # Regularize synapse number
#     x = layers.BatchNormalization()(x)
#     x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, use_sigmoid=use_sigmoid, mult=sigmoid_mult, name='preNeuronBool')(x)
    
    if dropout2: x = layers.Dropout(dropout2, name="preNeuronDrop")(x)
    
    # For number of synapse loss regularization
    if nSynapse:
        nExcitatorySynapsesPerMS = SpikeProcessor(excitatory_wanted, name='nExcitatory', end=N_EXC)(x)
        nInhibitorySynapsesPerMS = SpikeProcessor(inhibitory_wanted, name='nInhibitory', start=N_EXC)(x)
    
    # Neuron
    x = CallNeuron(L5PC_model, name="SpikeTrain")(x)
#     x = ToBoolLayer(threshold=.2, use_sigmoid=False, name='postNeuronBool')(x)
    
    x = layers.Reshape((x.shape[-1], 1))(x)#[:, :800, :]
#     x = layers.Conv1D(NLABELS, 150, 100)(x)
#     x = K.sum(x, axis=-2)
    x = layers.MaxPooling1D(x.shape[-2], strides=x.shape[-2], name='maxPool')(x)
#     x = layers.AveragePooling1D(x.shape[-2], name='avePool')(x)
    output = layers.Flatten(name="nSpikes")(x)
#     print(x.shape)
#     output = layers.Softmax(name="nSpikes")(x)
#     output = layers.Dense(1, activation="softmax", name="nSpikes")(x)
    
    # create and compile model:
#     output = layers.Dense(1, activation="sigmoid", name="nSpikes")(x)
    if nSynapse: 
        model = keras.Model(inp, [output, nExcitatorySynapsesPerMS, nInhibitorySynapsesPerMS])
        model.compile(optimizer=optimizer, 
                      loss={'nSpikes': MeanSquaredError(), 'nExcitatory': MaxErrorSynapsesPerMS(), 'nInhibitory': MaxErrorSynapsesPerMS()}, 
                      metrics={'nSpikes': keras.metrics.BinaryAccuracy()}, 
                      loss_weights=[1-sum(qSynapse), *qSynapse])
    else:
        model = keras.Model(inp, output)
        model.compile(optimizer=optimizer, loss=MeanSquaredError(), metrics=keras.metrics.BinaryAccuracy())
    return model

In [None]:
def create_model_cat(categories, optimizer, sigmoid_threshold=0.8, use_sigmoid=True, sigmoid_mult=25, dropout=.2, 
                     excitatory_wanted=EXCITATORY_SYNAPSES_WANTED, inhibitory_wanted=INHIBITORY_SYNAPSES_WANTED, qSynapse=(0.05, 0.05)):
    
    inp = keras.Input(shape=(700, 1400,1))
    
    if dropout: x = layers.Dropout(dropout)(inp)
    
    # Wiring
    x = layers.Conv2D(1278, (inp.shape[-3], 1), activity_regularizer=pre_synaptic_spike_regularization, bias_constraint=keras.constraints.non_neg(),
                      kernel_constraint=keras.constraints.non_neg(), name="WiringLayer")(x)
#     x = layers.Conv2D(1278, (inp.shape[-3], 1),use_bias=False, 
#                       kernel_constraint=keras.constraints.NonNeg(), kernel_regularizer = NSynapseRegularizer(kernel_regularizer_strength),
#                       name="WiringLayer")(inp if not dropout else x)
    x = K.squeeze(x, axis=-3)
    
    # Regularize synapse number
#     x = layers.BatchNormalization()(x)
#     x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, use_sigmoid=use_sigmoid, mult=sigmoid_mult, name='preNeuronBool')(x)
    
#     if dropout: x = layers.Dropout(dropout, name="preNeuronDrop")(x)
    
    # For number of synapse loss regularization
    nExcitatorySynapsesPerMS = SpikeProcessor(excitatory_wanted, name='nExcitatory', end=N_EXC)(x)
    nInhibitorySynapsesPerMS = SpikeProcessor(inhibitory_wanted, name='nInhibitory', start=N_EXC)(x)
    
    # Neuron
    x = CallNeuron(L5PC_model, name="SpikeTrain")(x)
#     x = ToBoolLayer(threshold=.2, use_sigmoid=False, name='postNeuronBool')(x)
    
    # to categorical prediction:
    x = tf.expand_dims(x, -1)
    x = layers.Conv1D(NLABELS, 150, 100)(x)
#     x = K.sum(x, axis=-2)
    x = layers.MaxPooling1D(x.shape[-2], name='maxPool')(x)
    x = layers.Flatten()(x)
    output = layers.Softmax(name="nSpikes")(x)
#     output = layers.Dense(categories, activation="softmax", name="nSpikes")(x)
    
    # create and compile model:
    model = keras.Model(inp, [output, nExcitatorySynapsesPerMS, nInhibitorySynapsesPerMS])
    model.compile(optimizer=optimizer, 
                  loss={'nSpikes': 'categorical_crossentropy', 
                        'nExcitatory': MeanSquaredErrorSynapsesPerMS(), 
                        'nInhibitory': MeanSquaredErrorSynapsesPerMS()}, 
                  metrics={'nSpikes': keras.metrics.CategoricalAccuracy()}, 
                  loss_weights=[1-sum(qSynapse), *qSynapse])
    return model

In [None]:
# model = create_model(Nadam(), sigmoid_mult=50, nSynapse=True, use_sigmoid=True, sigmoid_threshold=0.2, dropout=.1, excitatory_wanted=8, inhibitory_wanted=6, qSynapse=(.08, .05))

In [None]:
# model = create_model(Nadam(), sigmoid_mult=50, nSynapse=True, use_sigmoid=True, sigmoid_threshold=0.2, dropout=.1, delay=False, excitatory_wanted=10, inhibitory_wanted=8, qSynapse=(.08, .05))

# Current Module

In [None]:
def create_model(optimizer, sigmoid_threshold=0.8, use_sigmoid=True, sigmoid_mult=25, dropout=.2, delay=.02, dropout2=False):
    
    inp = keras.Input(shape=(700, 1400,1))
    x = inp
    if delay: x = GeometricDelay(delay)(x)    
    if dropout: x = layers.Dropout(dropout)(x)
    
    # Wiring
    x = layers.Conv2D(1278, (inp.shape[-3], 1), activity_regularizer=pre_synaptic_spike_regularization, kernel_constraint=keras.constraints.NonNeg(), 
                      bias_constraint=keras.constraints.NonNeg(), name="WiringLayer")(x)
    x = layers.Reshape((x.shape[-2], x.shape[-1]))(x)
    
    # Regularize synapse number
    x = layers.BatchNormalization()(x)
    x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, use_sigmoid=use_sigmoid, mult=sigmoid_mult, name='preNeuronBool')(x)
    
    if dropout2: x = layers.Dropout(dropout2, name="preNeuronDrop")(x)

    # Neuron
    x = CallNeuron(L5PC_model, name="SpikeTrain")(x)
    
    x = layers.Reshape((x.shape[-1], 1))(x)

    x = layers.MaxPooling1D(x.shape[-2], strides=x.shape[-2], name='maxPool')(x)
    output = layers.Flatten(name="nSpikes")(x)

    model = keras.Model(inp, output)
    model.compile(optimizer=optimizer, loss=MeanSquaredError(), metrics=keras.metrics.BinaryAccuracy())
    
    return model

In [None]:
class SliceLayer(tf.keras.layers.Layer):
    def __init__(self, start=None, end=None, name="SliceLayer"):
        super().__init__(name=name)
        self.start = start
        self.end = end

    def build(self, shape):
        self.dims = len(shape)
        pass
    
    def call(self, inputs):
        if self.start is not None:
            if self.end is not None:
                return inputs[:,:,:,self.start:self.end] if self.dims==4 else inputs[:,:,self.start:self.end]
            else:
                return inputs[:,:,:,self.start:] if self.dims==4 else inputs[:,:,self.start:]
        elif self.end is not None:
            return inputs[:,:,:,:self.end] if self.dims==4 else inputs[:,:,:self.end]

In [None]:
class MeanSynapsesPerMsMetric:
    def __init__(self, name="nSynapses"):
        self.name=name
    def __call__(self, y_true, y_pred):
        booleans = tf.cast(tf.math.greater(y_pred, 0.5), tf.float32)
        summed = tf.math.reduce_sum(booleans, axis=-1)
        mean_by_ms = tf.math.reduce_mean(summed, axis=-1)
        meaned_batches = tf.math.reduce_mean(mean_by_ms, axis=-1)
        return meaned_batches

In [None]:
def printModule(dropout1, dropout2, sigmoid_threshold, sigmoid_mult, exWanted, inWanted, qSynapse, synLoss):
    print("~*~ Visual Module ~*~")
#     if augment: print(f"Images are augmented (rotated by {augment})")
    print(f"Image Dropout Rate: {dropout1}")
    print(f"Synapse Threshold: {sigmoid_threshold}")
    print(f"Synapse Training Sigmoid Multiplication: {sigmoid_mult}")
    print(f"Synapses Wanted: Exc-{exWanted}; Inh-{inWanted}; Sum-{exWanted+inWanted}")
    print(f"Synapses Loss Rate: Exc-{qSynapse[0]}; Inh-{qSynapse[1]}")
    print(f"Synapses Dropout Rate: {dropout2}")
    print(f"Synapse Loss Function: {synLoss.__name__}")

In [None]:
def different_nSynapses(saccades=None, xaxis=True, use_sigmoid=True, pruning=True, dropout1=False, dropout2=False,
                        sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, dropout=False, 
                        excitatory_wanted=EXCITATORY_SYNAPSES_WANTED, inhibitory_wanted=INHIBITORY_SYNAPSES_WANTED,
                        qSynapse=(tf.Variable(0.1, trainable=False), tf.Variable(0.1, trainable=False)), 
                        optimizer=SGD(momentum=.9), conv_shape=(16,16), non_neg=False,
                        synLoss=MeanSquaredErrorSynapsesPerMS, nSynapse=True, neurons=neurons, regular_sigmoid=True):
    printModule(dropout1, dropout2, sigmoid_threshold, sigmoid_mult, excitatory_wanted, inhibitory_wanted, qSynapse, synLoss)
    module_name = "module_syn_"
    for i in [dropout1, dropout2, sigmoid_threshold, sigmoid_mult, excitatory_wanted, inhibitory_wanted, qSynapse, synLoss]:
        module_name += str(i)
    inp = keras.Input(shape=(700, 1400,1))
    x = inp
    if dropout: x = layers.Dropout(dropout)(x)
    
    # Wiring
    if non_neg: x = layers.Conv2D(1278, (inp.shape[-3], 1), activity_regularizer=pre_synaptic_spike_regularization, kernel_constraint=keras.constraints.NonNeg(), 
                      bias_constraint=keras.constraints.NonNeg(), name="WiringLayer")(x)
    else: x = layers.Conv2D(1278, (inp.shape[-3], 1),
                      bias_constraint=keras.constraints.NonNeg(), name="WiringLayer")(x)
    x = layers.Reshape((x.shape[-2], x.shape[-1]))(x)
#     inp = keras.Input(shape=(256,256,1))
    
#     padding = tf.Variable(lambda: padding, trainable=False)
#     x = inp
#     if augment: x = data_augmentation(augment)(x)
#     x = processor(x)
#     if dropout1: x = layers.Dropout(dropout1)(x)
#     conv_shape = (x.shape[-3], 1) if xaxis else conv_shape
#     x = Pad(padding)(x)
#     x = layers.Conv2D(2*N_EXC, conv_shape, strides=conv_shape, use_bias=True, kernel_constraint=keras.constraints.NonNeg(), name="WiringLayer")(x)#, activity_regularizer=pre_synaptic_spike_regularization)(x)
    x = layers.BatchNormalization(name="BatchNorm")(x)
    if regular_sigmoid: x = layers.Activation(sigmoid)(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, use_sigmoid=use_sigmoid, mult=sigmoid_mult, name='preNeuronBool')(x)
    
#     x = K.squeeze(x, axis=-3)
        
    ExcitatorySynapses = SliceLayer(end=N_EXC, name="ExcSyns")(x)#[:, :, :N_EXC]  #SpikeProcessor(excitatory_wanted, name='nExcitatory', end=639)(x)
    InhibitorySynapses = SliceLayer(start=N_EXC, name="InhSyns")(x)#[:, :, N_EXC:]  #SpikeProcessor(inhibitory_wanted, name='nInhibitory', start=639)(x)

#     if dropout2: x = layers.Dropout(dropout2)(x)
    x = CallNeuron(neurons, name="SpikeTrain")(x)

#     x = PredictNeuron(neurons)(x)[:,-198:,:]   # run through david's model, take only the post-noise time (32 ms X 6 times)
    x = layers.MaxPooling1D(x.shape[-2], strides=x.shape[-2], name="MaxPooling")(x)
    output = layers.Flatten(name="nSpikes")(x)
    
    model = keras.Model(inp, [output, ExcitatorySynapses, InhibitorySynapses])
    if not nSynapse:
        model.compile(optimizer=optimizer, loss={'nSpikes': MeanSquaredError()}, 
                      metrics={"nSpikes": keras.metrics.BinaryAccuracy(), "ExcSyns": MeanSynapsesPerMsMetric(), "InhSyns": MeanSynapsesPerMsMetric()})
    else:
        model.compile(optimizer=optimizer, 
                  loss=[MeanSquaredError(), synLoss(excitatory_wanted), synLoss(inhibitory_wanted)], 
                  metrics={"nSpikes": keras.metrics.BinaryAccuracy(), "ExcSyns": MeanSynapsesPerMsMetric(), "InhSyns": MeanSynapsesPerMsMetric()},
                  loss_weights=[1.0 -qSynapse[0] - qSynapse[1], qSynapse[0], qSynapse[1]])
    return model, module_name

In [None]:
# model_synapses, module_name = different_nSynapses(nSynapse=True, non_neg=True, dropout1=False, dropout2=False, use_sigmoid=True, regular_sigmoid=False, sigmoid_mult=2, synLoss=MSE_RMS_SynapsesPerMS, sigmoid_threshold=2, optimizer=Nadam(), qSynapse=(.1, .1))

In [None]:
# model_synapses.summary()

In [None]:
# history = model_synapses.fit_generator(train_ds, epochs=50, validation_data=valid_ds)#, callbacks=SynapsePruner(iterations=20))

# Benchmark

In [None]:
def build_defected_FCN_model(cat=1, depth=1, width=32, filters=1, kernel_size=400, stride_size=150, dropout=0, optimizer=SGD(momentum=.9), l2_reg=1e-3):
    inp = keras.Input(shape=(700, 1400,1))
    x = inp
#     x = GeometricDelay(delay, False)(inp) if delay else inp
#     x = layers.Flatten()(x)
    if dropout != 0: x = layers.Dropout(dropout)(x)
#     x = layers.Conv2D(1278, (700, 1), (700, 1), activation="relu")(x)
#     x = layers.BatchNormalization()(x)
    x = layers.Conv2D(filters, (700, kernel_size), (700, stride_size), activation="relu")(x)
    x = layers.BatchNormalization()(x)
    for d in range(depth):
#         if dropout != 0: x = layers.Dropout(dropout)(x)
        x = layers.Dense(units=width, activation='linear', kernel_regularizer=keras.regularizers.l2(l2_reg), name='FC_layer_%d' %(d + 1))(x)
        x = layers.LeakyReLU(alpha=0.3, name='LReLU_%d' %(d + 1))(x)
        x = layers.BatchNormalization(name='BN_layer_%d' %(d + 1))(x)
    if filters > 1: x = layers.Dense(units=1, activation='sigmoid', kernel_regularizer=keras.regularizers.l2(l2_reg), name='Logits')(x)

#     x = layers.Flatten()(x)
    x = layers.Reshape((x.shape[-2], 1))(x)
    x = layers.MaxPooling1D(x.shape[-2], strides=x.shape[-2])(x)
    output = layers.Flatten()(x)
#     output = keras.activations.sigmoid(x)
#     output = layers.Dense(cat, activation='sigmoid', name='logits')(x)
    model = keras.Model(inp, output)
    model.compile(optimizer=optimizer, loss=MeanSquaredError(), metrics=keras.metrics.BinaryAccuracy() if cat==1 else keras.metrics.CategoricalAccuracy())
    return model

In [None]:
bench = build_defected_FCN_model(cat=1, depth=0, width=5, filters=10, dropout=.5)

In [None]:
bench.summary()

In [None]:
bench.fit_generator(train_ds, epochs=5, validation_data=valid_ds)

# David's Benchamrk Check

In [None]:
short_run = False

if short_run:
    # simple cells feature extraction is extreemly fast
    batch_size = 64
    num_epochs = 12
    l2_reg = 1e-2

    depths_list = [1,2,3]
    widths_list = [1,2,4,8,16,32,64]
    num_random_repetitions = [1,2,3,4,5,6,7,8]
else:
    batch_size = 64
    num_epochs = 7
    l2_reg = 1e-3

    depths_list = [1,2,3]
    widths_list = [1,2,4,8,16]
    num_random_repetitions = [1,2,3,4]

    
results_list = []
for depth in depths_list:
    for width in widths_list:
        for rand_rep in num_random_repetitions:
            training_start_time = time.time()

            # create model 
            FCN_model = build_defected_FCN_model(depth=depth, width=width, l2_reg=l2_reg, dropout=.5, filters=10)

            # fit model
            print('-----------------------------------------------------------')
            train_history = FCN_model.fit_generator(train_ds, epochs=num_epochs)
            print('----------------------------------------')

            # evaluate performace
            train_loss_and_metrics = FCN_model.evaluate_generator(train_ds)
            test_loss_and_metrics  = FCN_model.evaluate_generator(test_ds)
            print('----------------------------------------')
            training_time_min = (time.time() - training_start_time) / 60
            # print results
            print('----------------------------------------')
            print("training took %.2f minutes" % (training_time_min))
            print('----------------------------------------')
            print('model name is ""%s"' %(FCN_model.name))
            print('----------------------------------------')
            print("Train loss = %.5f" %(train_loss_and_metrics[0]))
            print("Train accuracy = %.2f%s" %(100 * train_loss_and_metrics[1], '%'))
            print('----------------------------------------')
            print("Test loss = %.5f" %(test_loss_and_metrics[0]))
            print("Test accuracy = %.2f%s" %(100 * test_loss_and_metrics[1], '%'))
            print('-----------------------------------------------------------')

            # store results in dict
            results_dict = {}
            results_dict['depth']              = depth
            results_dict['width']              = width
            results_dict['rand_rep']           = rand_rep

            results_dict['model_name']         = FCN_model.name
            results_dict['training_time_min']  = training_time_min
            results_dict['train_history']      = train_history.history
            results_dict['Train loss']         = train_loss_and_metrics[0]
            results_dict['Train accuracy']     = 100 * train_loss_and_metrics[1]
            results_dict['Test loss']          = test_loss_and_metrics[0]
            results_dict['Test accuracy']      = 100 * test_loss_and_metrics[1]

            # store in results list
            results_list.append(results_dict)

In [None]:
print('The total number of models that were trained is %d' %(len(results_list)))

In [None]:
pickle.dump(results_list, open('results_list_%d_models_%d.pickle' %(len(results_list), np.random.randint(100)), "wb"))

In [None]:
column_names = ['preprocessing_type','subsample','depth','width','Test accuracy','rand_rep']

results_df = pd.DataFrame(index=range(len(results_list)), columns=column_names)

for k, results_dict in enumerate(results_list):
    results_df.loc[k,'depth'] = results_dict['depth']
    results_df.loc[k,'width'] = results_dict['width']
    results_df.loc[k,'Test accuracy'] = results_dict['Test accuracy']
    results_df.loc[k,'rand_rep'] = results_dict['rand_rep']

results_df = results_df.astype({"Test accuracy": float})
results_df

In [None]:
grouped_results_df = results_df.groupby(['depth','width']).agg({'Test accuracy': ['max','mean']})
grouped_results_df.columns = ['Best Test accuracy','Mean Test accuracy']
grouped_results_df = grouped_results_df.reset_index()
grouped_results_df

# Plot

In [None]:
def plot_examples(model, plot_start=None, write_last=False, show_spikes=False, ds=valid_ds):
    plt.figure(figsize=(256, 200))
    how_many = 10
    for img, label in ds:
#         inputs = K.function(model.input, model.get_layer(bool_layers[0]).output)([img])
        inputs = K.function(model.input, model.get_layer("preNeuronBool").output)([img])
        outputs = K.function(model.input, model.get_layer("SpikeTrain").output)([img])
        if write_last: nAP = K.function(model.input, model.get_layer("nSpikes").output)([img])

        for i in range(how_many):
            plt.subplot(how_many, 3, 3*i+1)
            plt.imshow(img[i,:,:,0]/255, cmap='binary')
            plt.title(LABELS[label[i]], fontdict={'fontsize':200})
            plt.axis("off")
            plt.subplot(how_many, 3, 3*i+2)
            if len(inputs[i].shape)==3:
                curr_input = np.squeeze(inputs[i], axis=-3)
            else:
                curr_input = inputs[i]
            num_of_synapses = np.sum(curr_input, axis=-1)
            mean_synapses = round(num_of_synapses.mean(), 2)
            std_synapses = round(num_of_synapses.std(), 2)
            plt.title(f"\u03BC: {str(mean_synapses)},  \u03C3: {str(std_synapses)}", fontdict={'fontsize':200})
            plt.imshow(curr_input, cmap='binary', vmin=0, vmax=1)
            plt.axis("off")
            plt.subplot(how_many, 3, 3*i+3)
            curr_output = outputs[i]            
            if show_spikes:
                spikes = []
                curr_index = 0
                while True:
                    start = np.where(curr_output[curr_index:] > 0.25)[0]
                    if not start.shape[0]: break
                    start = curr_index+start[0]
                    end = np.where(curr_output[start:] < 0.1)[0]
                    if not end.shape[0]: break
                    end = end[0] + start
                    spikes.append((start,end))
                    curr_index = end + 1
            plt.title((f"Score: {str(round(nAP[i][0],2))}," if write_last else "") +f"\u03A3: {str(round(curr_output.sum(),2))}" + (f", spikes:{spikes}" if show_spikes else ""), fontdict={'fontsize':150})
            plt.plot(curr_output, linewidth=5)
            plt.ylim(0,1)
            if plot_start: plt.axvline(plot_start, color='g', linestyle=':', linewidth=10.)
            plt.axis('on')
        break

In [None]:
def plot_examples_cat(model, plot_start=None, write_last=True, show_spikes=False, input_layer = "preNeuronBool"):
    plt.figure(figsize=(256, 200))
    
    how_many = 10
    for img, label in valid_ds:
#         inputs = K.function(model.input, model.get_layer(bool_layers[0]).output)([img])
        inputs = K.function(model.input, model.get_layer(input_layer).output)([img])
        outputs = K.function(model.input, model.get_layer("SpikeTrain").output)([img])
        prediction = K.function(model.input, model.get_layer("nSpikes").output)([img])

        for i in range(how_many):
            img_summed = img[i,:,:,0].sum(axis=0)
            for k in range(img_summed.shape[0]-1):
                if img_summed[k] > 0 and (img_summed[k+1:]==0).all():
                    end_time = k+1
                    break
            plt.subplot(how_many, 4, 4*i+1)
            plt.imshow(img[i,:,:,0]/255, cmap='binary')
            plt.axvline(end_time, color='r', linestyle="--")
            correct = label[i].tolist().index(1)
            plt.title(LABELS[correct], fontdict={'fontsize':200})
            plt.axis("off")
            
            plt.subplot(how_many, 4, 4*i+2)
            if len(inputs[i].shape)==3:
                curr_input = np.squeeze(inputs[i], axis=-3)
            else:
                curr_input = inputs[i]
            num_of_synapses = np.sum(curr_input, axis=-1)
            mean_synapses = round(num_of_synapses.mean(), 2)
            std_synapses = round(num_of_synapses.std(), 2)
            plt.title(f"\u03BC: {str(mean_synapses)},  \u03C3: {str(std_synapses)}", fontdict={'fontsize':200})
            plt.imshow(curr_input, cmap='binary', vmin=0, vmax=1)
            plt.axhline(end_time, color='r', linestyle="--")
            plt.axis("off")
            
            plt.subplot(how_many, 4, 4*i+3)
            curr_output = outputs[i]
            if show_spikes:
                spikes = []
                curr_index = 0
                while True:
                    start = np.where(curr_output[curr_index:] > 0.25)[0]
                    if not start.shape[0]: break
                    start = curr_index+start[0]
                    end = np.where(curr_output[start:] < 0.1)[0]
                    if not end.shape[0]: break
                    end = end[0] + start
                    spikes.append((start,end))
                    curr_index = end + 1
            plt.title(f"\u03A3: {str(round(curr_output.sum(),2))}" + (f", spikes:{spikes}" if show_spikes else ""), fontdict={'fontsize':150})
            plt.plot(curr_output, linewidth=5)
            plt.axvline(end_time, color='r', linestyle="--")
            plt.ylim(0,1)
            if plot_start: plt.axvline(plot_start, color='g', linestyle=':', linewidth=10.)
            plt.axis('on')
            
            plt.subplot(how_many, 4, 4*i+4)
            cur_prediction = prediction[i].tolist().index(max(prediction[i]))
            cur_prediction_list = [j==max(prediction[i]) for j in prediction[i]]
            plt.title(f"Correct: {LABELS[correct]} ({str(round(prediction[i][correct], 2))}), Prediction: {LABELS[cur_prediction]} ({str(round(max(prediction[i]),2))})", fontdict={'fontsize':150})
            colors = ["g" if (lbl and pred) else "r" if pred else "y" if lbl else "b" for lbl, pred in zip(label[i], cur_prediction_list)]
            plt.bar(range(label[i].shape[0]), prediction[i], tick_label = LABELS[:NLABELS], color=colors)
#             plt.xticks(range(label[i].shape[0]), LABELS)
        break

In [None]:
def plot_statistics(model, input_layer="preNeuronBool"):
    
    all_inputs = []
    n_synapses = []
    synapses_mu = []
    synapses_std = []
    synapses_max = []
    synapses_min = []
#     sum_output = []
    all_outputs = []
#     excitatory_synapses = []
#     inhibitory_synapses = []
    label_dct = {}
    synapses_mean_per_label = {}
    
    for i, (img, label) in enumerate(train_ds):
        if i > 3: break
        inputs = K.function(model.input, model.get_layer(input_layer).output)([img])
        outputs = K.function(model.input, model.get_layer("SpikeTrain").output)([img])
        all_inputs.append(inputs)
        labels = [LABELS[lbl] for lbl in label]
        for j in range(len(inputs)):
            if len(inputs[j].shape)==3:
                curr_input = np.squeeze(inputs[j], axis=-3)
            else:
                curr_input = inputs[j]
            img_summed = img[j,:,:,0].sum(axis=0)
            for k in range(img_summed.shape[0]-1):
                if img_summed[k] > 0 and (img_summed[k+1:]==0).all():
                    end_time = k+1
                    break
#             plt.imshow(img[j,:,:,0], cmap="binary")
#             plt.axvline(end_time)
            curr_input = curr_input[:end_time,:]
            num_of_synapses = np.sum(curr_input, axis=-1)
#             excitatory_synapses.append(np.sum(curr_input[:, :N_EXC], axis=-1).mean())
#             inhibitory_synapses.append(np.sum(curr_input[:, N_EXC:], axis=-1).mean())
            n_synapses.append(num_of_synapses)
            synapses_mu.append(num_of_synapses.mean())
            synapses_std.append(num_of_synapses.std())
            synapses_max.append(num_of_synapses.max())
            synapses_min.append(num_of_synapses.min())
            curr_output = outputs[j]
#             sum_output.append(curr_output.sum())
            all_outputs.append(curr_output)
            if labels[j] not in label_dct:
                label_dct[labels[j]] = []
                synapses_mean_per_label[labels[j]] = {"Sum": [], "Ex":[], "Inh":[]}
            label_dct[labels[j]].append(curr_output)
            synapses_mean_per_label[labels[j]]["Sum"].append(np.sum(curr_input, axis=-1).mean())
            synapses_mean_per_label[labels[j]]["Ex"].append(np.sum(curr_input[:, :N_EXC], axis=-1).mean())
            synapses_mean_per_label[labels[j]]["Inh"].append(np.sum(curr_input[:, N_EXC:], axis=-1).mean())

    mean_dct = {}
    for lbl in label_dct.keys():
        mean_dct[lbl] = np.stack(label_dct[lbl]).mean(axis=0)

    plt.figure(figsize=(15,10))

    plt.subplot(3,3,1)
    plt.title('input values')
    plt.hist(np.array(all_inputs).ravel())

    plt.subplot(3,3,2)
    plt.title('mean of sum synapses')
    plt.hist(synapses_mu)

    plt.subplot(3,3,3)
    plt.title('std of sum synapses')
    plt.hist(synapses_std)

    plt.subplot(3,3,4)
    plt.title('max sum synapses')
    plt.hist(synapses_max)

    plt.subplot(3,3,5)
    plt.title('min sum synapses')
    plt.hist(synapses_min)

    plt.subplot(3,3,6)
    plt.title(f'mean output of neuron (AP)')
    for lbl, value in mean_dct.items():
        plt.plot(value, label=lbl, linewidth=2.)
    plt.ylim(0,1)
    plt.legend()

    plt.subplot(3,3,7)
    plt.title('Excitatory Synapses Mean per Label')
    plt.hist([synapses_mean_per_label[lbl]["Ex"] for lbl in synapses_mean_per_label.keys()], label=list(synapses_mean_per_label.keys()))
    plt.legend()
    plt.xlim((0,15))


    plt.subplot(3,3,8)
    plt.title('Inhibitory Synapses Mean per Label')
    plt.hist([synapses_mean_per_label[lbl]["Inh"] for lbl in synapses_mean_per_label.keys()], label=list(synapses_mean_per_label.keys()))
    plt.legend()
    plt.xlim((0,15))

    
    plt.subplot(3,3,9)
    plt.title('All Synapses Mean per Label')
    plt.hist([synapses_mean_per_label[lbl]["Sum"] for lbl in synapses_mean_per_label.keys()], label=list(synapses_mean_per_label.keys()))
    plt.legend()

    
    
    # plt.axvline(400-257, color='g', linestyle=':', linewidth=2.)

    plt.show()

In [None]:
def plot_statistics_cat(model, input_layer="preNeuronBool"):
    
    all_inputs = []
    n_synapses = []
    synapses_mu = []
    synapses_std = []
    synapses_max = []
    synapses_min = []
    sum_output = []
    all_outputs = []
    excitatory_synapses = []
    inhibitory_synapses = []
    label_dct = {lbl:[] for lbl in LABELS[:NLABELS]}
    synapses_mean_per_label = {lbl:[] for lbl in LABELS[:NLABELS]}
    
    for i, (img, label) in enumerate(train_ds):
        if i > 3: break
        inputs = K.function(model.input, model.get_layer(input_layer).output)([img])
        outputs = K.function(model.input, model.get_layer("SpikeTrain").output)([img])
        all_inputs.append(inputs)
        labels = [LABELS[lbl] for lbl in np.where(label == 1)[1].tolist()]
        for j in range(len(inputs)):
            if len(inputs[j].shape)==3:
                curr_input = np.squeeze(inputs[j], axis=-3)
            else:
                curr_input = inputs[j]
            img_summed = img[j,:,:,0].sum(axis=0)
            for k in range(img_summed.shape[0]-1):
                if img_summed[k] > 0 and (img_summed[k+1:]==0).all():
                    end_time = k+1
                    break
#             plt.imshow(img[j,:,:,0], cmap="binary")
#             plt.axvline(end_time)
            curr_input = curr_input[:end_time,:]
            num_of_synapses = np.sum(curr_input, axis=-1)
            excitatory_synapses.append(np.sum(curr_input[:, :N_EXC], axis=-1).mean())
            inhibitory_synapses.append(np.sum(curr_input[:, N_EXC:], axis=-1).mean())
            n_synapses.append(num_of_synapses)
            synapses_mu.append(num_of_synapses.mean())
            synapses_std.append(num_of_synapses.std())
            synapses_max.append(num_of_synapses.max())
            synapses_min.append(num_of_synapses.min())
            curr_output = outputs[j]
            sum_output.append(curr_output.sum())
            all_outputs.append(curr_output)
            label_dct[labels[j]].append(curr_output)
            
    mean_dct = {}
    for lbl in label_dct.keys():
        mean_dct[lbl] = np.stack(label_dct[lbl]).mean(axis=0)

    plt.figure(figsize=(15,10))

    plt.subplot(3,3,1)
    plt.title('input values')
    plt.hist(np.array(all_inputs).ravel())

    plt.subplot(3,3,2)
    plt.title('mean of sum synapses')
    plt.hist(synapses_mu)

    plt.subplot(3,3,3)
    plt.title('std of sum synapses')
    plt.hist(synapses_std)

    plt.subplot(3,3,4)
    plt.title('max sum synapses')
    plt.hist(synapses_max)

    plt.subplot(3,3,5)
    plt.title('min sum synapses')
    plt.hist(synapses_min)

    plt.subplot(3,3,6)
    plt.title(f'mean output of neuron (AP)')
    for lbl, value in mean_dct.items():
        plt.plot(value, label=lbl, linewidth=2.)
    plt.ylim(0,1)
    plt.legend()


    plt.subplot(3,3,7)
    plt.title('Synapses Mean (per ms per image)')
    plt.hist(excitatory_synapses, label='excitatory', color='b')
    plt.hist(inhibitory_synapses, label='inhibitiroy', color='r')


    plt.subplot(3,3,8)
    plt.title('Dense Layer Weights')
    plt.plot(model.get_layer("nSpikes").get_weights()[0][:,0], color='r')

    plt.subplot(3,3,9)
    
    plt.show()

In [None]:
PLOT = False

In [None]:
if PLOT: plot_examples(model, write_last=True)

In [None]:
if PLOT: plot_examples_cat(model)

In [None]:
if PLOT: plot_statistics(model)

In [None]:
if PLOT: 
    plt.figure(figsize=(10,20))
    weights = model.get_layer("WiringLayer").get_weights()[0][:,0,0,:]
    print(f"min: {weights.min()}, max: {weights.max()}\nmean: {weights.mean()}, sd: {weights.std()}\n")
    exc = weights[:,:N_EXC]
    inh = weights[:,N_EXC:]
    # xExc = K.flatten(tf.transpose(exc) / tf.reduce_max(exc, axis=-1))
    # xInh = K.flatten(tf.transpose(inh) / tf.reduce_max(exc, axis=-1))
    # print(f"Exc: mean per pre-synaptic: {xExc.mean()}, mean binary: {(xExc>PRESYNAPTIC_THRESHOLD).sum(axis=-1).mean()}")
    # print(f"Inh: mean per pre-synaptic: {weights[:,N_EXC:].mean()}, mean binary: {(weights[:,N_EXC:]>PRESYNAPTIC_THRESHOLD).sum(axis=-1).mean()}")
    amountExc = [(exc>0.01).sum(axis=-1)]
    amountInh = [(inh>0.01).sum(axis=-1)]

    plt.subplot(3,1,1)
    plt.hist(amountExc)
    plt.title("Historgram of Excitatory Synapses per Neuron")
    plt.ylabel("Synapses per Neuron")

    plt.subplot(3,1,2)
    plt.hist(amountInh)
    plt.title("Historgram of Inhibitiory Synapses per Neuron")
    plt.ylabel("Synapses per Neuron")



    # plt.subplot(3,1,1)
    # plt.title("Histogram Synapses per Presynaptic Neuron (Exc synapse)")
    # plt.hist((weights[:,:N_EXC]>PRESYNAPTIC_THRESHOLD).sum(axis=-1))
    # plt.ylabel("nSynapses")

    # plt.subplot(3,1,2)
    # plt.title("Histogram Synapses per Presynaptic Neuron (Inh synapse)")
    # plt.hist((weights[:,N_EXC:]>PRESYNAPTIC_THRESHOLD).sum(axis=-1))
    # plt.ylabel("nSynapses")

    plt.subplot(3,1,3)
    plt.title("Weights")
    plt.imshow(weights, cmap="Reds", vmin=weights.min(), vmax=weights.max())
    plt.axvline(N_EXC, color='r')
    plt.xlabel("synapse")
    plt.ylabel("neuron")

    plt.show()

# Save Weights

In [None]:
if PLOT:
    with open("weights_Nadam_84.npy", 'wb') as f:
        np.save(f, model.get_layer("WiringLayer").get_weights()[0])
    with open("bias_Nadam_84.npy", 'wb') as f:
        np.save(f, model.get_layer("WiringLayer").get_weights()[1])

# Evolution

Never actually ran it. But it's an idea.

In [None]:
class EvolutionModule:
    generations = {}
    results = {}
    alive = set()
    
    def __init__(self, ID, generation, create_func, noise=.1, parent=None):
        self.module = None
        
        self.generation = generation
        self.cur_generation = generation
        self.ID = ID
        self.parent = parent
        self.name = "Module" + (f"({parent.__str__()[6:]})-" if parent else "") + f"[{generation}.{ID}]"
        
        if generation not in EvolutionModule.generations: EvolutionModule.generations[generation] = [self]
        else: EvolutionModule.generations[generation].append(self)
        EvolutionModule.alive.add(self)
                   
        self.history = None
        
        self.children = set()
        self.status = True
        
        self.create_func = create_func
        self.original_weights = None
        self.noise = noise
        
        self._create_model()
        
    
    def _create_model(self):
        self.module = self.create_func()
        if self.parent:
            self.original_weights = [weight for weight in self.parent.get_weights()]
            self.original_weights[2] = self.original_weights[2] + np.random.rand(*self.original_weights[2].shape) * self.original_weights[2].std() * self.noise
            self.module.set_weights(self.original_weights)
    
    def give_birth(self, n):
        self.cur_generation += 1
        for i in range(n):
            self.children.add(EvolutionModule(i, self.cur_generation, self.create_func, self.noise, self))
    
    def is_alive(self):
        return self.status
    
    def fit(self, training_ds, epochs, validation_data):
        self.to_print()
        history = self.module.fit(train_ds, epochs=epochs, validation_data=validation_data).history
        if not self.history: self.history = history
        else: 
            for key, value in history.items(): 
                self.history[key].extend(value)
    
    def die(self):
        self.status = False
        EvolutionModule.alive.remove(self)
    
    def value(self, n):
        if not self.history: return None
        return np.mean(self.history['loss'][-n:])
    
    def accuracy(self):
        return self.history['val_logits_binary_accuracy'][-1]
    
    def value_to_print(self, n):
        return self.history['loss'][-1], np.mean(self.history['loss'][-n:]), self.history['val_loss'][-1], np.mean(self.history['val_loss'][-n:]), self.history['val_logits_binary_accuracy'][-1], np.mean(self.history['val_logits_binary_accuracy'][-n:])
    
    def to_print(self):
        print(f"\n Beginning fitting for {self.name}...\n")
    
    def __str__(self):
        return self.name
    
    def get_weights(self):
        return self.module.get_weights()

In [None]:
def create_model():
    return playable(saccades=None, xaxis=True, use_sigmoid=True, sigmoid_mult=50, to_bool=True, useSynapse=True, qSynapse=0.1, sigmoid_threshold=0.8, augment=False, optimizer=SGD(5e-3))

In [None]:
def run_evolution(nStart=10, nEpochs=10, nMax=2, nPerMax=3, nGenerations=5, rand=0.1, measure_by=2):
    EvolutionModule.generations.clear()
    EvolutionModule.results.clear()
    EvolutionModule.alive.clear()
    
    print("\nGeneration 0:\n")
    for i in range(nStart):
        module = EvolutionModule(i, 0, create_model, rand)
        module.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
    EvolutionModule.results[0] = sorted(EvolutionModule.alive, key=lambda x: x.value(measure_by))
    print("\n" + "* "*10)
    print(f'Generation: 0:\n Best accuracy (by loss): {EvolutionModule.results[0][0].accuracy()} by {EvolutionModule.results[0][0]}\n')
    for module in EvolutionModule.alive:
        print(module, end=': ')
        print("Last loss: {0}, Mean loss: {1}, Last val_loss: {2}, Mean val_loss: {3}, Last val accuracy: {4}, Mean val acc: {5}".format(*module.value_to_print(measure_by)))
    print("\n" + "* "*10+'\n')
    
    
    for generation in range(1, nGenerations):    
        print(f"\nGeneration {generation}:\n")
        for module, indicator in zip(EvolutionModule.results[generation-1], [True]*nMax+[False]*len(EvolutionModule.alive)):
            if indicator: module.give_birth(nPerMax)
            else: module.die()
        for module in EvolutionModule.alive:
            module.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
        
        EvolutionModule.results[generation] = sorted(EvolutionModule.alive, key=lambda x: x.value(measure_by))
        print("\n" + "* "*10)
        print(f'Generation: {generation}:\n Best accuracy (by loss): {EvolutionModule.results[generation][0].accuracy()}\n')
        for module in EvolutionModule.alive:
            print(module, end=': ')
            print("Last loss: {0}, Mean loss: {1}, Last val_loss: {2}, Mean val_loss: {3}, Last val accuracy: {4}, Mean val acc: {5}".format(*module.value_to_print(measure_by)))
        print("\n" + "* "*10+'\n')
        
    print("\n~ ~ ~ ~ ~ ~ Done! ~ ~ ~ ~ ~ ~\n")
    return EvolutionModule.results[-1][0], EvolutionModule.generations, EvolutionModule.results, EvolutionModule.alive

In [None]:
def evolution(result_dct={}, nStart=10, nEpochs=10, nMax=2, nPerMax=3, nGenerations=5, rand=0.1, measure_by=3):
    result_dct[0] = []
    for i in range(nStart):
        print("\nStarting first generation module number", i, "...\n")
        model = create_model()
        history = model.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
        result_dct[0].append((model, history.history['loss'], history.history['val_loss'], history.history['val_logits_binary_accuracy']))
    for generation in range(1,nGenerations):
        result_dct[generation] = []
        max_chosen = sorted(result_dct[generation-1], key=lambda x: sum(x[1][-measure_by:]))[:nMax]
        
        print("\n" + "* "*10)
        print(f'Generation: {generation}:\n Best accuracy (by loss): {max_chosen[0][-1][-1]}\n')
        for result in result_dct[generation-1]:
            print("last loss:", result[1][-1], " loss mean:", np.mean(result[1][-measure_by:]),"   last accuracy:", result[-1][-1], "  accuracy mean:", np.mean(result[-1][-measure_by:]))
        print("\n" + "* "*10+'\n')
        
        print("Starting generation ", generation+1,"...\n")
        i = 1
        for max_model, _, _ in max_chosen:
            old_weights = max_model.get_weights()
            for _ in range(nPerMax):
                print("Starting module number ", i, " out of", nMax*(nPerMax+1), "\n")
                new_weights = [weight for weight in old_weights]
                old_conv_weights = new_weights[2]
                new_weights[2] = old_conv_weights + np.random.rand(*old_conv_weights.shape) * old_conv_weights.std() * rand
                model = create_model()
                model.set_weights(new_weights)
                history = model.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
                result_dct[generation].append((model, history.history['loss'], history.history['val_loss'], history.history['val_logits_binary_accuracy']))
                i+=1
            print("\nStarting father...\n")
            history = max_model.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
            result_dct[generation].append((model, history.history['loss'], history.history['val_loss'], history.history['val_logits_binary_accuracy']))
    max_chosen = sorted(result_dct[generation-1], key=lambda x: sum(x[1][-measure_by:]))[:nMax]
        
    print("\n" + "* "*10)
    print(f'Generation: {nGenerations}:\n Best accuracy (by loss): {max_chosen[0][-1][-1]}\n')
    for result in result_dct[nGenerations-1]:
        print("last loss:", result[2][-1], " loss mean:", np.mean(result[2][-measure_by:]),"   last accuracy:", result[-1][-1], "  accuracy mean:", np.mean(result[-1][-measure_by:]))
    print("\n" + "* "*10+'\n')
    return result_dct

In [None]:
# best_module, generations, results, alive = run_evolution(nStart=10, nEpochs=10, nMax=3, nPerMax=2, nGenerations=6)
# print(best_module.value_to_print(1))

In [None]:
# plot_examples(best_module)

In [None]:
# plot_statistics(best_module)

In [None]:
# cur_model = best_module
# while cur_model:
#     print(" * * * * * * *"*2)
#     print("Current module:", cur_model)
#     print("current weights:", best_module.get_weights()[2].std())
#     print("Original weights:", best_module.original_weights[2].std())
#     cur_model = cur_model.parent