In [None]:
import keras
import tensorflow as tf
import keras.layers as layers
from keras.preprocessing.image import load_img
from keras.preprocessing.image import array_to_img
import numpy as np
from math import floor, ceil
from tensorflow.python.ops import math_ops
from tensorflow import math, random, shape
import os
from keras.losses import MeanSquaredError, BinaryCrossentropy
from keras.optimizers import Nadam, SGD, Adam, Adamax
from keras.activations import sigmoid
from tensorflow import convert_to_tensor as tens
from keras import backend as K
from cv2 import getGaborKernel as Gabor
from functools import reduce
from matplotlib import pyplot as plt
from math import sqrt
import itertools
import re
from random import shuffle, seed
from tensorflow.keras.utils import Sequence
from keras.constraints import NonNeg
from keras.regularizers import l1,l2,l1_l2
from keras.initializers import RandomNormal
import tensorflow_addons as tfa

In [None]:
N_EXC = 639
BATCH_SIZE = 32
RESOLUTION = 8
EXCITATORY_SYNAPSES_WANTED = 8
INHIBITORY_SYNAPSES_WANTED = 6

# Transform data to dataset

In [None]:
directory = '/kaggle/input/cat-and-dog/'
LABELS = ['Cat', 'Dog']

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory+'training_set/training_set',
    labels='inferred',
    color_mode='grayscale',
    validation_split=0.2,
    subset="training",
    seed=1337,
    batch_size=BATCH_SIZE,
)
valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory+'training_set/training_set',
    labels='inferred',
    color_mode='grayscale',
    validation_split=0.2,
    subset="validation",
    seed=1337,
    batch_size=BATCH_SIZE,
)

test_ds = keras.preprocessing.image_dataset_from_directory(directory+'test_set/test_set', labels='inferred', color_mode='grayscale', batch_size=BATCH_SIZE)

show some photos

In [None]:
plt.figure(figsize=(10, 10))
for i, (images, labels) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[0].numpy()[:,:,0], cmap='gray')
    plt.title(LABELS[int(labels[0])])
    plt.axis("off")

# Use David Beniaguev's (selfishgene) trained L5PC model

In [None]:
class Module:
    def __init__(self, path, module_name, name):
        self.name = name
        self.module = self.create_module(path, module_name)
        self.input = self.module.input
    
    def __call__(self, inp):
        return self.module(inp)
    
    def create_module(self, path, module_name):
        models_folder  = os.path.join(path, 'Models')
        model_filename  = os.path.join(models_folder, module_name)
        old_model = keras.models.load_model(model_filename)
        inp = keras.Input(shape=old_model.layers[0].input.shape[1:])
        x = old_model.layers[1](inp)
        for layer in old_model.layers[2:-2-1]:
            x = layer(x)
        output = old_model.layers[-3](x)
        module = keras.Model(inp, output)
#         L5PC_model.compile(optimizer=Nadam(lr=0.0001), loss='binary_crossentropy', loss_weights=[1.])

        for layer in module.layers:
            layer.trainable = False

        module.summary()
        return module       

In [None]:
class Modules:
    def __init__(self, neurons):
        self.multiple = isinstance(neurons, list)
        self.neurons = neurons
        self.input = neurons[0].input if self.multiple else neurons.input
        if not self.multiple:
            print("Neuron module is", self.neurons.name)
        else:
            message = f"Averaging between {len(self.neurons)} neurons:\n"
            for neuron in self.neurons:
                message += f"\t- {neuron.name}\n"
            print(message)
                
    def __call__(self, inputs):
        if self.multiple:
            outputs = math_ops.reduce_mean(tf.stack([module(inputs) for module in self.neurons], axis=-1), axis=-1)
        else:
            outputs = self.neurons(inputs)
        return outputs  

In [None]:
multiple_neurons = True
dataset_folder = '/kaggle/input/single-neurons-as-deep-nets-nmda-test-data'

In [None]:
if multiple_neurons:
    neurons = Modules([Module(dataset_folder, "NMDA_TCN__DWT_8_224_217__model.h5", "Module_8_layers"), 
                       Module(dataset_folder, "NMDA_TCN__DWT_7_292_169__model.h5", "Module_7_layers"), 
                       Module(dataset_folder, "NMDA_TCN__DWT_9_256_241__model.h5", "Module_9_layers")])
else:
    neurons = Module(dataset_folder, "NMDA_TCN__DWT_7_128_153__model.h5", "module_7_layers")

# Some custom layers and funcs to use

In [None]:
class ToBoolLayer(keras.layers.Layer):
  def __init__(self, threshold=0.5, mult=15, use_sigmoid=True, use_special_sigmoid=None, name="ToBoolLayer"):
      super().__init__(name=name)
      self.use_sigmoid = use_sigmoid
      self.special_sigmoid = use_special_sigmoid
      self.threshold = threshold
      self.mult = mult

  def build(self, shape):
    if self.special_sigmoid is not None:
      if not len(self.special_sigmoid) == 2:
        raise Exception("Special sigmoid should be in format (pos_mult, neg_mult).")
      else:
        self.sigmoid = SigmoidThreshold(*self.special_sigmoid, self.threshold)
    elif self.use_sigmoid:
      self.sigmoid = SigmoidThresholdEasier(threshold=self.threshold, mult=self.mult)
    
  def call(self, inputs, training=None):
    if training is False:
        inputs = inputs - self.threshold
        inputs = math_ops.ceil(inputs)
    elif self.use_sigmoid:
        inputs = self.sigmoid(inputs)
    return inputs

In [None]:
class SpikeProcessor(keras.layers.Layer):
    def __init__(self, nSynapses, start=None, end=None, name="SpikeProcessor"):
        super().__init__(name=name)
        self.nSynapses = nSynapses
        self.start = start
        self.end = end

    def build(self, shape):
        self.dims = len(shape)
        pass
    
    def call(self, inputs):
        if self.start is not None:
            if self.end is not None:
                inputs = inputs[:,:,:,self.start:self.end] if self.dims==4 else inputs[:,:,self.start:self.end]
            else:
                inputs = inputs[:,:,:,self.start:] if self.dims==4 else inputs[:,:,self.start:]
        elif self.end is not None:
            inputs = inputs[:,:,:,:self.end] if self.dims==4 else inputs[:,:,:self.end]
        preds_sum = math.reduce_sum(inputs, axis=-1, keepdims=True)
        return_value = preds_sum/self.nSynapses
        return layers.Flatten()(return_value)

# Previous funcs and layers I don't use anymore

In [None]:
class OneTimeStamp2Many(keras.layers.Layer):
    def __init__(self, units):
        super().__init__()
        self.units = units
    
    def build(self, input_shape):
        self.w = self.add_weight(shape=(self.units, input_shape[1]),
                               initializer='random_normal',
                               trainable=True)
        self.b = self.add_weight(shape=(1,input_shape[1]),
                               initializer='zeros',
                               trainable=True)
    def call(self, inputs):
        inputs = K.expand_dims(inputs, axis=1)
        return tf.multiply(inputs, self.w) + self.b

In [None]:
def noise_init(shape, dtype=None):
    return tf.cast(
        tf.concat(
        [tf.random.categorical(tf.math.log([[0.92, 0.08]]), shape[-1]) 
         for _ in range(shape[-2])], axis=0), 
        dtype)

In [None]:
class AddNoise(keras.layers.Layer):
    def __init__(self, dtype=None):
        super().__init__()
        self.datatype=dtype

    def build(self, shape):
        self.noise = noise_init(shape, self.datatype)
    
    def call(self, inputs):
        return math.multiply(inputs, self.noise)

# Custom Funcs

In [None]:
def loss_for_me(y_true, y_preds):
    return MeanSquaredError()(1, y_preds)

In [None]:
def spikes_preds_processing(preds, n_spikes=50):
    preds_sum = math.reduce_sum(preds, axis=-1, keepdims=True)
    return preds_sum/(n_spikes)

In [None]:
def SigmoidThreshold(pos_mult=1, neg_mult=1, threshold=0):
    """returns a sigmoid function [0,1]->[0,1] with the center at threshold given, and slope multified"""
    def sigmoid_threshold(x):
        new_x = -x+threshold
        multiplier = (pos_mult + neg_mult + (-pos_mult + neg_mult)*K.sign(new_x)) / 2 
        return 1/(1+math.exp(multiplier*(new_x)))
    return sigmoid_threshold

In [None]:
x = np.arange(0,1,0.01)
threshold = 1/16
pos_mult = 10
neg_mult = 50
plt.plot(x, SigmoidThreshold(pos_mult, neg_mult, threshold)(x))
plt.ylim(0,1)
plt.axhline(0.5, color='r', linestyle='--')
plt.axvline(threshold, color='g', linestyle='--')
plt.show()

In [None]:
def SigmoidThresholdEasier(mult=1, threshold=0.5):
  """returns a sigmoid function [0,1]->[0,1] with the center at threshold given, and slope multified"""
  def sigmoid_threshold(x):
      return 1/(1+math.exp(mult*(threshold-x)))
  return sigmoid_threshold

how the new sigmoid looks like

In [None]:
x = np.arange(0,1,0.01)
threshold = 0.9
mult = 50
plt.plot(x, SigmoidThresholdEasier(mult, threshold)(x))
plt.ylim(0,1)
plt.xlim(0,1)
plt.axhline(0.5, color='r', linestyle='--')
plt.axvline(threshold, color='g', linestyle='--')
plt.show()

will be in use later

In [None]:
def set_num_syn_loss(syns_wanted_per_ms=50):
    def num_syn_loss(y_true, y_preds):
        return MeanSquaredError()(1,y_preds/syns_wanted_per_ms)
    return num_syn_loss

In [None]:
class GetRandomInt():
    def __init__(self, maximum):
        self.maximum = maximum
    def call(self):
        return np.random.randint(0, self.maximum, 1)[0]

In [None]:
class ToNeuronInput(keras.layers.Layer):
    """flattens the temporal dimensions, order by new_order and pads with 0's to fill"""
    def __init__(self, full, padding=0, new_order=None, name="NeuronInput", part=200, batch_size=BATCH_SIZE):
        super().__init__(name=name)
        self.full = full
        self.new_order = None
        self.zero_padding = isinstance(padding, int) and padding == 0
        self.padding = padding
        self.part = part
        self.batch_size = batch_size
    
    def build(self, shape):
        self.ms = shape[1]*shape[2]
        self.times = self.part // self.ms
        self.padding_amount = self.full-self.ms*self.times
    
    def call(self, inputs, padding):
                
        new_inp = layers.Reshape((self.ms, inputs.shape[-1]))(inputs)
        if self.new_order is not None:
            new_inp = self.gather(new_inp, self.new_order, axis=-2)
        if self.times > 1:
            new_inp = layers.Concatenate(axis=1)([new_inp for _ in range(self.times)])
        if self.zero_padding and False:
            new_inp = layers.ZeroPadding1D(padding=(self.full - self.ms*self.times, 0))(new_inp)
        else:
            padding = tf.reshape(padding, (padding.shape[0]*padding.shape[1]*padding.shape[2], padding.shape[-1]))
            starting_time = layers.Lambda(lambda x: self.ranInt(x))(padding)
            padding = padding[np.newaxis, starting_time:starting_time+self.padding_amount]
            new_inp = layers.Concatenate(axis=-2)([tf.tile(padding, [tf.shape(new_inp)[0], 1, 1]), new_inp])
        return new_inp

    def ranInt(self, x):
        return K.random_uniform((1,), 0, 1024-self.padding_amount, dtype=tf.int32)[0]#.numpy()
    
    @tf.function
    def gather(x, ind):
        return tf.gather(x+0, ind)

In [None]:
class PredictNeuron(keras.layers.Layer):
    """flattens the temporal dimensions, order by new_order and pads with 0's to fill"""
    def __init__(self, neuron_module, name="NeuronPrediction"):
        super().__init__(name=name)
        self.neurons = neuron_module
        self.single = not isinstance(self.neurons, list)
        if self.single:
            print("Neuron module is", self.neurons.name)
        else:
            message = f"Averaging between {len(self.neurons)} neurons:\n"
            for neuron in self.neurons:
                message += f"\t- {neuron.name}\n"
            print(message)
                
    def call(self, inputs):
        if not self.single:
            outputs = math_ops.reduce_mean(tf.stack([module.module(inputs) for module in self.neurons], axis=-1), axis=-1)
        else:
            outputs = self.neurons.module(inputs)
        return outputs

# Synapse Loss

In [None]:
def MeanSquaredErrorSynapsesPerMS(batch_size=32):
  def mean_squared_error_synapses_per_ms(y_true, y_preds):
    squared_difference = tf.square(1.-y_preds)
    mean = tf.reduce_mean(squared_difference, axis=-1)
    return mean
  return mean_squared_error_synapses_per_ms

In [None]:
def MSE_RMS_SynapsesPerMS(wanted, size=N_EXC, batch_size=32, eps=1e-3):
    def mse_rms(X, mu):
        return tf.math.reduce_mean(tf.math.reduce_sum(tf.square(tf.math.sqrt(tf.math.reduce_mean(tf.square(X+eps), axis=-1)) - mu), axis=-1))
    real_wanted = np.sqrt(wanted / size)
    zeros_wanted = np.sqrt(1. - wanted / size)
    def mse_rms_synapses_per_ms(y_true, y_preds):
        result = tf.math.sqrt((mse_rms(y_preds, real_wanted)**2 + mse_rms(1. - y_preds, zeros_wanted)**2) / 2)
        return result
    return mse_rms_synapses_per_ms

In [None]:
class MeanSynapsesPerMsMetric:
    def __init__(self, name="nSynapses"):
        self.name=name
    def __call__(self, y_true, y_pred):
        booleans = tf.cast(tf.math.greater(y_pred, 0.5), tf.float32)
        summed = tf.math.reduce_sum(booleans, axis=-1)
        mean_by_ms = tf.math.reduce_mean(summed, axis=-1)
        meaned_batches = tf.math.reduce_mean(mean_by_ms, axis=-1)
        return meaned_batches
# def MeanSynapsesPerMsMetric(y_true, y_pred):
#     booleans = tf.cast(tf.math.greater(y_pred, 0.5), tf.float32)
#     summed = tf.math.reduce_sum(booleans, axis=-1)
#     mean_by_ms = tf.math.reduce_mean(summed, axis=-1)
#     meaned_batches = tf.math.reduce_mean(mean_by_ms, axis=-1)
#     return meaned_batches

In [None]:
# class MeanSynapsesPerMsMetric(tf.keras.metrics.Metric):
#     def __init__(self, name="SynapsesPerMs", dtype=tf.float32):
#         super().__init__(name=name, dtype=dtype)
#         self.curr_result = self.add_weight(name='nSyns', initializer='zeros')
#         self.dtype=dtype
        
#     def update_state(self, y_true, y_pred, sample_weight=None):
#         booleans = tf.cast(tf.math.greater(y_pred, 0.5), self.dtype)
#         summed = tf.math.reduce_sum(booleans, axis=-1)
#         mean_by_ms = tf.math.reduce_mean(summed, axis=-1)
#         meaned_batches = tf.math.reduce_mean(mean_by_ms, axis=-1)
#         if sample_weight is not None:
#             sample_weight = tf.cast(sample_weight, self.dtype)
#             meaned_batches = tf.multiply(meaned_batches, sample_weight)
#         self.curr_result.assign_add(meaned_batches)
        
#     def result(self):
#         return tf.cast(self.curr_result)
#     def reset_states(self):
#         self.curr_result.assign(0.)
#         self.n.assign(0)

In [None]:
mean = 8
samples = 10
synapses = 100
mu = mean / synapses

f = lambda x, y: MSE_RMS_SynapsesPerMS(mean, size=synapses)(x, np.tile(y[np.newaxis], (BATCH_SIZE,1,1)))
metric = lambda x: MeanSynapsesPerMsMetric()(None, np.tile(x[np.newaxis], (BATCH_SIZE,1,1)))#()
update = lambda x: metric.update_state(None, np.tile(x[np.newaxis], (BATCH_SIZE,1,1)))

roun = lambda x: str(round(x.numpy(), 3))

good = np.array([[1.]*mean + [0]*(synapses - mean)]*samples)
for i in range(samples): np.random.shuffle(good[i])

# update(good)
plt.figure(figsize=(10,5))    
plt.suptitle(f"Loss with wanted mean {mean}, {samples} samples and {synapses} synapses")
plt.subplot(3,2,1)
plt.title(f"perfect: {roun(f(None, good))}; nSyns {metric(good)}")
plt.imshow(good, cmap="gray", vmin=0, vmax=1)
# metric.reset_states()

plt.subplot(3,2,2)
pretty_good = np.clip(good + np.random.normal(0, .3, good.shape), 0 ,1)
# update(pretty_good)
plt.title(f"good: {roun(f(None, pretty_good))}; nSyns {metric(pretty_good)}")
plt.imshow(pretty_good, cmap="gray", vmin=0, vmax=1)
# metric.reset_states()

plt.subplot(3,2,3)
opposite = np.array([[not i for i in row] for row in good.astype(np.int8)], dtype=np.float32) #np.clip(good + np.random.normal(0, .6, good.shape), 0, 1)
# update(opposite)
plt.title(f"opposite: {roun(f(None, opposite))}; nSyns {metric(opposite)}")
plt.imshow(opposite, cmap="gray", vmin=0, vmax=1)
# metric.reset_states()

plt.subplot(3,2,4)
bad = np.clip(np.random.normal(.5, .25, samples*synapses).reshape((samples, synapses)), 0, 1)
# update(bad)
plt.title(f"random: {roun(f(None, bad))}; nSyns {metric(bad)}")
plt.imshow(bad, cmap="gray", vmin=0, vmax=1)
# metric.reset_states()


plt.subplot(3,2,5)
zeros = np.zeros((samples, synapses))
# update(zeros)
plt.title(f"zeros: {roun(f(None, zeros))}; nSyns {metric(zeros)}")
plt.imshow(zeros, cmap="gray", vmin=0, vmax=1)
# metric.reset_states()


plt.subplot(3,2,6)
ones = np.ones((samples, synapses))
# update(ones)
plt.title(f"ones: {roun(f(None, ones))}; nSyns {metric(ones)}")
plt.imshow(ones, cmap="gray", vmin=0, vmax=1)  
# metric.reset_states()

plt.show()

In [None]:
def MaxErrorSynapsesPerMS(batch_size=32, y_true_for_real=False):
    def max_error_synapses_per_ms(y_true, y_preds):
        maxes = tf.reduce_max(y_preds, axis=-1)
        squared_difference = tf.square(1-y_preds)
        mean = tf.reduce_mean(squared_difference, axis=-1)
        return mean
    return max_error_synapses_per_ms

In [None]:
def EntropySynapseLoss(batch_size=32):
    def entropy_synapse_loss(y_true, y_preds):
        flattened = layers.Flatten()(y_preds)
        return -math_ops.reduce_sum(flattened * math_ops.log(flattened), -1)
    return entropy_synapse_loss

In [None]:
def VarianceSynapseLoss(k=1, batch_size=32):
    def variance_synapse_loss(y_true, y_preds):
        flattened = layers.Flatten()(y_preds)
        mean = layers.Flatten()(layers.RepeatVector(flattened.shape[-1])(K.expand_dims(math_ops.reduce_mean(flattened, axis=-1), axis=-1)))
        variance = math_ops.reduce_sum((flattened - mean)**2, -1) / (flattened.shape[-1] - 1)
        return math_ops.exp(-k * variance)
    return variance_synapse_loss

In [None]:
max_acceptable_spikes_per_ms = 3.0
max_acceptable_spikes_deviation = 20.0
activity_reg_constant = 0.2 * 0.0028


def pre_synaptic_spike_regularization(activation_map):
    # sum over all dendritic locations
    x = K.sum(activation_map, axis=2)

    # ask if it's above 'max_acceptable_spikes_per_ms'
    x = K.relu(x - max_acceptable_spikes_per_ms)

    # if above threshold, apply quadratic penelty
    x = K.square(x / max_acceptable_spikes_deviation)

    # average everything
    x = activity_reg_constant * K.mean(x)

    return x

# **Preprocessing**

Parameters for gabor filters and max pooling kernel sizes are from "Robust Object Recognition with Cortex-Like Mechanisms" (Serre et al.)

> # Gabor filters - Simple Cells

In [None]:
n_filters = 64
n_orientations = 4

ksizes = [(i, i) for i in range(7,38, 2)]
thetas = [0 , (45 / 180) * np.pi, (90 / 180) * np.pi, (135 / 180) * np.pi]
gammas = [0.3] * 16
sigmas = [2.8, 3.6, 4.5, 5.4, 6.3, 7.3, 8.2, 9.2, 10.2, 11.3, 12.3, 13.4, 14.6, 15.8, 17., 18.2]
lambdas = [3.5, 4.6, 5.6, 6.8, 7.9, 9.1, 10.3, 11.5, 12.7, 14.1, 15.4, 16.8, 18.2, 19.7, 21.2, 22.8]

all_filters = [[(size, sigma, theta, lambd, gamma) for theta in thetas] 
               for size, gamma, sigma, lambd in zip(ksizes, gammas, sigmas, lambdas)]
all_filters = reduce(lambda x,y: x+y, all_filters, [])
reoredering_inds = [ 0,0+4,   1, 1+4,  2, 2+4,  3, 3+4,
                     8,8+4,   9, 9+4, 10,10+4, 11,11+4,
                    16,16+4, 17,17+4, 18,18+4, 19,19+4,
                    24,24+4, 25,25+4, 26,26+4, 27,27+4,
                    32,32+4, 33,33+4, 34,34+4, 35,35+4,
                    40,40+4, 41,41+4, 42,42+4, 43,43+4,
                    48,48+4, 49,49+4, 50,50+4, 51,51+4,
                    56,56+4, 57,57+4, 58,58+4, 59,59+4]

all_filters_reordered = [all_filters[k] for k in reoredering_inds]

Gabor initializer

In [None]:
class GaborInitializer(tf.keras.initializers.Initializer):
    def __init__(self, size, sigma, theta, lambd, gamma):
        self.ksize = size
        self.sigma = sigma
        self.theta = theta
        self.lambd = lambd
        self.gamma = gamma

    def __call__(self, dtype=None):
        return tens(Gabor(self.ksize, self.sigma, self.theta, self.lambd, self.gamma))

    def get_config(self):  # To support serialization
        return {'ksize': self.ksize, 'sigma': self.sigma, 'theta': self.theta, 'lambda': self.lambd, 'gamma': self.gamma}

In [None]:
max_filter_shape = all_filters_reordered[-1][0]
center_pixel_ind = int((max_filter_shape[0] - 1 ) / 2)

# place to store all filter activations
filters_matrix = np.zeros((max_filter_shape[0], max_filter_shape[1], len(all_filters)))

plt.figure(figsize=(18,22))
plt.subplots_adjust(left=0.04, right=0.96, bottom=0.04, top=0.96, hspace=0.25, wspace=0.1)
for i, filt in enumerate(all_filters_reordered):
    filter_size  = filt[0][0]
    oritentation = (filt[2] / np.pi ) * 180
    curr_sigma   = filt[1]
    curr_lambda  = filt[3]
    
    half_filter_size = int((filter_size -1 ) / 2)
    upper_left_start = center_pixel_ind - half_filter_size
    
    curr_small_filter = GaborInitializer(*filt)().numpy()
    curr_full_filter = np.zeros((max_filter_shape))
    curr_full_filter[upper_left_start:upper_left_start + filter_size, upper_left_start:upper_left_start + filter_size] = curr_small_filter
    
    # store the filter and the activations for later
    filters_matrix[:,:,i] = curr_full_filter
    
    plt.subplot(8,8,i+1);
    plt.title('%dx%d, $\Theta=%d$, \n$\sigma=%.1f$, $\lambda=%.1f$' %(filter_size, filter_size, oritentation, curr_sigma, curr_lambda))
    plt.imshow(curr_full_filter, cmap='gray')
    plt.axis("off")

Function to create Conv2D layer waith gabor filter

In [None]:
def simple_cells_module(filters_matrix, strides=(1,1), input_shape=(256,256)):
    # input is a single channel gray scale image
    input_tensor = keras.Input(shape=(input_shape[0],input_shape[1],1))
    
    # initializer with predefined weights
    def gabor_filters_init(shape, dtype=None):
        return -filters_matrix[:,:,np.newaxis,:]

    num_filters = filters_matrix.shape[2]
    kernel_size = (filters_matrix.shape[0], filters_matrix.shape[1])

    # single conv2d layer with all weights
    conv_2d_layer = layers.Conv2D(num_filters, kernel_size, strides=strides, kernel_initializer=gabor_filters_init, padding='same')
    conv_2d_layer.trainable=False
    simple_cell_activations = conv_2d_layer(input_tensor)
    simple_cell_module = keras.Model(input_tensor, simple_cell_activations)
    
    return simple_cell_module

In [None]:
max_pooling_size = list(range(8,23,2))
pooling_size_list = [[(s,s,2)] for s in max_pooling_size]
num_orientations = 4

In [None]:
def pre_process_module(filter_matrix, pooling_size_list, strides=(1,1), num_orientations=4, input_shape=(256,256)):
    
    print('num gabor filters must be %d' %(sum([x[0][-1] for x in pooling_size_list]) * num_orientations))

    # input is a single channel gray scale image
    input_tensor = keras.Input(shape=(input_shape[0],input_shape[1],1))
    
    # calc simple cell activations
    simple_cells_activations = K.expand_dims(simple_cells_module(filters_matrix, strides=(1,1), input_shape=(256,256))(input_tensor), axis=-1)
    
    # add support for subsampling
    strides_to_use = (strides[0], strides[1], 2)

    # apply max pooling
    maxpool_3d_layers = []
    for k, pool_size in enumerate(pooling_size_list):
        start_ind = num_orientations * pool_size[0][-1] * k
        end_ind   = num_orientations * pool_size[0][-1] * (k + 1)
        curr_pool_input_slice  = simple_cells_activations[:,:,:,start_ind:end_ind]
        curr_pool_output_slice = layers.MaxPooling3D(pool_size=pool_size[0], strides=strides_to_use, padding='same')(curr_pool_input_slice)
        maxpool_3d_layers.append(curr_pool_output_slice)
    
    # squeeze the last dimention
    concatenated_pooled_layers = K.squeeze(layers.Concatenate(axis=-2)(maxpool_3d_layers), axis=-1)
    
    # wrap as module and return
    complex_cell_module = keras.Model(input_tensor, concatenated_pooled_layers)
    
    return complex_cell_module

In [None]:
def preprocessor(filters_matrix, pooling_sizes, n_orientations=4, conv_strides=(1,1), max_pooling_strides=(1,1), input_shape=(256,256)):
    return pre_process_module(filters_matrix, pooling_sizes, strides=max_pooling_strides, num_orientations=n_orientations, input_shape=input_shape)


In [None]:
processor = preprocessor(filters_matrix, pooling_size_list, max_pooling_strides=(RESOLUTION,RESOLUTION))

In [None]:
def conv2d_with_gabor(filters, trainable=False):
    layer = layers.Conv2D(len(filters), filters[0][0], kernel_initializer='zeros', padding='same')
    initalize = layer(np.zeros((1,256,256,1)))
    new_weights = np.stack([GaborInitializer(*filt)() for filt in filters]).transpose((1,2,0))[:,:,np.newaxis,:]
    layer.set_weights([new_weights, layer.get_weights()[-1]])
    layer.trainable = trainable
    return layer

In [None]:
def preprocess(inp):
    normalized = inp / 255
    pre_process_s1 = [layers.Concatenate(axis=-2)
                  ([K.expand_dims(conv2d_with_gabor(all_filters[n_orientations*j:n_orientations*(j+1)], 
                                                    trainable=False)(inp),
                                  axis=-2) for j in (i, i+1)]) 
                  for i in range(0,n_filters//n_orientations,2)]
    pre_process_c1 = layers.Concatenate(axis=-2)([(layers.MaxPooling3D(pool_size=(ksize, ksize, 2), strides=(1,1,2), padding="same")(i)) 
                  for i, ksize in zip(pre_process_s1, max_pooling_size)])
    return pre_process_c1

Show preprocessing on one image from dataset

In [None]:
plt.figure(figsize=(10, 10))
for _, (image, label) in enumerate(train_ds.take(1)):
    image = image[0].numpy()[:,:,0]
    label = int(label[0])
    processed = processor(image[np.newaxis,:,:,np.newaxis]).numpy()[0]
    for i in range(32):
        ax = plt.subplot(8, 4, i+1)
        plt.imshow(processed[:,:,i], cmap='gray')
        plt.title(LABELS[label])
        plt.axis("off")

# Noisy Optimizer (not used)
https://arxiv.org/abs/1511.06807

https://github.com/cpury/keras_gradient_noise

In [None]:
import inspect
import importlib

def add_gradient_noise(BaseOptimizer, keras=None):
    """
    Given a Keras-compatible optimizer class, returns a modified class that
    supports adding gradient noise as introduced in this paper:
    https://arxiv.org/abs/1511.06807
    The relevant parameters from equation 1 in the paper can be set via
    noise_eta and noise_gamma, set by default to 0.3 and 0.55 respectively.
    By default, tries to guess whether to use default Keras or tf.keras based
    on where the optimizer was imported from. You can also specify which Keras
    to use by passing the imported module.
    """
    if keras is None:
        # Import it automatically. Try to guess from the optimizer's module
        if hasattr(BaseOptimizer, '__module__') and BaseOptimizer.__module__.startswith('keras'):
            keras = importlib.import_module('keras')
        else:
            keras = importlib.import_module('tensorflow.keras')

    K = keras.backend

    if not (
        inspect.isclass(BaseOptimizer) and
        issubclass(BaseOptimizer, keras.optimizers.Optimizer)
    ):
        raise ValueError(
            'add_gradient_noise() expects a valid Keras optimizer'
        )

    def _get_shape(x):
        if hasattr(x, 'dense_shape'):
            return x.dense_shape

        return K.shape(x)

    class NoisyOptimizer(BaseOptimizer):
        def __init__(self, noise_eta=0.3, noise_gamma=0.55, **kwargs):
            super(NoisyOptimizer, self).__init__(**kwargs)
            with K.name_scope(self.__class__.__name__):
                self.noise_eta = K.variable(noise_eta, name='noise_eta')
                self.noise_gamma = K.variable(noise_gamma, name='noise_gamma')

        def get_gradients(self, loss, params):
            grads = super(NoisyOptimizer, self).get_gradients(loss, params)

            # Add decayed gaussian noise
            t = K.cast(self.iterations, K.dtype(grads[0]))
            variance = self.noise_eta / ((1 + t) ** self.noise_gamma)

            grads = [
                grad + K.random_normal(
                    _get_shape(grad),
                    mean=0.0,
                    stddev=K.sqrt(variance),
                    dtype=K.dtype(grads[0])
                )
                for grad in grads
            ]

            return grads

        def get_config(self):
            config = {'noise_eta': float(K.get_value(self.noise_eta)),
                      'noise_gamma': float(K.get_value(self.noise_gamma))}
            base_config = super(NoisyOptimizer, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

    NoisyOptimizer.__name__ = 'Noisy{}'.format(BaseOptimizer.__name__)

    return NoisyOptimizer

In [None]:
NoisySGD = add_gradient_noise(SGD)
NoisyNadam = add_gradient_noise(Nadam)
NoisyAdamax = add_gradient_noise(Adamax)

In [None]:
eta = 0.03
gamma = 0.55

n = 150

var = []
for t in range(n):
    var.append(eta/(1+t)**gamma)
plt.plot(np.arange(n), var)
plt.title(r'Noise $\sigma^{2} $ through epochs')
plt.xlabel('ephocs')
plt.ylabel(r'noise $\sigma^{2}$')
plt.show()

# Start noise

In [None]:
plt.figure(figsize=(5,20))
dct = {"Dog": [], "Cat": []}
for _, (image, label) in enumerate(train_ds.take(5)):
    for img, lbl in zip(image, label):
        dct[LABELS[lbl]].append(img[np.newaxis,])
#     boolean = False
#     if len(dogs) < 400//64:
#         boolean = True
#         dogs += [img[np.newaxis,] for img,lbl in zip(images,label) if LABELS[lbl]=="Dog"]
#     if len(cats) < 400//64:
#         boolean = True   
#         cats += [img[np.newaxis,] for img,lbl in zip(images,label) if LABELS[lbl]=="Cat"]
#     if not boolean: break
all_images = np.concatenate(dct["Dog"] + dct["Cat"], axis=0)
plt.subplot(5,1,1)
plt.title("Original Image")
plt.imshow(all_images[0,:,:,0], cmap="gray")

# Import data
blocksize = 64

# Create blocks
shuffled_images = all_images.copy()
for j in range(0, all_images.shape[2], blocksize):
    for i in range(0, all_images.shape[1], blocksize):
        indxs = np.random.permutation(all_images.shape[0]).tolist()
        for orig,new in zip(indxs, range(all_images.shape[0])):
            shuffled_images[orig,i:i+blocksize, j:j+blocksize] = all_images[new, i:i+blocksize, j:j+blocksize]
plt.subplot(5,1,2)
plt.title("Hybrid Image")
plt.imshow(shuffled_images[0,:,:,0], cmap="gray")

# plt.subplot(5,1,3)
# plt.title("Smoothed Hybrid")
# smoothed = smoothing_layer(shuffled_images)
# plt.imshow(smoothed[0,:,:,0], cmap="gray")


# image_list = [smoothed[i] for i in range(smoothed.shape[0])]
# PADDING_HYBRID = np.concatenate(image_list, axis=-2)[np.newaxis]

# processed_images = processor(smoothed)

# plt.subplot(5,1,4)
# plt.title("Processed Smoothened Hybrid Image (Channel 0)")
# plt.imshow(processed_images[0,:,:,0], cmap="gray")

# image_list = [processed_images[i] for i in range(processed_images.shape[0])]
# PADDING = np.concatenate(image_list, axis=-2)[np.newaxis]
# print(PADDING.shape)

processed_images = processor(shuffled_images)

plt.subplot(5,1,4)
plt.title("Processed Hybrid Image (Channel 0)")
plt.imshow(processed_images[0,:,:,0], cmap="gray")

image_list = [processed_images[i] for i in range(processed_images.shape[0])]
PADDING = np.concatenate(image_list, axis=-2)[np.newaxis]
print(PADDING.shape)

plt.subplot(5,1,5)
plt.title("Concatenated Processed Hybrid Image (Channel 0)")
plt.imshow(PADDING[0,:,:,0], cmap="gray")

plt.show()

# Convert V1-like images to Dataset (not used)

In [None]:
class NumpyDirectoryGeneratorSequence(Sequence):
    def __init__(self, dct_label, dtype='.npy', randomize=True, random_seed=1331, validation_split=None, is_validation=False, batch_size=32):
        self.directories = dct_label
        self.dtype = dtype
        self.randomize = randomize
        self.seed = random_seed
        self.validation_split = validation_split
        self.is_validation = is_validation
        self.batch_size = batch_size
        self.x = self.y = None
        self._create_files()
        
    def _create_files(self):
        files = []
        for directory,label in self.directories.items():
            files += [(directory+ ('' if directory[-1] == '/' else '/') +i, label) 
                           for i in os.listdir(directory) 
                           if re.findall(self.dtype, i)]
        if self.randomize:
            seed(self.seed)
            shuffle(files)
        if self.validation_split:
            if self.is_validation:
                files = files[floor(len(files) - len(files)*self.validation_split):]
            else:
                files = files[:floor(len(files) - len(files)*self.validation_split)]
                
        self.x, self.y = zip(*files)
        self.y = np.array(self.y)
    
    def __len__(self):
        return ceil(self.y.shape[0] / self.batch_size)
    
    def __getitem__(self, idx):
        batch_x_pre, batch_y = self.x[idx * self.batch_size:(idx + 1) * self.batch_size], self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = [np.load(file) for file in batch_x_pre]
        return np.array(batch_x), batch_y

In [None]:
class NumpyDirectoryGenerator:
    def __init__(self, dct_label, dtype='.npy', randomize=True, random_seed=1331, validation_split=None, is_validation=False):
        self.directories = dct_label
        self.dtype = dtype
        self.files = []
        self.randomize = randomize
        self.seed = random_seed
        self.validation_split = validation_split
        self.is_validation = is_validation
        self._create_files()
        
    def _create_files(self):
        for directory,label in self.directories.items():
            self.files += [(directory+ ('' if directory[-1] == '/' else '/') +i, label) 
                           for i in os.listdir(directory) 
                           if re.findall(self.dtype, i)]
        if self.randomize:
            seed(self.seed)
            shuffle(self.files)
        if self.validation_split:
            if self.is_validation:
                self.files = self.files[floor(len(self.files) - len(self.files)*self.validation_split):]
            else:
                self.files = self.files[:floor(len(self.files) - len(self.files)*self.validation_split)]
    
    def __call__(self):
        return iter(self)
    
    def __iter__(self):
        for file, label in self.files:
            with open(file, 'rb') as data:
                yield (np.load(data), label)

# Microsaccades (not used)
After preprocessing, the model will use a 3d convolutional layer with 1278 filters (that will be the input vector for a synapse) that slices the image to 256 blocks. Each one of these blocks will represent one ms, but in order to keep the "eyesight" none-linear, the vectors are reorganized in an order that represents microsaccades.
The following functions define the blocks order.

In [None]:
def saccades_per1ms(shape, noise_mult=0, radius=0, radial_q=0, return_values_and_grades=False):
    """creates a new order, supposed to be simulating microsaccades but i didn't really check how they work"""
    m, n = shape
    size = m * n
    center = np.array([m / 2, n / 2])-0.5 + np.mod(shape,2)*0.5
    keys = np.arange(size)
    if radius and radial_q:
        keys = {(k, (i, j)): np.minimum((1-radial_q)*np.abs(np.sqrt(((np.array([i,j]) - center)**2).sum())-radius), 
                                        radial_q*np.sqrt(((np.array([i,j]) - center)**2).sum()))+np.random.rand(1)*noise_mult 
                for k, i, j in zip(keys, keys//m, np.mod(keys, m))}
    else: 
        keys = {(k, (i, j)): np.abs(np.sqrt(((np.array([i,j]) - center)**2).sum())-radius)+np.random.rand(1)*noise_mult for k, i, j in zip(keys, keys//m, np.mod(keys, m))}

    sort = sorted(keys.keys(), key=lambda x: keys[x])
    sorted_lst = [i[0] for i in sort]
    if not return_values_and_grades: return sorted_lst
    as_array = np.zeros((m,n))
    for k in range(len(sort)):
        curr_i, curr_j = sort[k][1]
        as_array[curr_i, curr_j] = k
    return keys, as_array, sorted_lst

In [None]:
m = n = 8
noise_mult = 0.2

value, grading, SACCADES_ORDER = saccades_per1ms((m,n), noise_mult, 6, radial_q=.1, return_values_and_grades=True)
plt.figure(figsize=(8, 8))
grades = np.zeros((m,n))
for key, grade in value.items():
    grades[key[1][0], key[1][1]] = grade
plt.imshow(grades, cmap='gray')
for i in range(m):
    for j in range(n):
        text = plt.text(j, i, int(grading[i, j]), ha="center", va="center", color="w")

In [None]:
def saccades_blocks(shape, noise_mult=0, small_mult=0, rows=2, cols=2, radius=0, radial_q=0, return_values_and_grades=False):
    arr = saccades_per1ms((m//rows, n//cols), noise_mult, radius, radial_q, True)[1]
    small_arr = saccades_per1ms((rows, cols), small_mult, radius, radial_q, return_values_and_grades=True)[1]
    size = m * n
    keys = np.arange(size).reshape(shape)
    sorted_dct = {}
    for row in range(arr.shape[0]):
        for col in range(arr.shape[1]):
            value = arr[row, col]
            k = 0
            for subrow in range(rows):
                for subcol in range(cols):
                    sorted_dct[keys[rows*row+subrow, cols*col+subcol]] = rows*cols*value + small_arr[subrow, subcol]
                    k += 1
            keys[rows*row: rows*row +rows, cols*col: cols*col +cols] = rows*cols*value + small_arr
    if not len(sorted_dct) == size: print("PROBLEM!")
    sorted_by_value = sorted(sorted_dct.keys(), key=lambda x: sorted_dct[x])
    if not return_values_and_grades: return sorted_by_value
    else:
        return keys, sorted_by_value, arr, small_arr

Show how the microsaccades look. The ints are the order, and the paint is the grade the order is decided by. the noise_mult can control how much the order is centered (the higher it is, the noisier it gets).

In [None]:
m = n = 16
noise_mult = 0.6
small_noise_mult = 1

to_plot, SACCADES_ORDER, bigarr, smallarr = saccades_blocks((m,n), noise_mult, small_noise_mult, rows=2, cols=2, radius=3, radial_q=0.4, return_values_and_grades=True)
_, (ax0, ax1, ax2) = plt.subplots(3,1, figsize=(20, 20), gridspec_kw={'height_ratios': [4,2,20]})
ax0.imshow(bigarr, cmap='gray')
ax1.imshow(smallarr, cmap='gray')
ax2.imshow(to_plot, cmap='gray')
for i in range(m):
    for j in range(n):
        text = ax2.text(j, i, int(to_plot[i, j]), ha="center", va="center", color="w")
plt.show()

# Model

In [None]:
def data_augmentation(rotate=0.1):
    return tf.keras.Sequential([layers.experimental.preprocessing.RandomFlip("horizontal"), 
                                layers.experimental.preprocessing.RandomRotation(0.075)])

In [None]:
plt.figure(figsize=(30, 10))
for _, (images, labels) in enumerate(train_ds.take(1)):
    augmented = data_augmentation()(images)
    for i in range(9):
        plt.subplot(2, 9, i + 1)
        plt.imshow(images[i,:,:,0], cmap='gray')
        plt.title(LABELS[int(labels[i])])
        plt.axis("off")
        
        plt.subplot(2, 9, 9 + i + 1)
        plt.imshow(augmented[i,:,:,0], cmap='gray')
        plt.title("Augmented " + LABELS[int(labels[i])])
        plt.axis("off")

# BenchMark

In [None]:
def benchmark(depth=1, width=32, augment=True, optimizer=Nadam(5e-3), l2_reg=1e-3):
    inp = keras.Input(shape=(256,256,1))
    if augment:
        x = data_augmentation(inp)
    x = processor(x if augment else inp)
    # x = layers.BatchNormalization()(x)
    x = layers.Flatten()(x)
    for d in range(depth):
      x = layers.Dense(units=width, activation='linear', kernel_regularizer=keras.regularizers.l2(l2_reg), name='FC_layer_%d' %(d + 1))(x)
      x = layers.LeakyReLU(alpha=0.3, name='LReLU_%d' %(d + 1))(x)
      x = layers.BatchNormalization(name='BN_layer_%d' %(d + 1))(x)
    output = layers.Dense(1, activation='sigmoid', name='logits')(x)
    model = keras.Model(inp, output)
    model.compile(optimizer=optimizer, loss=MeanSquaredError(), metrics=keras.metrics.BinaryAccuracy())
    return model

In [None]:
# bench = benchmark(width=2, depth=1, optimizer=Nadam(1e-3), augment=False)
# bench.summary()

In [None]:
# bench.fit(train_ds, epochs=3, validation_data=valid_ds)

In [None]:
archives = {}
# for depth in (1,2,3):
#   archives[depth] = {}
#   for width in [1,2,4,8,16]:
#     archives[depth][width] = set()
#     for _ in range(3):
#       bench = benchmark(depth=depth, width=width, optimizer=Nadam(1e-3), augment=True)
#       archives[depth][width].add(bench.fit(train_ds, epochs=12, validation_data=valid_ds))

In [None]:
# pickle.dump(archives, open('benchmark_archives.pickle', "wb"))

In [None]:
# mean_dct = {nLayer: {nDepth: np.mean([np.mean(i[-3]) for i in nDepthValue]) for nDepth, nDepthValue in depths.items()} for nLayer, depths in archives.items()}

In [None]:
# mean_archives_df = pandas.DataFrame.from_dict(archives[1])
# for i in (2,3):
#   mean_archives_df.append(pandas.DataFrame.from_dict(archives[i]))
# mean_archives_df

In [None]:
# best_dct = {nLayer: {nDepth: np.max([i[-1] for i in nDepthValue]) for nDepth, nDepthValue in depths.items()} for nLayer, depths in archives.items()}

In [None]:
# best_archives_df = pandas.DataFrame.from_dict(archives[1])
# for i in (2,3):
#   best_archives_df.append(pandas.DataFrame.from_dict(archives[i]))
# best_archives_df

In [None]:
# for depth in (1,2,3):
#   curDepth = archives[depth]
#   for width in [1,2,4,8,16]:
#     archives[depth][width] = set()

In [None]:
# for i, col in zip(range(3), ('b', 'r', 'g')):
#   for history, boolean in zip(archives[i+1], (True, False, False)):
#     if boolean:
#       plt.plot(np.arange(50), history.history['val_binary_accuracy'], color=col, label=f"nLayers={i+1}")
#     else:
#       plt.plot(np.arange(50), history.history['val_binary_accuracy'], color=col)
# plt.ylim(0,1)
# plt.legend()
# plt.show()

# Preprocess + Conv3D + Nx(Conv2D + MaxPooling2D) + Dense

using SGD - learn

In [None]:
def pre_conv3d_conv2d_max_dense(lr, n_conv):
    inp = keras.Input(shape=(256,256,1))
    x = processor(inp)
    x = layers.Conv2D(32, 3, activation='relu', strides=(3,3))(x)
    x = layers.BatchNormalization()(x)
    for _ in range(n_conv):
        x = layers.Conv2D(32, 3, activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling2D((2,2))(x)
    x = layers.Flatten()(x)
#     x = layers.Conv3D(1278, (16,16,8), activation='sigmoid', strides=(16,16,8))(x)
#     x = conv_to_neuron_input(x, 400, saccades)
#     x = layers.Flatten()(x)
    output = layers.Dense(1, activation='sigmoid')(x)
    model = keras.Model(inp, output)
    model.compile(optimizer=SGD(lr=lr), loss=MeanSquaredError(), metrics=keras.metrics.BinaryAccuracy())
    # model_gabor = keras.Model(inp, [output, spikes_per_ml])
    return model

One Conv2D - learns

In [None]:
def preprocess_linear(lr, loss_func=SGD):
    inp = keras.Input(shape=(256,256,1))
    x = processor(inp)
    x = layers.Flatten()(x)
#     x = layers.Conv3D(1278, (16,16,8), activation='sigmoid', strides=(16,16,8))(x)
#     x = conv_to_neuron_input(x, 400, saccades)
#     x = layers.Flatten()(x)
    output = layers.Dense(1, activation='sigmoid')(x)
    model = keras.Model(inp, output)
    model.compile(optimizer=loss_func(lr=lr), loss=MeanSquaredError(), metrics=keras.metrics.BinaryAccuracy())
    # model_gabor = keras.Model(inp, [output, spikes_per_ml])
    return model

# Dense Initializer

In [None]:
FULL = 400
def block_initializer(*windows):
    rate = sum([j-i for i,j in windows])
    def block(shape, dtype=None):
        weights = np.zeros(FULL)
        for i,j in windows:
            weights[i:j] = 1 / (rate)
        plt.plot(weights, label="spikes wanted")
        return weights[:,np.newaxis]
    return block

In [None]:
FULL = 400
def block_initializer(*windows):
    rate = sum([j-i for i,j in windows])
    def block(shape, dtype=None):
        weights = np.zeros(FULL)
        for i,j in windows:
            weights[i:j] = 1/np.sqrt(2*np.pi)*np.exp(-0.5*(np.arange(i-(i+j)//2,j-(i+j)//2) / (j-i)*1.5)**2) / len(windows)
        weights = weights / sum(weights)
        plt.plot(weights, label="spikes wanted")
        return weights[:,np.newaxis]
    return block

In [None]:
bloop = block_initializer((250,280), (200, 230))(1)

In [None]:
def dense_initializer_AP(cycle_time, full=400, sin_intializer=False, calc_init=False, norm_calc=False):
    times = full // cycle_time
    starting_time = full % cycle_time
    def sinus_initializer(shape, dtype=None):
        actual = (np.sin(np.arange(0,cycle_time*(times))*2*np.pi/cycle_time + np.pi) + 1) / times
        padding = np.zeros(starting_time)
        weights = np.concatenate([padding, actual])
        return weights[:,np.newaxis]
    def calcium_initializer(shape, dtype=None):
        weights = np.zeros(full)
        weights[starting_time:starting_time+3*cycle_time] = (np.sin(np.arange(0,3*cycle_time)*2/3*np.pi/cycle_time-np.pi/2)+1) / np.pi#cycle_time * 10
        return weights[:, np.newaxis]
    def normal_calcium(shape, dtype=None):
        sigma = cycle_time
        weights =  1/np.sqrt(2*np.pi)*np.exp(-((np.arange(full) - (starting_time + cycle_time)) / sigma)**2 * 0.5)
        return weights[:, np.newaxis]
    return sinus_initializer if sin_intializer else calcium_initializer if calc_init else normal_calcium

In [None]:
plt.plot(np.arange(400), dense_initializer_AP(32, calc_init=True)(1), color='b', label="calcium spike initializer")
plt.plot(np.arange(400), dense_initializer_AP(32, sin_intializer=True)(1),  color='r', label="pulses initializer")
plt.plot(np.arange(400), dense_initializer_AP(32)(1), color='g', label="normal calcium initializater")
plt.plot(np.arange(400), block_initializer((250,280), (200, 230))(1), color='k', label="block")
plt.ylim(0,1)
plt.legend()
plt.show()

# Previous Modules and Funcs

In [None]:
def identity_init_1d_conv_layer(shape, dtype=None):
    if shape[0] != 1:
        raise ValueError('Can only be used with keranel size of 1')
    if shape[1] != shape[2]:
        raise ValueError('Can only be used with same number of filters for input and output')

    return np.identity(shape[1])[np.newaxis]

In [None]:
def playable(saccades=None, xaxis=False, use_sigmoid=True, sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, 
             useSynapse=False, qSynapse=0.2, augment=False, nSynapse=50, optimizer=SGD, conv_shape=(16,16)):
    inp = keras.Input(shape=(256,256,1))
    if augment:
        x = data_augmentation(inp)
    x = processor(x if augment else inp)
    conv_shape = (1,x.shape[2]) if xaxis else conv_shape
    x = layers.Conv2D(1278, conv_shape, strides=conv_shape, activity_regularizer=pre_synaptic_spike_regularization)(x)
    x = layers.BatchNormalization()(x)
    x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, use_sigmoid=use_sigmoid, mult=sigmoid_mult, name='Neuron_Input_Bool')(x)  # only in validation and test
    if useSynapse:
        nSynapsesPerMS = SpikeProcessor(nSynapse, name='nSynapses')(x)
#     x= layers.UpSampling()
    x = ToNeuronInput(400, saccades, name="NeuronInput")(x)
    x = L5PC_model(x)   # run through david's model
    
    x = layers.Flatten()(x)
    if to_bool:
        x = ToBoolLayer(threshold=0.1, use_sigmoid=False, name='postNeuronBool')(x)  # only in validation and test
    output = layers.Dense(1, activation='sigmoid', name='postNeuron')(x)
    if useSynapse:
        model = keras.Model(inp, [output, nSynapsesPerMS])
        model.compile(optimizer=optimizer, loss={'postNeuron': MeanSquaredError(), 'nSynapses': loss_for_me}, 
                      metrics={'postNeuron': keras.metrics.BinaryAccuracy()}, loss_weights=[1-qSynapse, qSynapse])
    else:
        model = keras.Model(inp, output)#, output_after_sigmoid, output_train])
        model.compile(optimizer=optimizer, loss=MeanSquaredError(), metrics=keras.metrics.BinaryAccuracy())
    return model

In [None]:
l1_reg = 1e-7
l2_reg = 1e-6

In [None]:
def playable_initialized_dense(saccades=None, xaxis=False, use_sigmoid=True, 
                       sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, 
                       useSynapse=False, qSynapse=0.2, augment=False, 
                       nSynapse=50, optimizer=SGD(5e-3), conv_shape=(16,16),
                       dense_init=block_initializer):
    inp = keras.Input(shape=(256,256,1))
    if augment:
        x = data_augmentation(inp)
    x = processor(x if augment else inp)
    # x = layers.BatchNormalization()(x)
    conv_shape = (1,x.shape[2]) if xaxis else conv_shape
    x = layers.Conv2D(1278, conv_shape, strides=conv_shape,
                      activity_regularizer=pre_synaptic_spike_regularization, name="WiringLayer")(x)
    x = layers.BatchNormalization()(x)
    x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, 
                    use_sigmoid=use_sigmoid, 
                    mult=sigmoid_mult, 
                    name='preNeuronBool')(x)
    if useSynapse:
        nSynapsesPerMS = SpikeProcessor(nSynapse, name='nSynapses')(x)
    one_cycle = 200//x.shape[1]
    full_time = one_cycle*x.shape[1]

    x = ToNeuronInput(400, new_order=saccades, name="NeuronInput", padding=NOISE)(x)
    x = L5PC_model(x)   # run through david's model
    
    x = layers.Flatten()(x)
    if to_bool:
        x = ToBoolLayer(threshold=0.2, use_sigmoid=False, mult=25, name='postNeuronBool')(x)  # only in validation and test
    
    output = layers.Dense(1, 
                          kernel_initializer=dense_init(start=400-full_time, end=400), 
                          bias_initializer=lambda shape, dtype: np.array([0.]), 
                          trainable=False,
                          name="nSpikes")(x)

    if useSynapse:
        model = keras.Model(inp, [output, nSynapsesPerMS])
        model.compile(optimizer=optimizer, 
                      loss={'nSpikes': MeanSquaredError(), 'nSynapses': MeanSquaredErrorSynapsesPerMS()}, 
                      metrics={'nSpikes': keras.metrics.BinaryAccuracy()}, 
                      loss_weights=[1-qSynapse, qSynapse])
    else:
        model = keras.Model(inp, output)#, output_after_sigmoid, output_train])
        model.compile(optimizer=optimizer, loss=MeanSquaredError(), metrics=keras.metrics.BinaryAccuracy())
    return model

In [None]:
def playable_dense(saccades=None, xaxis=False, use_sigmoid=True, 
                       sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, 
                       useSynapse=False, qSynapse=0.2, augment=False, 
                       nSynapse=50, optimizer=SGD(5e-3), conv_shape=(16,16),
                       dense_init=block_initializer):
    inp = keras.Input(shape=(256,256,1))
    if augment:
        x = data_augmentation(inp)
    x = processor(x if augment else inp)
    # x = layers.BatchNormalization()(x)
    conv_shape = (1,x.shape[2]) if xaxis else conv_shape
    x = layers.Conv2D(1278, conv_shape, strides=conv_shape,
                      activity_regularizer=pre_synaptic_spike_regularization, name="WiringLayer")(x)
    x = layers.BatchNormalization()(x)
    x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, 
                    use_sigmoid=use_sigmoid, 
                    mult=sigmoid_mult, 
                    name='preNeuronBool')(x)
    if useSynapse:
        nSynapsesPerMS = SpikeProcessor(nSynapse, name='nSynapses')(x)
    one_cycle = 400//x.shape[1]
    full_time = once_cycle*x.shape[1]

    x = ToNeuronInput(400, new_order=saccades, name="NeuronInput", padding=NOISE)(x)
    x = L5PC_model(x)   # run through david's model
    
    x = layers.Flatten()(x)
    if to_bool:
        x = ToBoolLayer(threshold=0.2, use_sigmoid=False, mult=25, name='postNeuronBool')(x)  # only in validation and test
    
    output = layers.Dense(1, name="nSpikes")(x[:, -full_time:])

    if useSynapse:
        model = keras.Model(inp, [output, nSynapsesPerMS])
        model.compile(optimizer=optimizer, 
                      loss={'nSpikes': MeanSquaredError(), 'nSynapses': MeanSquaredErrorSynapsesPerMS()}, 
                      metrics={'nSpikes': keras.metrics.BinaryAccuracy()}, 
                      loss_weights=[1-qSynapse, qSynapse])
    else:
        model = keras.Model(inp, output)#, output_after_sigmoid, output_train])
        model.compile(optimizer=optimizer, loss=MeanSquaredError(), metrics=keras.metrics.BinaryAccuracy())
    return model

In [None]:
def around_spikes_loss(windows, r, full=400):
    weights = np.zeros(full)
    for start,end in windows:
        weights[start-r:start] = 1
        weights[end:end+r] = 1
    for start,end in windows:
        weights[start:end] = 0
    weights = tens(weights[:,np.newaxis], dtype=tf.float32)
    def loss(y_true, y_preds):
        return mean_squared_error_synapses_per_ms(tf.matmul(y_preds, weights))
    
    def mean_squared_error_synapses_per_ms(y_preds):
        squared_difference = tf.square(y_preds)
        mean = tf.reduce_mean(squared_difference, axis=-1)
        return mean
    plt.plot(weights[:,0], label="silence wanted")
    return loss

In [None]:
def different_nSynapses_initialized_dense(saccades=None, xaxis=False, use_sigmoid=True, 
                        sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, dropout=.2,
                        excitatory_wanted=EXCITATORY_SYNAPSES_WANTED, inhibitory_wanted=INHIBITORY_SYNAPSES_WANTED,
                        qSynapse=(0.1, 0.1), augment=False, 
                        optimizer=SGD(5e-3), conv_shape=(16,16),
                        dense_init=block_initializer, spike_wanted = [(260,280)], loss_radius=20,
                                          padding=PADDING):
    inp = keras.Input(shape=(256,256,1))
    if augment:
        x = data_augmentation(inp)
    x = processor(x if augment else inp)
    # x = layers.BatchNormalization()(x)
    conv_shape = (1,x.shape[2]) if xaxis else conv_shape
    wiringLayer = layers.Conv2D(1278, conv_shape, strides=conv_shape,#kernel_constraint=keras.constraints.non_neg(),
                      activity_regularizer=pre_synaptic_spike_regularization, name="WiringLayer")
    x = wiringLayer(x)
    x = layers.BatchNormalization()(x)
    x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, 
                    use_sigmoid=use_sigmoid, 
                    mult=sigmoid_mult, 
                    name='preNeuronBool')(x)
    nExcitatorySynapsesPerMS = SpikeProcessor(excitatory_wanted, name='nExcitatory', end=639)(x)
    nInhibitorySynapsesPerMS = SpikeProcessor(inhibitory_wanted, name='nInhibitory', start=639)(x)
    cycle_time = x.shape[1]
    cycles = 200//x.shape[1]
    full_time = cycles*cycle_time
    convertedPadding = wiringLayer(padding)
    convertedPadding = layers.BatchNormalization()(convertedPadding)
    convertedPadding = sigmoid(convertedPadding)
    convertedPadding = ToBoolLayer(threshold=sigmoid_threshold, 
                    use_sigmoid=use_sigmoid, 
                    mult=sigmoid_mult, 
                    name='preNeuronBool')(convertedPadding)

    x = ToNeuronInput(400, new_order=saccades, name="NeuronInput", padding=convertedPadding)(x)
    if dropout: layers.Dropout(dropout)(x)
    x = L5PC_model(x)   # run through david's model
    
    x = layers.Flatten()(x)
    if to_bool:
        spike_train = ToBoolLayer(threshold=0.2, use_sigmoid=False, mult=25, name='SpikeTrain')(x)  # only in validation and test
    
    output = layers.Dense(1, 
                          kernel_initializer=dense_init(*spike_wanted), 
                          bias_initializer=lambda shape, dtype: np.array([0.]), 
                          trainable=False,
                          name="nSpikes")(x)
    
    model = keras.Model(inp, [output, spike_train, nExcitatorySynapsesPerMS, nInhibitorySynapsesPerMS])
    model.compile(optimizer=optimizer, 
                  loss={'nSpikes': MeanSquaredError(), 'SpikeTrain': around_spikes_loss(spike_wanted, loss_radius), 'nExcitatory': MeanSquaredErrorSynapsesPerMS(), 'nInhibitory': MeanSquaredErrorSynapsesPerMS()}, 
                  metrics={'nSpikes': keras.metrics.BinaryAccuracy()}, 
                  loss_weights=[0.95-sum(qSynapse),0.05, *qSynapse])
    return model

In [None]:
def different_nSynapses(saccades=None, xaxis=False, use_sigmoid=True, 
                        sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, 
                        excitatory_wanted=EXCITATORY_SYNAPSES_WANTED, inhibitory_wanted=INHIBITORY_SYNAPSES_WANTED,
                        qSynapse=(0.1, 0.1), augment=False, 
                        optimizer=SGD(5e-3), conv_shape=(16,16),
                        dense_init=block_initializer, padding=PADDING):
    inp = keras.Input(shape=(256,256,1))
    if augment:
        x = data_augmentation(inp)
    x = processor(x if augment else inp)
    # x = layers.BatchNormalization()(x)
    conv_shape = (1,x.shape[2]) if xaxis else conv_shape
    wiringLayer = layers.Conv2D(1278, conv_shape, strides=conv_shape,
                      activity_regularizer=pre_synaptic_spike_regularization, name="WiringLayer")
    x = wiringLayer(x)
    x = layers.BatchNormalization()(x)
    x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, 
                    use_sigmoid=use_sigmoid, 
                    mult=sigmoid_mult, 
                    name='preNeuronBool')(x)
    nExcitatorySynapsesPerMS = SpikeProcessor(excitatory_wanted, name='nExcitatory', end=639)(x)
    nInhibitorySynapsesPerMS = SpikeProcessor(inhibitory_wanted, name='nInhibitory', start=639)(x)
    one_cycle = 200//x.shape[1]
    full_time = one_cycle*x.shape[1]
    convertedPadding = wiringLayer(padding)
    convertedPadding = layers.BatchNormalization()(convertedPadding)
    convertedPadding = sigmoid(convertedPadding)
    convertedPadding = ToBoolLayer(threshold=sigmoid_threshold, 
                    use_sigmoid=use_sigmoid, 
                    mult=sigmoid_mult, 
                    name='preNeuronBool')(convertedPadding)

    x = ToNeuronInput(400, new_order=saccades, name="NeuronInput", padding=convertedPadding)(x)
    x = L5PC_model(x)   # run through david's model
    
    x = layers.Flatten()(x)
    if to_bool:
        x = ToBoolLayer(threshold=0.2, use_sigmoid=False, mult=25, name='postNeuronBool')(x)  # only in validation and test
    
    output = layers.Dense(1, name="nSpikes")(x[:, -full_time:])

    model = keras.Model(inp, [output, nExcitatorySynapsesPerMS, nInhibitorySynapsesPerMS])
    model.compile(optimizer=optimizer, 
                  loss={'nSpikes': MeanSquaredError(), 'nExcitatory': MeanSquaredErrorSynapsesPerMS(), 'nInhibitory': MeanSquaredErrorSynapsesPerMS()}, 
                  metrics={'nSpikes': keras.metrics.BinaryAccuracy()}, 
                  loss_weights=[1-sum(qSynapse), *qSynapse])
    return model

In [None]:
def different_nSynapses_max_loss(window, saccades=None, xaxis=False, use_sigmoid=True, 
                        sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, 
                        excitatory_wanted=EXCITATORY_SYNAPSES_WANTED, inhibitory_wanted=INHIBITORY_SYNAPSES_WANTED,
                        qSynapse=(0.1, 0.1), augment=False, 
                        optimizer=SGD(5e-3), conv_shape=(16,16),
                        dense_init=block_initializer, padding=PADDING):
    inp = keras.Input(shape=(256,256,1))
    if augment:
        x = data_augmentation(inp)
    x = processor(x if augment else inp)
    # x = layers.BatchNormalization()(x)
    conv_shape = (1,x.shape[2]) if xaxis else conv_shape
    wiringLayer = layers.Conv2D(1278, conv_shape, strides=conv_shape,
                      activity_regularizer=pre_synaptic_spike_regularization, name="WiringLayer")
    x = wiringLayer(x)
    x = layers.BatchNormalization()(x)
    x = sigmoid(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, 
                    use_sigmoid=use_sigmoid, 
                    mult=sigmoid_mult, 
                    name='preNeuronBool')(x)
    nExcitatorySynapsesPerMS = SpikeProcessor(excitatory_wanted, name='nExcitatory', end=639)(x)
    nInhibitorySynapsesPerMS = SpikeProcessor(inhibitory_wanted, name='nInhibitory', start=639)(x)
    one_cycle = 200//x.shape[1]
    full_time = one_cycle*x.shape[1]
    convertedPadding = wiringLayer(padding)
    convertedPadding = layers.BatchNormalization()(convertedPadding)
    convertedPadding = sigmoid(convertedPadding)
    convertedPadding = ToBoolLayer(threshold=sigmoid_threshold, 
                    use_sigmoid=use_sigmoid, 
                    mult=sigmoid_mult, 
                    name='preNeuronBool')(convertedPadding)

    x = ToNeuronInput(400, new_order=saccades, name="NeuronInput", padding=convertedPadding)(x)
    x = L5PC_model(x)   # run through david's model
    
    output = layers.Flatten(name="SpikeTrain1")(x)
    if to_bool:
        output = ToBoolLayer(threshold=0.2, use_sigmoid=True, mult=25, name='SpikeTrain')(output)  # only in validation and test
    
    model = keras.Model(inp, [output, nExcitatorySynapsesPerMS, nInhibitorySynapsesPerMS])
    model.compile(optimizer=optimizer, 
                  loss={'SpikeTrain': max_loss(window), 'nExcitatory': MeanSquaredErrorSynapsesPerMS(), 'nInhibitory': MeanSquaredErrorSynapsesPerMS()}, 
                  metrics={'SpikeTrain': keras.metrics.BinaryAccuracy()}, 
                  loss_weights=[1-sum(qSynapse), *qSynapse])
    return model

In [None]:
def max_loss(windows, full=400):
    """not in use"""
    weights = np.zeros(full)
    for start,end in windows:
        weights[start:end] = 1
    weights = tens(weights, dtype=tf.float32)
    def loss(y_true, y_preds):
        loss_value =  (tf.cast(y_true, y_preds.dtype) - K.max(y_preds*weights, axis=-1))**2
        return loss_value
    return loss

# Lowering Synapse loss Callback

# Pruning (not in use)

In [None]:
class SynapseLossDecay(keras.callbacks.Callback):
    def __init__(self, alpha, beta, decay_alpha=.9, decay_beta=.9, low=.03):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.decay_alpha = decay_alpha
        self.decay_beta = decay_beta
        self.low = low

#     @tf.autograph.experimental.do_not_convert
    def on_epoch_end(self, batch, logs=None):
        if self.alpha > 0:
            if self.alpha < self.low: self.alpha = self.alpha * 0
            else: self.alpha = self.alpha * self.decay_alpha
        if self.beta > 0:
            if self.beta < self.low: self.beta = self.beta * 0
            else: self.beta = self.beta * self.decay_beta

In [None]:
class SynapsePruner(keras.callbacks.Callback):
    def __init__(self, kmax1=5, kmax2=15, iterations=8, axis=-1, splitPoint = N_EXC, prune_layer="WiringLayer"):
        super().__init__()
        self.kmax1 = kmax1
        self.kmax2 = kmax2
        self.iterations = iterations
        self.curIterations = iterations
        self.split = splitPoint
        self.layer_to_prune = prune_layer
        print(self)
    
    def __str__(self):
        string = f"\nSynapsePruner callback:\nkmax1 = {self.kmax1}\n"
        string += f"kmax2 = {self.kmax2}\n"
        string += f"every {self.iterations} iterations"
        return string
        
    def build(self, shape):
        pass
        
#     @tf.autograph.experimental.do_not_convert
    def on_train_batch_end(self, batch, logs=None):
        if self.iterations:
            if self.curIterations: self.curIterations -= 1
            else:
                self.prune()
                self.curIterations = self.iterations

    def prune(self):
        layer = self.model.get_layer(self.layer_to_prune)
        weights = layer.get_weights()
        is_bias = len(weights) == 2
        if is_bias: kernels, bias = layer.get_weights()
        else: kernels = layer.get_weights()[0]
        kernels = kernels[:, 0]
        p1_x = kernels[:, :, :self.split]
        p2_x = kernels[:, :, self.split:]
        kmax1 = np.partition(p1_x, -self.kmax1, axis=-1)[:,:,-self.kmax1][:,:,np.newaxis]
        kmax2 = np.partition(p2_x, -self.kmax2, axis=-1)[:,:,-self.kmax2][:,:,np.newaxis]


        arr1 = (p1_x * (p1_x >= kmax1))
        arr2 = (p2_x * (p2_x >= kmax2))
        arr = np.concatenate([arr1, arr2], axis=-1)
        arr = np.concatenate([arr1, arr2], axis=-1)
        assert arr.shape == kernels.shape, tf.print("Oh no. something went wrong")

        layer.set_weights([arr[:, np.newaxis], bias] if is_bias else [arr[:, np.newaxis]])

In [None]:
nFilters = 32
nSyn = 10
nPixels = 32  # coloumn
a_kmax1 = 1   # how many excitatory synapses per filter per pixel
a_kmax2 = 1   # how many inhibitory synapses per filter per pixel


kernels = np.random.rand(nPixels*nFilters*nSyn).reshape((nPixels,nFilters,nSyn))
splat = nSyn //2

# Split
p1_x = kernels[:, :, :splat]
p2_x = kernels[:, :, splat:]

# find k-highest
kmax1 = np.partition(p1_x, -a_kmax1,axis=-1)[:,:,-a_kmax1][:,:,np.newaxis]
kmax2 = np.partition(p2_x, -a_kmax2,axis=-1)[:,:,-a_kmax2][:,:,np.newaxis]

# Prune all lower
arr1 = (p1_x * (p1_x >= kmax1))
arr2 = (p2_x * (p2_x >= kmax2))
arr = np.concatenate([arr1, arr2], axis=-1)

# Plot
plt.figure(figsize=(50,10))
for i in range(nFilters):
    plt.subplot(2,nFilters,i+1)
    plt.title(f"Filter {i}")
    plt.imshow(kernels[:,i,:], cmap="gray")
    plt.xlabel("Synapse")
    plt.ylabel("Pixel")
    plt.axvline(splat, color="r", linestyle="--")
    
    plt.subplot(2,nFilters,nFilters +i+1)
    plt.title(f"Pruned {i}")
    plt.imshow(arr[:,i,:], cmap="gray")
    plt.xlabel("Synapse")
    plt.ylabel("Pixel")
    plt.axvline(splat, color="r", linestyle="--")


plt.show()

In [None]:
def printModule(dropout1, dropout2, sigmoid_threshold, sigmoid_mult, exWanted, inWanted, qSynapse, augment, synLoss):
    print("~*~ Visual Module ~*~")
    if augment: print(f"Images are augmented (rotated by {augment})")
    print(f"Image Dropout Rate: {dropout1}")
    print(f"Synapse Threshold: {sigmoid_threshold}")
    print(f"Synapse Training Sigmoid Multiplication: {sigmoid_mult}")
    print(f"Synapses Wanted: Exc-{exWanted}; Inh-{inWanted}; Sum-{exWanted+inWanted}")
    print(f"Synapses Loss Rate: Exc-{qSynapse[0]}; Inh-{qSynapse[1]}")
    print(f"Synapses Dropout Rate: {dropout2}")
    print(f"Synapse Loss Function: {synLoss.__name__}")

In [None]:
def different_nSynapses(saccades=None, xaxis=True, use_sigmoid=True, pruning=True, dropout1=.2, dropout2=False,
                        sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, 
                        excitatory_wanted=EXCITATORY_SYNAPSES_WANTED, inhibitory_wanted=INHIBITORY_SYNAPSES_WANTED,
                        qSynapse=(0.1, 0.1), augment=False, 
                        optimizer=SGD(5e-3), conv_shape=(16,16),
                        padding=PADDING, synLoss=MeanSquaredErrorSynapsesPerMS):
    printModule(dropout1, dropout2, sigmoid_threshold, sigmoid_mult, excitatory_wanted, inhibitory_wanted, qSynapse, augment, synLoss)
    padding = tf.Variable(lambda: padding, trainable=False)
    inp = keras.Input(shape=(256,256,1))
    if augment:
        x = data_augmentation(inp)
    x = processor(x if augment else inp)
    # x = layers.BatchNormalization()(x)
    if dropout1:
        x = layers.Dropout(dropout1)(x)
    conv_shape = (x.shape[2], 1) if xaxis else conv_shape
    f = tf.keras.Sequential([layers.Conv2D(1278, conv_shape, strides=conv_shape, #kernel_constraint=keras.constraints.NonNeg(), 
                                activity_regularizer=pre_synaptic_spike_regularization, name="WiringLayer"),
                             layers.BatchNormalization(name="BatchNorm"),
                             layers.Activation(sigmoid),
                             ToBoolLayer(threshold=sigmoid_threshold, use_sigmoid=use_sigmoid, mult=sigmoid_mult, name='preNeuronBool')])
    x = f(x)
    pad = f(padding)
        
    cycles = 200//x.shape[-2]
    full_time = cycles*x.shape[-2]
    print("how many cycles:", cycles)

    x = ToNeuronInput(400, new_order=saccades, name="NeuronInput")(x, padding=pad if padding is not 0 else padding)
    
    nExcitatorySynapsesPerMS = SpikeProcessor(excitatory_wanted, name='nExcitatory', end=639)(x)
    nInhibitorySynapsesPerMS = SpikeProcessor(inhibitory_wanted, name='nInhibitory', start=639)(x)

    if dropout2:
        x = layers.Dropout(dropout2)(x)
    x = L5PC_model(x)[:,-full_time:,:]   # run through david's model
    x = layers.MaxPooling1D(x.shape[-2], strides=x.shape[-2], name="MaxPooling")(x)
    output = layers.Flatten(name="nSpikes")(x)
#     if to_bool:
#         x = ToBoolLayer(threshold=0.2, use_sigmoid=False, mult=25, name='postNeuronBool')(x)  # only in validation and test
    
#     output = layers.Dense(1, name="nSpikes")(x[:, -full_time:])
#     model = keras.Model(inp, output)
#     model.compile(optimizer=optimizer, loss={'nSpikes': MeanSquaredError()}, metrics={'nSpikes': keras.metrics.BinaryAccuracy()})
    model = keras.Model(inp, [output, nExcitatorySynapsesPerMS, nInhibitorySynapsesPerMS])
    model.compile(optimizer=optimizer, 
                  loss={'nSpikes': MeanSquaredError(), 'nExcitatory': synLoss(), 'nInhibitory': synLoss()}, 
                  metrics={'nSpikes': keras.metrics.BinaryAccuracy()}, 
                  loss_weights=[1-sum(qSynapse), *qSynapse])
    return model

In [None]:
class Pad(keras.layers.Layer):
    """flattens the temporal dimensions, order by new_order and pads with 0's to fill"""
    def __init__(self, padding, full=400, times=6, reverse=True, new_order=None, name="PadLayer"):
        super().__init__(name=name)
        self.full = full
        self.padding = padding
        self.times = times
        self.reverse = reverse
        self.new_order = new_order
        self.shape = None
        self.pad_to_add = None
    
    def build(self, shape):
        self.shape = shape
        self.pad_to_add = self.full - self.shape[-2]*self.times
    
    def call(self, inputs):
        if self.new_order is not None:
            inputs = self.gather(inputs, self.new_order, axis=-2)
        if self.times > 1:
            if self.reverse: new_inp = layers.Concatenate(axis=-2)([inputs, K.reverse(inputs,axes=-2)] * (self.times // 2) + [inputs] * (self.times % 2))
            else: new_inp = layers.Concatenate(axis=-2)([inputs] * self.times)
        else: new_inp = inputs
        starting_time = layers.Lambda(lambda x: self.ranInt(x))(self.padding)
        padding = self.padding[:,:, starting_time:starting_time+self.pad_to_add]
        new_inp = layers.Concatenate(axis=-2)([tf.tile(padding, [tf.shape(new_inp)[0], 1, 1, 1]), new_inp])
        new_inp = K.reshape(new_inp, (tf.shape(new_inp)[0], self.shape[-3], self.shape[-2]*self.times + self.pad_to_add, self.shape[-1]))
        return new_inp

    def ranInt(self, x):
        return K.random_uniform((1,), 0, self.padding.shape[-2]-self.pad_to_add, dtype=tf.dtypes.int32)[0]#.numpy()
    
    @tf.function
    def gather(self, x, ind, axis):
        return tf.gather(x+0, ind, axis)

In [None]:
class SliceLayer(tf.keras.layers.Layer):
    def __init__(self, start=None, end=None, name="SliceLayer"):
        super().__init__(name=name)
        self.start = start
        self.end = end

    def build(self, shape):
        self.dims = len(shape)
        pass
    
    def call(self, inputs):
        if self.start is not None:
            if self.end is not None:
                return inputs[:,:,:,self.start:self.end] if self.dims==4 else inputs[:,:,self.start:self.end]
            else:
                return inputs[:,:,:,self.start:] if self.dims==4 else inputs[:,:,self.start:]
        elif self.end is not None:
            return inputs[:,:,:,:self.end] if self.dims==4 else inputs[:,:,:self.end]

In [None]:
def different_nSynapses(saccades=None, xaxis=True, use_sigmoid=True, pruning=True, dropout1=False, dropout2=False,
                        sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, 
                        excitatory_wanted=EXCITATORY_SYNAPSES_WANTED, inhibitory_wanted=INHIBITORY_SYNAPSES_WANTED,
                        qSynapse=(tf.Variable(0.1, trainable=False), tf.Variable(0.1, trainable=False)), augment=False, 
                        optimizer=SGD(momentum=.9), conv_shape=(16,16), threshold=.2, times=6,
                        padding=PADDING, synLoss=MeanSquaredErrorSynapsesPerMS, nSynapse=True, neurons=neurons, regular_sigmoid=True, non_neg=False):
    printModule(dropout1, dropout2, sigmoid_threshold, sigmoid_mult, excitatory_wanted, inhibitory_wanted, qSynapse, augment, synLoss)
    module_name = "module_syn_"
    for i in [dropout1, dropout2, sigmoid_threshold, sigmoid_mult, threshold, excitatory_wanted, inhibitory_wanted, qSynapse, augment, synLoss.__name__]:
        module_name += str(i)
           
    inp = keras.Input(shape=(256,256,1))
    
    padding = tf.Variable(lambda: padding, trainable=False)
    x = inp
    if augment: x = data_augmentation(augment)(x)
    x = processor(x)
    if dropout1: x = layers.Dropout(dropout1)(x)
    conv_shape = (x.shape[-3], 1) if xaxis else conv_shape
    real_input_time = x.shape[-2]*times
    x = Pad(padding, full=neurons.input.shape[-2], times=times)(x)
    if non_neg: x = layers.Conv2D(2*N_EXC, conv_shape, strides=conv_shape, use_bias=True, kernel_constraint=keras.constraints.NonNeg(), name="WiringLayer")(x)#, activity_regularizer=pre_synaptic_spike_regularization)(x)
    else: x = layers.Conv2D(2*N_EXC, conv_shape, strides=conv_shape, use_bias=True, name="WiringLayer")(x)#, activity_regularizer=pre_synaptic_spike_regularization)(x)
    x = layers.BatchNormalization(name="BatchNorm")(x)
    if regular_sigmoid: x = layers.Activation(sigmoid)(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, use_sigmoid=use_sigmoid, mult=sigmoid_mult, name='preNeuronBool')(x)
    
    x = K.squeeze(x, axis=-3)
        
    ExcitatorySynapses = SliceLayer(end=N_EXC, name="ExcSyns")(x)#[:, :, :N_EXC]  #SpikeProcessor(excitatory_wanted, name='nExcitatory', end=639)(x)
    InhibitorySynapses = SliceLayer(start=N_EXC, name="InhSyns")(x)#[:, :, N_EXC:]  #SpikeProcessor(inhibitory_wanted, name='nInhibitory', start=639)(x)

    if dropout2: x = layers.Dropout(dropout2)(x)
    x = neurons(x)[:,-real_input_time:,:]   # run through david's model, take only the post-noise time (32 ms X 6 times)
    x = layers.MaxPooling1D(x.shape[-2], strides=x.shape[-2], name="MaxPooling")(x)
    x = ToBoolLayer(threshold=threshold, use_sigmoid=True, mult=1, name='postNeuronBool')(x)
    output = layers.Flatten(name="nSpikes")(x)
    
    model = keras.Model(inp, [output, ExcitatorySynapses, InhibitorySynapses])
    if not nSynapse:
        model.compile(optimizer=optimizer, loss={'nSpikes': MeanSquaredError()}, 
                      metrics={"nSpikes": keras.metrics.BinaryAccuracy(), "ExcSyns": MeanSynapsesPerMsMetric(), "InhSyns": MeanSynapsesPerMsMetric()})
    else:
        model.compile(optimizer=optimizer, 
                  loss=[MeanSquaredError(), synLoss(excitatory_wanted), synLoss(inhibitory_wanted)], 
                  metrics={"nSpikes": keras.metrics.BinaryAccuracy(), "ExcSyns": MeanSynapsesPerMsMetric(), "InhSyns": MeanSynapsesPerMsMetric()},
                  loss_weights=[1.0 -qSynapse[0] - qSynapse[1], qSynapse[0], qSynapse[1]])
    return model, module_name

In [None]:
# ex_synapse_weight = tf.Variable(0.05, trainable=False)
# inh_synapse_weight = tf.Variable(0.05, trainable=False)

model_synapses, module_name = different_nSynapses(threshold=.2, non_neg=False, nSynapse=False, dropout1=.3, dropout2=.3, use_sigmoid=True, regular_sigmoid=True, sigmoid_mult=50, sigmoid_threshold=.9, augment=1., optimizer=Nadam(), qSynapse=(.1, .1), synLoss=MeanSquaredErrorSynapsesPerMS)

In [None]:
lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=5e-3, decay_steps=10000, decay_rate=0.9)

In [None]:
model_synapses, module_name = different_nSynapses(threshold=.2, non_neg=True, nSynapse=True, dropout1=0.2, dropout2=0.1, use_sigmoid=True, regular_sigmoid=True, sigmoid_mult=50, sigmoid_threshold=.95, augment=True, optimizer=Nadam(lr=5e-3), qSynapse=(.05, .05), synLoss=MSE_RMS_SynapsesPerMS)

In [None]:
model_synapses.summary()

In [None]:
history_syn = model_synapses.fit(train_ds, epochs=200, validation_data=valid_ds, callbacks=[SynapsePruner(5, 3, iterations=1), tf.keras.callbacks.LearningRateScheduler(lr_schedule)])#, SynapseLossDecay(ex_synapse_weight, inh_synapse_weight)])#, tf.keras.callbacks.LearningRateScheduler(tfa.optimizers.CyclicalLearningRate(1e-3, 5.1e-2, 1e-2, lambda x: 1))])

In [None]:
model = model_synapses
history = history_syn

# Best Module so Far

In [None]:
def printModule(dropout1, dropout2, sigmoid_threshold, sigmoid_mult, augment):
    print("~*~ Visual Module ~*~")
    if augment: print("Images are augmented")
    print(f"Image Dropout Rate: {dropout1}")
    print(f"Synapse Threshold: {sigmoid_threshold}")
    print(f"Synapse Training Sigmoid Multiplication: {sigmoid_mult}")
    print(f"Synapses Dropout Rate: {dropout2}")

In [None]:
def create_module(saccades=None, xaxis=True, use_sigmoid=True, pruning=True, dropout1=.2, dropout2=False,
                        sigmoid_threshold=0.9, sigmoid_mult=15, to_bool=True, 
                        augment=False, optimizer=SGD(5e-3), conv_shape=(16,16),
                        padding=PADDING, neurons=neurons, model_name="L5PC", non_neg=False):
    
    printModule(dropout1, dropout2, sigmoid_threshold, sigmoid_mult, augment)
    module_name = f"module_{model_name}_{dropout1}_{dropout2}_{sigmoid_threshold}_{sigmoid_mult}_{augment}"
    
    inp = keras.Input(shape=(256,256,1))
    
    padding = tf.Variable(lambda: padding, trainable=False)
    x = inp
    if augment: x = data_augmentation(augment)(x)
    x = processor(x)
    if dropout1: x = layers.Dropout(dropout1)(x)
    conv_shape = (x.shape[-3], 1) if xaxis else conv_shape
    x = Pad(padding)(x)
    if non_neg: x = layers.Conv2D(1278, conv_shape, strides=conv_shape, activity_regularizer=pre_synaptic_spike_regularization, kernel_constraint=keras.constraints.non_neg(),use_bias=True, name="WiringLayer")(x)
    else: x = layers.Conv2D(1278, conv_shape, strides=conv_shape, activity_regularizer=pre_synaptic_spike_regularization, use_bias=True, name="WiringLayer")(x)

    x = layers.BatchNormalization(name="BatchNorm")(x)
    x = layers.Activation(sigmoid)(x)
    x = ToBoolLayer(threshold=sigmoid_threshold, use_sigmoid=use_sigmoid, mult=sigmoid_mult, name='preNeuronBool')(x)
    
    x = K.squeeze(x, axis=-3)

    if dropout2: x = layers.Dropout(dropout2)(x)
    x = PredictNeuron(neurons)(x)[:,-198:,:]   # run through david's model, take only the post-noise time (32 ms X 6 times)
    x = layers.MaxPooling1D(x.shape[-2], strides=x.shape[-2], name="MaxPooling")(x)
    output = layers.Flatten(name="nSpikes")(x)

    model = keras.Model(inp, output)
    model.compile(optimizer=optimizer, loss={'nSpikes': MeanSquaredError()}, metrics={'nSpikes': keras.metrics.BinaryAccuracy()})
    return model, module_name

In [None]:
history_syn_2 = model_synapses.fit(train_ds, epochs=100, validation_data=valid_ds)#, callbacks=[SynapsePruner(100, 100, iterations=1)])#, SynapseLossDecay(ex_synapse_weight, inh_synapse_weight)])#, tf.keras.callbacks.LearningRateScheduler(tfa.optimizers.CyclicalLearningRate(1e-3, 5.1e-2, 1e-2, lambda x: 1))])

In [None]:
# model, module_name = create_module(use_sigmoid=True, dropout1=False, dropout2=False, sigmoid_mult=75, sigmoid_threshold=0.9, augment=.8, optimizer=Nadam())

In [None]:
model.summary()

In [None]:
# history = model.fit(train_ds, epochs=100, validation_data=valid_ds, callbacks=[SynapsePruner(100, 100, iterations=0)])

In [None]:
hist = history.history

plt.figure(figsize=(20,10))
plt.subplot(1,2,1)
plt.title("Loss")
plt.plot(hist["loss"], color="k", label="train")
plt.plot(hist["val_loss"], color="r", label="validation")
plt.axhline(0, color="b", linestyle="--")
plt.xlabel("epochs")
plt.legend()

accuracy = "nSpikes_binary_accuracy"
# accuracy = "binary_accuracy"

plt.subplot(1,2,2)
plt.title("Accuracy")
plt.plot(hist[accuracy], color="k", label=f"train (last: {str(round(hist[accuracy][-1], 2))})")
plt.plot(hist["val_"+accuracy], color="r", label=f"validation (last: {str(round(hist['val_'+accuracy][-1], 2))})")
plt.axhline(.5, color="b", linestyle="--", label="chance")
plt.ylim((0,1))
plt.xlabel("epochs")
plt.legend()

plt.show()

In [None]:
model.evaluate(test_ds)

# Plot

In [None]:
def plot_examples(model, startFrom=200, plot_cycle=False, plot_spikes=True, plot_start=None, write_last=False, weights_start=None, preNeuron="NeuronInput", spikeTrain="SpikeTrain", postNeuron="nSpikes"):
    plt.figure(figsize=(200, 250))
    how_many = 10
    for img, label in train_ds.take(1):
#         inputs = K.function(model.input, model.get_layer(bool_layers[0]).output)([img])
        inputs = K.function(model.input, model.get_layer(preNeuron).output)([img])
        outputs = K.function(model.input, model.get_layer(spikeTrain).output)([img])
        if write_last: nAP = K.function(model.input, model.get_layer(postNeuron).output)([img])
        for i in range(how_many):
            plt.subplot(how_many, 3, 3*i+1)
            plt.imshow(img[i,:,:,0]/255, cmap='gray')
            plt.title(LABELS[label[i]], fontdict={'fontsize':200})
            plt.axis("off")
            plt.subplot(how_many, 3, 3*i+2)
            if len(inputs[i].shape)==3:
                curr_input = np.squeeze(inputs[i], axis=-3)[startFrom:]
            else:
                curr_input = inputs[i][startFrom:]
            num_of_synapses = np.sum(curr_input, axis=-1)
            mean_synapses = round(num_of_synapses.mean(), 2)
            std_synapses = round(num_of_synapses.std(), 2)
            plt.title(f"\u03BC: {str(mean_synapses)},  \u03C3: {str(std_synapses)}", fontdict={'fontsize':200})
            plt.imshow(curr_input[-32:,:] if plot_cycle else curr_input, cmap='binary', vmin=0, vmax=1)
            plt.axis("off")
            plt.subplot(how_many, 3, 3*i+3)
            curr_output = outputs[i]            
            spikes = []
            curr_index = startFrom
            while True:
                start = np.where(curr_output[curr_index:] > 0.25)[0]
                if not start.shape[0]: break
                start = curr_index+start[0]
                end = np.where(curr_output[start:] < 0.1)[0]
                if not end.shape[0]: break
                end = end[0] + start
                spikes.append((start,end))
                curr_index = end + 1
            plt.title((f"Score: {str(round(nAP[i][0],2))}," if write_last else "") +f"\u03A3: {str(round(curr_output.sum(),2))}" + (f", spikes:{spikes}" if plot_spikes else ""), fontdict={'fontsize':150})
            plt.plot(curr_output, linewidth=5)
            if weights_start is not None:
                plt.plot(np.arange(weights_start, 400), model.get_layer("nSpikes").get_weights()[0][:,0], color='r')
            plt.ylim(0,1)
            if plot_start: plt.axvline(plot_start, color='g', linestyle=':', linewidth=10.)
            plt.axis('on')
        break

In [None]:
def plot_statistics(model, starting_time=0, cycle_time=32, preNeuron="NeuronInput", spikeTrain="SpikeTrain"):
    
    all_inputs = []
    n_synapses = []
    synapses_mu = []
    synapses_std = []
    synapses_max = []
    synapses_min = []
#     sum_output = []
    all_outputs = []
#     excitatory_synapses = []
#     inhibitory_synapses = []
    label_dct = {}
    synapses_mean_per_label = {}
    
    for img, label in train_ds.take(25):
        inputs = K.function(model.input, model.get_layer(preNeuron).output)([img])
        outputs = K.function(model.input, model.get_layer(spikeTrain).output)([img])
        all_inputs.append(inputs)
        labels = [LABELS[lbl] for lbl in label]
        for j in range(len(inputs)):
            if len(inputs[j].shape)==3:
                curr_input = np.squeeze(inputs[j], axis=-3)
            else:
                curr_input = inputs[j]
            
            num_of_synapses = np.sum(curr_input[starting_time:starting_time+cycle_time, :N_EXC], axis=-1)
#             excitatory_synapses.append(np.sum(curr_input[:, :N_EXC], axis=-1).mean())
#             inhibitory_synapses.append(np.sum(curr_input[:, N_EXC:], axis=-1).mean())
            n_synapses.append(num_of_synapses)
            synapses_mu.append(num_of_synapses.mean())
            synapses_std.append(num_of_synapses.std())
            synapses_max.append(num_of_synapses.max())
            synapses_min.append(num_of_synapses.min())
            curr_output = outputs[j]
#             sum_output.append(curr_output.sum())
            all_outputs.append(curr_output)
            if labels[j] not in label_dct:
                label_dct[labels[j]] = []
                synapses_mean_per_label[labels[j]] = {"Sum": [], "Ex":[], "Inh":[]}
            label_dct[labels[j]].append(curr_output)
            synapses_mean_per_label[labels[j]]["Sum"].append(np.sum(curr_input[starting_time:starting_time+cycle_time], axis=-1).mean())
            synapses_mean_per_label[labels[j]]["Ex"].append(np.sum(curr_input[starting_time:starting_time+cycle_time, :N_EXC], axis=-1).mean())
            synapses_mean_per_label[labels[j]]["Inh"].append(np.sum(curr_input[starting_time:starting_time+cycle_time, N_EXC:], axis=-1).mean())

    mean_dct = {}
    for lbl in label_dct.keys():
        mean_dct[lbl] = np.stack(label_dct[lbl]).mean(axis=0)

    plt.figure(figsize=(15,10))

    plt.subplot(3,3,1)
    plt.title('input values')
    plt.hist(np.array(all_inputs).ravel())

    plt.subplot(3,3,2)
    plt.title('mean of sum synapses')
    plt.hist(synapses_mu)

    plt.subplot(3,3,3)
    plt.title('std of sum synapses')
    plt.hist(synapses_std)

    plt.subplot(3,3,4)
    plt.title('max sum synapses')
    plt.hist(synapses_max)

    plt.subplot(3,3,5)
    plt.title('min sum synapses')
    plt.hist(synapses_min)

    plt.subplot(3,3,6)
    plt.title(f'mean output of neuron (AP)')
    for lbl, value in mean_dct.items():
        plt.plot(value, label=lbl, linewidth=2.)
    plt.ylim(0,1)
    plt.legend()

    plt.subplot(3,3,7)
    plt.title('Excitatory Synapses Mean per Label')
    plt.hist([synapses_mean_per_label[lbl]["Ex"] for lbl in synapses_mean_per_label.keys()], label=list(synapses_mean_per_label.keys()))
    plt.legend()
    plt.xlim()


    plt.subplot(3,3,8)
    plt.title('Inhibitory Synapses Mean per Label')
    plt.hist([synapses_mean_per_label[lbl]["Inh"] for lbl in synapses_mean_per_label.keys()], label=list(synapses_mean_per_label.keys()))
    plt.legend()
    plt.xlim()

    
    plt.subplot(3,3,9)
    plt.title('All Synapses Mean per Label')
    plt.hist([synapses_mean_per_label[lbl]["Sum"] for lbl in synapses_mean_per_label.keys()], label=list(synapses_mean_per_label.keys()))
    plt.legend()

    plt.show()

In [None]:
import re
moduleLayer = r"tf_op_layer_strided_slice_\d+"
selectedLayer = None
for layer in model.layers:
    if bool(re.search(moduleLayer, layer.name)):
        selectedLayer = layer.name
        break
    else: print(layer.name)
print(selectedLayer)

In [None]:
plot_examples(model, startFrom=0, plot_cycle=False, plot_spikes=False, plot_start=None, write_last=True, preNeuron="preNeuronBool", spikeTrain=selectedLayer)

In [None]:
plot_statistics(model, 200 + 200%32, preNeuron="preNeuronBool", spikeTrain=selectedLayer)

In [None]:
plt.figure(figsize=(5,15))
weights = model.get_layer("WiringLayer").get_weights()[0][0]
plt.suptitle(f"Weights Examples (8 synapses out of {2*N_EXC})")
for i in range(8):
    for j in range(4):
        plt.subplot(8, 4, 4*i+j+1)
        x = np.zeros((32,8))
        for l in range(32//4):
            x[:,l] = weights[:,l*4+j,i]
        plt.imshow(x, vmin=tf.reduce_min(weights), vmax=tf.reduce_max(weights))
        if not i:
            plt.title(r"$\Theta={}\pi$".format(thetas[j]/np.pi))
        if not j:
            plt.ylabel(f"Syn {i}\nyaxis pixels")
        if i == 7:
            plt.xlabel(f"filter")
        plt.xticks([])
        plt.yticks([])

In [None]:
plt.plot(weights.mean(axis=(0,2)), label="mean")
plt.plot(np.abs(weights).mean(axis=(0,2)), label="abs mean")
plt.title("Mean and Abs Mean of Synapse Weights by Channel (after processing)")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(20,20))

plt.subplot(2,2,1)
plt.title(r"Exc Synapses Weights (std by Synapse)")
plt.hist(weights[:,:,:N_EXC].std(axis=(0,1)))

plt.subplot(2,2,3)
plt.title(r"Exc Synapses Weights (sum by Synapse)")
plt.hist(weights[:,:,:N_EXC].sum(axis=(0,1)))

plt.subplot(2,2,2)
plt.title(r"Inh Synapses Weights (std by Synapse)")
plt.hist(weights[:,:,N_EXC:].std(axis=(0,1)))

plt.subplot(2,2,4)
plt.title(r"Inh Synapses Weights (sum by Synapse)")
plt.hist(weights[:,:,N_EXC:].sum(axis=(0,1)))

plt.show()

In [None]:
# plt.figure(figsize=(10,20))
weights = model.get_layer("WiringLayer").get_weights()[0][0]
print(f"min: {weights.min()}, max: {weights.max()}\nmean: {weights.mean()}, sd: {weights.std()}\n")
exc = weights[:,:,:N_EXC]
inh = weights[:,:,N_EXC:]
# xExc = K.flatten(tf.transpose(exc) / tf.reduce_max(exc, axis=-1))
# xInh = K.flatten(tf.transpose(inh) / tf.reduce_max(exc, axis=-1))
# print(f"Exc: mean per pre-synaptic: {xExc.mean()}, mean binary: {(xExc>PRESYNAPTIC_THRESHOLD).sum(axis=-1).mean()}")
# print(f"Inh: mean per pre-synaptic: {weights[:,N_EXC:].mean()}, mean binary: {(weights[:,N_EXC:]>PRESYNAPTIC_THRESHOLD).sum(axis=-1).mean()}")
amountExc = [(exc>0.01).sum(axis=(0,1))]
amountInh = [(inh>0.01).sum(axis=(0,1))]

plt.subplot(1,2,1)
plt.hist(amountExc)
plt.title("Exc Synapses per Pixel")
plt.ylabel("Synapses per Pixel")

plt.subplot(1,2,2)
plt.hist(amountInh)
plt.title("Inh Synapses per Pixel")
# plt.ylabel("Synapses per Neuron")



# plt.subplot(3,1,1)
# plt.title("Histogram Synapses per Presynaptic Neuron (Exc synapse)")
# plt.hist((weights[:,:N_EXC]>PRESYNAPTIC_THRESHOLD).sum(axis=-1))
# plt.ylabel("nSynapses")

# plt.subplot(3,1,2)
# plt.title("Histogram Synapses per Presynaptic Neuron (Inh synapse)")
# plt.hist((weights[:,N_EXC:]>PRESYNAPTIC_THRESHOLD).sum(axis=-1))
# plt.ylabel("nSynapses")

# plt.subplot(3,1,3)
# plt.title("Weights")
# plt.imshow(weights, cmap='binary', vmin=0, vmax=weights.max())
# plt.axvline(N_EXC, color='r')
# plt.xlabel("synapse")
# plt.ylabel("neuron")

plt.show()

# Convert Images to 400X1278 matrices

# Save Weights

In [None]:
serial = 0

In [None]:
with open(f"{module_name}_weights.npy", 'wb') as f:
    np.save(f, model.get_layer("WiringLayer").get_weights()[0])
with open(f"{module_name}_bias.npy", 'wb') as f:
    np.save(f, model.get_layer("WiringLayer").get_weights()[1])
with open(f"{module_name}_batchnorm.npy", 'wb') as f:
    np.save(f, model.get_layer("BatchNorm").get_weights())

In [None]:
model.save(f"./{module_name}")

# Evolution (not used)

In [None]:
class EvolutionModule:
    generations = {}
    results = {}
    alive = set()
    
    def __init__(self, ID, generation, create_func, noise=.1, parent=None):
        self.module = None
        
        self.generation = generation
        self.cur_generation = generation
        self.ID = ID
        self.parent = parent
        self.name = "Module" + (f"({parent.__str__()[6:]})-" if parent else "") + f"[{generation}.{ID}]"
        
        if generation not in EvolutionModule.generations: EvolutionModule.generations[generation] = [self]
        else: EvolutionModule.generations[generation].append(self)
        EvolutionModule.alive.add(self)
                   
        self.history = None
        
        self.children = set()
        self.status = True
        
        self.create_func = create_func
        self.original_weights = None
        self.noise = noise
        
        self._create_model()
        
    
    def _create_model(self):
        self.module = self.create_func()
        if self.parent:
            self.original_weights = [weight for weight in self.parent.get_weights()]
            self.original_weights[2] = self.original_weights[2] + np.random.rand(*self.original_weights[2].shape) * self.original_weights[2].std() * self.noise
            self.module.set_weights(self.original_weights)
    
    def give_birth(self, n):
        self.cur_generation += 1
        for i in range(n):
            self.children.add(EvolutionModule(i, self.cur_generation, self.create_func, self.noise, self))
    
    def is_alive(self):
        return self.status
    
    def fit(self, training_ds, epochs, validation_data):
        self.to_print()
        history = self.module.fit(train_ds, epochs=epochs, validation_data=validation_data).history
        if not self.history: self.history = history
        else: 
            for key, value in history.items(): 
                self.history[key].extend(value)
    
    def die(self):
        self.status = False
        EvolutionModule.alive.remove(self)
    
    def value(self, n):
        if not self.history: return None
        return np.mean(self.history['loss'][-n:])
    
    def accuracy(self):
        return self.history['val_logits_binary_accuracy'][-1]
    
    def value_to_print(self, n):
        return self.history['loss'][-1], np.mean(self.history['loss'][-n:]), self.history['val_loss'][-1], np.mean(self.history['val_loss'][-n:]), self.history['val_logits_binary_accuracy'][-1], np.mean(self.history['val_logits_binary_accuracy'][-n:])
    
    def to_print(self):
        print(f"\n Beginning fitting for {self.name}...\n")
    
    def __str__(self):
        return self.name
    
    def get_weights(self):
        return self.module.get_weights()

In [None]:
def create_model():
    return playable(saccades=None, xaxis=True, use_sigmoid=True, sigmoid_mult=50, to_bool=True, useSynapse=True, qSynapse=0.1, sigmoid_threshold=0.8, augment=False, optimizer=SGD(5e-3))

In [None]:
def run_evolution(nStart=10, nEpochs=10, nMax=2, nPerMax=3, nGenerations=5, rand=0.1, measure_by=2):
    EvolutionModule.generations.clear()
    EvolutionModule.results.clear()
    EvolutionModule.alive.clear()
    
    print("\nGeneration 0:\n")
    for i in range(nStart):
        module = EvolutionModule(i, 0, create_model, rand)
        module.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
    EvolutionModule.results[0] = sorted(EvolutionModule.alive, key=lambda x: x.value(measure_by))
    print("\n" + "* "*10)
    print(f'Generation: 0:\n Best accuracy (by loss): {EvolutionModule.results[0][0].accuracy()} by {EvolutionModule.results[0][0]}\n')
    for module in EvolutionModule.alive:
        print(module, end=': ')
        print("Last loss: {0}, Mean loss: {1}, Last val_loss: {2}, Mean val_loss: {3}, Last val accuracy: {4}, Mean val acc: {5}".format(*module.value_to_print(measure_by)))
    print("\n" + "* "*10+'\n')
    
    
    for generation in range(1, nGenerations):    
        print(f"\nGeneration {generation}:\n")
        for module, indicator in zip(EvolutionModule.results[generation-1], [True]*nMax+[False]*len(EvolutionModule.alive)):
            if indicator: module.give_birth(nPerMax)
            else: module.die()
        for module in EvolutionModule.alive:
            module.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
        
        EvolutionModule.results[generation] = sorted(EvolutionModule.alive, key=lambda x: x.value(measure_by))
        print("\n" + "* "*10)
        print(f'Generation: {generation}:\n Best accuracy (by loss): {EvolutionModule.results[generation][0].accuracy()}\n')
        for module in EvolutionModule.alive:
            print(module, end=': ')
            print("Last loss: {0}, Mean loss: {1}, Last val_loss: {2}, Mean val_loss: {3}, Last val accuracy: {4}, Mean val acc: {5}".format(*module.value_to_print(measure_by)))
        print("\n" + "* "*10+'\n')
        
    print("\n~ ~ ~ ~ ~ ~ Done! ~ ~ ~ ~ ~ ~\n")
    return EvolutionModule.results[-1][0], EvolutionModule.generations, EvolutionModule.results, EvolutionModule.alive

In [None]:
def evolution(result_dct={}, nStart=10, nEpochs=10, nMax=2, nPerMax=3, nGenerations=5, rand=0.1, measure_by=3):
    result_dct[0] = []
    for i in range(nStart):
        print("\nStarting first generation module number", i, "...\n")
        model = create_model()
        history = model.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
        result_dct[0].append((model, history.history['loss'], history.history['val_loss'], history.history['val_logits_binary_accuracy']))
    for generation in range(1,nGenerations):
        result_dct[generation] = []
        max_chosen = sorted(result_dct[generation-1], key=lambda x: sum(x[1][-measure_by:]))[:nMax]
        
        print("\n" + "* "*10)
        print(f'Generation: {generation}:\n Best accuracy (by loss): {max_chosen[0][-1][-1]}\n')
        for result in result_dct[generation-1]:
            print("last loss:", result[1][-1], " loss mean:", np.mean(result[1][-measure_by:]),"   last accuracy:", result[-1][-1], "  accuracy mean:", np.mean(result[-1][-measure_by:]))
        print("\n" + "* "*10+'\n')
        
        print("Starting generation ", generation+1,"...\n")
        i = 1
        for max_model, _, _ in max_chosen:
            old_weights = max_model.get_weights()
            for _ in range(nPerMax):
                print("Starting module number ", i, " out of", nMax*(nPerMax+1), "\n")
                new_weights = [weight for weight in old_weights]
                old_conv_weights = new_weights[2]
                new_weights[2] = old_conv_weights + np.random.rand(*old_conv_weights.shape) * old_conv_weights.std() * rand
                model = create_model()
                model.set_weights(new_weights)
                history = model.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
                result_dct[generation].append((model, history.history['loss'], history.history['val_loss'], history.history['val_logits_binary_accuracy']))
                i+=1
            print("\nStarting father...\n")
            history = max_model.fit(train_ds, epochs=nEpochs, validation_data=valid_ds)
            result_dct[generation].append((model, history.history['loss'], history.history['val_loss'], history.history['val_logits_binary_accuracy']))
    max_chosen = sorted(result_dct[generation-1], key=lambda x: sum(x[1][-measure_by:]))[:nMax]
        
    print("\n" + "* "*10)
    print(f'Generation: {nGenerations}:\n Best accuracy (by loss): {max_chosen[0][-1][-1]}\n')
    for result in result_dct[nGenerations-1]:
        print("last loss:", result[2][-1], " loss mean:", np.mean(result[2][-measure_by:]),"   last accuracy:", result[-1][-1], "  accuracy mean:", np.mean(result[-1][-measure_by:]))
    print("\n" + "* "*10+'\n')
    return result_dct

In [None]:
# best_module, generations, results, alive = run_evolution(nStart=10, nEpochs=10, nMax=3, nPerMax=2, nGenerations=6)
# print(best_module.value_to_print(1))

In [None]:
# plot_examples(best_module)

In [None]:
# plot_statistics(best_module)

In [None]:
# cur_model = best_module
# while cur_model:
#     print(" * * * * * * *"*2)
#     print("Current module:", cur_model)
#     print("current weights:", best_module.get_weights()[2].std())
#     print("Original weights:", best_module.original_weights[2].std())
#     cur_model = cur_model.parent