In [1]:
!nvidia-smi

Wed Oct 26 15:44:57 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  On   | 00000000:01:00.0 Off |                    0 |
| N/A   53C    P0   160W / 275W |  23162MiB / 40536MiB |     99%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:47:00.0 Off |                    0 |
| N/A   54C    P0   269W / 275W |  37398MiB / 40536MiB |    100%      Default |
|       

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sklearn
from enum import Enum
import imageio
import os
import hashlib

%matplotlib inline

dtype = 'float32'
tf.keras.backend.set_floatx(dtype)

2022-10-26 15:44:58.453648: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-10-26 15:44:58.613384: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-10-26 15:44:59.233193: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-10-26 15:44:59.233257: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or 

In [143]:
##### datasets.py

def get_dataset_sample(X, y, fraction, seed=None):
    if seed is not None:
        np.random.seed(seed)  # Set random seed
    selection = np.random.choice([True, False], len(X), p=[fraction, 1 - fraction])
    if seed is not None:
        np.random.seed()  # Unset random seed
    X_sampled = X[selection]
    y_sampled = y[selection]
    return X_sampled, y_sampled


class Dataset:
    def __init__(self, X_train, y_train, X_test, y_test, shape, shape_flattened, fraction, vision=True,
                 standardize=True):
        if fraction is not None:
            X_train, y_train = get_dataset_sample(X_train, y_train, fraction, seed=42)
            X_test, y_test = get_dataset_sample(X_test, y_test, fraction, seed=42)

        X_train = X_train.astype(dtype)
        y_train = y_train.astype(dtype)
        X_test = X_test.astype(dtype)
        y_test = y_test.astype(dtype)

        if vision:
            X_train = X_train / 255.0
            X_test = X_test / 255.0

        X_train = np.reshape(X_train, shape_flattened)
        X_test = np.reshape(X_test, shape_flattened)

        X = np.concatenate((X_train, X_test), axis=0)
        y = np.concatenate((y_train, y_test), axis=0)

        if standardize:
            from sklearn.preprocessing import StandardScaler

            scaler = StandardScaler()
            scaler.fit(X_train)  # Scaling each feature independently

            X_norm = scaler.transform(X)
            del X
            X_train_norm = scaler.transform(X_train)
            del X_train
            X_test_norm = scaler.transform(X_test)
            del X_test
        else:
            X_norm = X
            X_train_norm = X_train
            X_test_norm = X_test

        X_norm = np.reshape(X_norm, shape)
        X_train_norm = np.reshape(X_train_norm, shape)
        X_test_norm = np.reshape(X_test_norm, shape)

        # Shuffle X_norm and y
        assert len(X_norm) == len(y)
        p = np.random.permutation(len(X_norm))
        X_norm, y = X_norm[p], y[p]

        self.X_norm = X_norm
        self.y = y
        self.X_train_norm = X_train_norm
        self.y_train = y_train
        self.X_test_norm = X_test_norm
        self.y_test = y_test


def get_cifar_10_dataset(fraction=None):
    cifar10 = tf.keras.datasets.cifar10
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    shape = (-1, 32, 32, 3)
    shape_flattened = (-1, 3072)  # Scaling each feature independently
    return Dataset(X_train, y_train, X_test, y_test, shape=shape, shape_flattened=shape_flattened, fraction=fraction)


def get_cifar_100_dataset(fraction=None):
    cifar100 = tf.keras.datasets.cifar100
    (X_train, y_train), (X_test, y_test) = cifar100.load_data()
    shape = (-1, 32, 32, 3)
    shape_flattened = (-1, 3072)  # Scaling each feature independently
    return Dataset(X_train, y_train, X_test, y_test, shape=shape, shape_flattened=shape_flattened, fraction=fraction)


def get_svhn_dataset(fraction=None):
    from urllib.request import urlretrieve
    from scipy import io

    train_filename, _ = urlretrieve('http://ufldl.stanford.edu/housenumbers/train_32x32.mat')
    test_filename, _ = urlretrieve('http://ufldl.stanford.edu/housenumbers/test_32x32.mat')

    X_train = io.loadmat(train_filename, variable_names='X').get('X')
    y_train = io.loadmat(train_filename, variable_names='y').get('y')
    X_test = io.loadmat(test_filename, variable_names='X').get('X')
    y_test = io.loadmat(test_filename, variable_names='y').get('y')

    X_train = np.moveaxis(X_train, -1, 0)
    y_train -= 1
    X_test = np.moveaxis(X_test, -1, 0)
    y_test -= 1

    shape = (-1, 32, 32, 3)
    shape_flattened = (-1, 3072)  # Scaling each feature independently
    return Dataset(X_train, y_train, X_test, y_test, shape=shape, shape_flattened=shape_flattened, fraction=fraction)


def get_tiny_imagenet_dataset(fraction=None):
    """
    Original source: https://github.com/sonugiri1043/Train_ResNet_On_Tiny_ImageNet/blob/master/Train_ResNet_On_Tiny_ImageNet.ipynb
    Original author: sonugiri1043@gmail.com
    """

    if not os.path.isdir('IMagenet'):
        os.system('git clone https://github.com/seshuad/IMagenet')

    print("Processing the downloaded dataset...")

    path = 'IMagenet/tiny-imagenet-200/'

    id_dict = {}
    for i, line in enumerate(open(path + 'wnids.txt', 'r')):
        id_dict[line.replace('\n', '')] = i

    train_data = list()
    test_data = list()
    train_labels = list()
    test_labels = list()

    for key, value in id_dict.items():
        train_data += [imageio.imread(path + 'train/{}/images/{}_{}.JPEG'.format(key, key, str(i)), pilmode='RGB') for i
                       in range(500)]
        train_labels_ = np.array([[0] * 200] * 500)
        train_labels_[:, value] = 1
        train_labels += train_labels_.tolist()

    X_train = np.array(train_data)
    X_test = np.array(test_data)
    del train_data, train_labels_

    for line in open(path + 'val/val_annotations.txt'):
        img_name, class_id = line.split('\t')[:2]
        test_data.append(imageio.imread(path + 'val/images/{}'.format(img_name), pilmode='RGB'))
        test_labels_ = np.array([[0] * 200])
        test_labels_[0, id_dict[class_id]] = 1
        test_labels += test_labels_.tolist()

    y_train = np.argmax(np.array(train_labels), axis=1)
    y_test = np.argmax(np.array(test_labels), axis=1)
    del train_labels
    del test_data, test_labels_, test_labels

    shape = (-1, 64, 64, 3)
    shape_flattened = (-1, 12288)  # Scaling each feature independently
    print("Calling Dataset()")
    return Dataset(X_train, y_train, X_test, y_test, shape=shape, shape_flattened=shape_flattened, fraction=fraction)


def get_mnist_dataset(fraction=None):
    mnist = tf.keras.datasets.mnist
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    shape = (-1, 28, 28, 1)
    shape_flattened = (-1, 1)  # Scaling all features together
    return Dataset(X_train, y_train, X_test, y_test, shape=shape, shape_flattened=shape_flattened, fraction=fraction)


def get_fashion_mnist_dataset(fraction=None):
    fashion_mnist = tf.keras.datasets.fashion_mnist
    (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
    shape = (-1, 28, 28, 1)
    shape_flattened = (-1, 1)  # Scaling all features together
    return Dataset(X_train, y_train, X_test, y_test, shape=shape, shape_flattened=shape_flattened, fraction=fraction)


def get_fifteen_puzzle_dataset(path=None, fraction=None):
    from sklearn.model_selection import train_test_split

    if path is None:
        from google.colab import drive
        drive.mount('/content/gdrive')
        path = 'gdrive/MyDrive/15-costs-v3.csv'
    costs = pd.read_csv(path)
    costs = costs.sample(frac=fraction, random_state=42)

    X_raw = costs.iloc[:, :-1].values
    y = costs['cost'].values
    X = np.apply_along_axis(lambda x: np.eye(16)[x].ravel(), 1, X_raw)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    del X, X_raw, y

    shape = (-1, 256)
    shape_flattened = (-1, 256)  # Scaling all features together
    return Dataset(X_train, y_train, X_test, y_test, shape=shape, shape_flattened=shape_flattened, vision=False,
                   fraction=None)

##### models.py

dtype = 'float32'
tf.keras.backend.set_floatx(dtype)


class Regularizer(tf.keras.regularizers.Regularizer):
    def __init__(self):
        self.n_new_neurons = 0
        self.scaling_tensor = None
        self.set_regularization_penalty(0.)
        self.set_regularization_method(None)

    def copy(self):
        regularizer_copy = Regularizer.__new__(Regularizer)
        regularizer_copy.n_new_neurons = self.n_new_neurons
        regularizer_copy.scaling_tensor = self.scaling_tensor
        regularizer_copy.set_regularization_penalty(self.regularization_penalty)
        regularizer_copy.set_regularization_method(self.regularization_method)
        return regularizer_copy

    def __call__(self, x):
        if self.regularization_method is None or self.regularization_penalty == 0:
            return 0
        elif self.regularization_method == 'weighted_l1':
            return self.weighted_l1(x)
        elif self.regularization_method == 'weighted_l1_reordered':
            return self.weighted_l1_reordered(x)
        elif self.regularization_method == 'group_sparsity':
            return self.group_sparsity(x)
        elif self.regularization_method == 'l1':
            return self.l1(x)
        else:
            raise NotImplementedError(f"Unknown regularization method {self.regularization_method}")

    def weighted_l1(self, x):
        # I.e. for a parameter matrix of 4 input and 10 output neurons:
        #
        # [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        #  [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        #  [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        #  [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]
        #
        # the scaling tensor, as well as the resulting weighted values, could be:
        #
        # [[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
        #  [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
        #  [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
        #  [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]]
        #
        # Therefore every additional output neuron is regularized more.

        scaling_tensor = tf.cumsum(tf.constant(self.regularization_penalty, shape=x.shape, dtype=dtype), axis=-1)
        weighted_values = scaling_tensor * tf.abs(x)
        return tf.reduce_sum(weighted_values)

    def weighted_l1_reordered(self, x):
        if self.update_scaling_tensor:
            scaling_tensor_raw = tf.cumsum(tf.constant(self.regularization_penalty, shape=x.shape, dtype=dtype),
                                           axis=-1)

            scaling_tensor_old_neurons = scaling_tensor_raw[:, :-self.n_new_neurons]
            scaling_tensor_new_neurons = scaling_tensor_raw[:, -self.n_new_neurons:]
            scaling_tensor_old_neurons_shuffled = tf.transpose(
                tf.random.shuffle(tf.transpose(scaling_tensor_old_neurons)))
            self.scaling_tensor = tf.concat([scaling_tensor_old_neurons_shuffled, scaling_tensor_new_neurons], axis=-1)
            self.update_scaling_tensor = False

        weighted_values = self.scaling_tensor * tf.abs(x)
        return tf.reduce_sum(weighted_values)

    def group_sparsity(self, x):
        # I.e. for a parameter matrix of 3 input and 5 output neurons:
        #
        # [[1., 1., 1., 1., 1.],
        #  [1., 2., 2., 1., 2.],
        #  [2., 2., 3., 1., 3.]]
        #
        # The resulting vector of group norms is [2., 2., 3., 1., 3.], therefore for
        # every output neuron, its incoming connections form a group.

        group_norms = tf.norm(x, ord=2, axis=0)
        # assert group_norms.shape[0] == x.shape[1]
        return self.regularization_penalty * tf.reduce_sum(group_norms)

    def l1(self, x):
        weighted_values = self.regularization_penalty * tf.abs(x)
        return tf.reduce_sum(weighted_values)

    def prune(self):
        self.n_new_neurons = 0
        if self.regularization_method == 'weighted_l1_reordered':
            self.update_scaling_tensor = True

    def grow(self, n_new_neurons):
        self.n_new_neurons = n_new_neurons
        if self.regularization_method == 'weighted_l1_reordered':
            self.update_scaling_tensor = True

    def set_regularization_penalty(self, regularization_penalty):
        self.regularization_penalty = regularization_penalty

    def set_regularization_method(self, regularization_method):
        self.regularization_method = regularization_method
        if self.regularization_method == 'weighted_l1_reordered':
            self.update_scaling_tensor = True
        else:
            self.update_scaling_tensor = None

    def get_config(self):
        return {'regularization_penalty': float(self.regularization_penalty)}


class DASLayer(tf.keras.layers.Layer):
    def __init__(self, input_shape, fixed_size):
        super().__init__()

        self._input_shape = input_shape
        self.fixed_size = fixed_size
        self._built = False


class Dense(DASLayer):
    def __init__(self, units, activation, kernel_initializer='glorot_uniform',
                 bias_initializer='zeros', input_shape=None, fixed_size=False,
                 regularizer=None):
        super().__init__(input_shape, fixed_size)

        self.units = units
        self.activation = activation
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer

        self.A = tf.keras.activations.get(activation)
        self.W_init = tf.keras.initializers.get(kernel_initializer)
        self.b_init = tf.keras.initializers.get(bias_initializer)
        if regularizer is not None:
            self.regularizer = regularizer
        else:
            self.regularizer = Regularizer()

    def copy(self):
        layer_copy = Dense.__new__(Dense)
        super(Dense, layer_copy).__init__(self._input_shape)

        layer_copy.units = self.units
        layer_copy.activation = self.activation
        layer_copy.kernel_initializer = self.kernel_initializer
        layer_copy.bias_initializer = self.bias_initializer
        layer_copy.fixed_size = self.fixed_size

        layer_copy.A = self.A
        layer_copy.W_init = self.W_init
        layer_copy.b_init = self.b_init
        layer_copy.regularizer = self.regularizer.copy()

        layer_copy.W = tf.Variable(
            name='W',
            initial_value=self.W,
            trainable=True)

        layer_copy.b = tf.Variable(
            name='b',
            initial_value=self.b,
            trainable=True)

        layer_copy.add_regularizer_loss()

        layer_copy._built = True
        return layer_copy

    def build(self, input_shape):
        if self._built:
            return

        input_units = input_shape[-1]

        self.W = tf.Variable(
            name='W',
            initial_value=self.W_init(shape=(input_units, self.units), dtype=dtype),
            trainable=True)

        self.b = tf.Variable(
            name='b',
            initial_value=self.b_init(shape=(self.units,), dtype=dtype),
            trainable=True)

        self.add_regularizer_loss()

        self._built = True

    def call(self, inputs, training=None):
        return self.A(tf.matmul(inputs, self.W) + self.b)

    def add_regularizer_loss(self):
        self.add_loss(lambda: self.regularizer(tf.concat([self.W, tf.reshape(self.b, (1, -1))], axis=0)))

    def get_size(self):
        return self.W.shape[0], self.W.shape[1]

    def prune(self, threshold, active_input_units_indices):
        # Remove connections from pruned units in previous layer
        new_W = tf.gather(self.W.value(), active_input_units_indices, axis=0)

        if self.fixed_size:
            active_output_neurons_indices = list(range(new_W.shape[1]))
        else:
            # Prune units in this layer
            weights_with_biases = tf.concat([new_W, tf.reshape(self.b.value(), (1, -1))], axis=0)
            neurons_are_active = tf.math.reduce_max(tf.abs(weights_with_biases), axis=0) >= threshold
            active_output_neurons_indices = tf.reshape(tf.where(neurons_are_active), (-1,))

            new_W = tf.gather(new_W, active_output_neurons_indices, axis=1)
            new_b = tf.gather(self.b.value(), active_output_neurons_indices, axis=0)

            self.b = tf.Variable(name='b', initial_value=new_b, trainable=True)

        self.W = tf.Variable(name='W', initial_value=new_W, trainable=True)

        self.regularizer.prune()
        return active_output_neurons_indices

    def grow(self, n_new_input_units, percentage, min_new_units, scaling_factor):
        if n_new_input_units > 0:
            # Add connections to grown units in previous layer
            W_growth = self.W_init(shape=(self.W.shape[0] + n_new_input_units, self.W.shape[1]), dtype=dtype)[
                       -n_new_input_units:,
                       :] * scaling_factor  # TODO is it better to be multiplying here by scaling_factor? It does help with not increasing the max weights of existing neurons when new neurons are added.
            new_W = tf.concat([self.W.value(), W_growth], axis=0)
        else:
            new_W = self.W.value()

        if self.fixed_size:
            n_new_output_units = 0
        else:
            # Grow new units in this layer
            n_new_output_units = max(min_new_units, int(new_W.shape[1] * percentage))
            if n_new_output_units > 0:
                W_growth = self.W_init(shape=(new_W.shape[0], new_W.shape[1] + n_new_output_units), dtype=dtype)[:,
                           -n_new_output_units:] * scaling_factor
                b_growth = self.b_init(shape=(n_new_output_units,),
                                       dtype=dtype)  # TODO for all possible bias initializers to work properly, the whole bias vector should be initialized at once
                new_W = tf.concat([new_W, W_growth], axis=1)
                new_b = tf.concat([self.b.value(), b_growth], axis=0)

                self.b = tf.Variable(name='b', initial_value=new_b, trainable=True)

        self.W = tf.Variable(name='W', initial_value=new_W, trainable=True)

        self.regularizer.grow(n_new_output_units)
        return n_new_output_units

    def mutate(self, mutation_strength):
        self.W.assign_add(tf.random.normal(self.W.shape, mean=0.0, stddev=mutation_strength))
        self.b.assign_add(tf.random.normal(self.b.shape, mean=0.0, stddev=mutation_strength))

    def set_regularization_penalty(self, regularization_penalty):
        if not self.fixed_size:
            self.regularizer.set_regularization_penalty(regularization_penalty)

    def set_regularization_method(self, regularization_method):
        if not self.fixed_size:
            self.regularizer.set_regularization_method(regularization_method)

    def get_param_string(self):
        param_string = ""
        weights_with_bias = tf.concat([self.W, tf.reshape(self.b, (1, -1))], axis=0)
        max_parameters = tf.math.reduce_max(tf.abs(weights_with_bias), axis=0).numpy()
        magnitudes = np.floor(np.log10(max_parameters))
        for m in magnitudes:
            if m > 0:
                m = 0
            param_string += str(int(-m))
        return param_string


class Conv2D(DASLayer):
    def __init__(self, filters, filter_size, activation, strides=(1, 1),
                 padding='SAME', kernel_initializer='glorot_uniform',
                 bias_initializer='zeros', input_shape=None, fixed_size=False,
                 regularizer=None):
        super().__init__(input_shape, fixed_size)

        self.filters = filters
        self.filter_size = filter_size
        self.activation = activation
        self.strides = strides
        self.padding = padding
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer

        self.A = tf.keras.activations.get(activation)
        self.F_init = tf.keras.initializers.get(kernel_initializer)
        self.b_init = tf.keras.initializers.get(bias_initializer)
        if regularizer is not None:
            self.regularizer = regularizer
        else:
            self.regularizer = Regularizer()

    def copy(self):
        layer_copy = Conv2D.__new__(Conv2D)
        super(Conv2D, layer_copy).__init__(self._input_shape)

        layer_copy.filters = self.filters
        layer_copy.filter_size = self.filter_size
        layer_copy.activation = self.activation
        layer_copy.strides = self.strides
        layer_copy.padding = self.padding
        layer_copy.kernel_initializer = self.kernel_initializer
        layer_copy.bias_initializer = self.bias_initializer
        layer_copy.fixed_size = self.fixed_size

        layer_copy.A = self.A
        layer_copy.F_init = self.F_init
        layer_copy.b_init = self.b_init
        layer_copy.regularizer = self.regularizer.copy()

        layer_copy.F = tf.Variable(
            name='F',
            initial_value=self.F,
            trainable=True)

        layer_copy.b = tf.Variable(
            name='b',
            initial_value=self.b,
            trainable=True)

        layer_copy.add_regularizer_loss()

        layer_copy._built = True
        return layer_copy

    def build(self, input_shape):
        if self._built:
            return

        input_filters = input_shape[-1]

        self.F = tf.Variable(
            name='F',
            initial_value=self.F_init(
                shape=(self.filter_size[0], self.filter_size[1], input_filters, self.filters), dtype=dtype
            ),
            trainable=True)

        self.b = tf.Variable(
            name='b',
            initial_value=self.b_init(shape=(self.filters,), dtype=dtype),
            trainable=True)

        self.add_regularizer_loss()

        self._built = True

    def call(self, inputs, training=None):
        y = tf.nn.conv2d(inputs, self.F, strides=self.strides, padding=self.padding)
        y = tf.nn.bias_add(y, self.b)
        y = self.A(y)
        return y

    def add_regularizer_loss(self):
        self.add_loss(lambda: self.regularizer(
            tf.concat([tf.reshape(self.F, (-1, self.F.shape[-1])), tf.reshape(self.b, (1, -1))], axis=0)))

    def get_size(self):
        return self.F.shape[-2], self.F.shape[-1]

    def prune(self, threshold, active_input_units_indices):
        # Remove connections from pruned units in previous layer
        new_F = tf.gather(self.F.value(), active_input_units_indices, axis=-2)

        if self.fixed_size:
            active_output_filters_indices = list(range(new_F.shape[-1]))
        else:
            # Prune units in this layer
            F_reduced_max = tf.reshape(tf.math.reduce_max(tf.abs(new_F), axis=(0, 1, 2)), (1, -1))
            F_reduced_max_with_biases = tf.concat([F_reduced_max, tf.reshape(self.b.value(), (1, -1))], axis=0)
            filters_are_active = tf.math.reduce_max(tf.abs(F_reduced_max_with_biases), axis=0) >= threshold
            active_output_filters_indices = tf.reshape(tf.where(filters_are_active), (-1,))

            new_F = tf.gather(new_F, active_output_filters_indices, axis=-1)
            new_b = tf.gather(self.b.value(), active_output_filters_indices, axis=0)

            self.b = tf.Variable(name='b', initial_value=new_b, trainable=True)

        self.F = tf.Variable(name='F', initial_value=new_F, trainable=True)

        self.regularizer.prune()
        return active_output_filters_indices

    def grow(self, n_new_input_units, percentage, min_new_units, scaling_factor):
        if n_new_input_units > 0:
            # Add connections to grown units in previous layer
            F_growth = self.F_init(
                shape=(self.F.shape[0], self.F.shape[1], self.F.shape[2] + n_new_input_units, self.F.shape[3]),
                dtype=dtype)[:, :, -n_new_input_units:,
                       :] * scaling_factor  # TODO is it better to be multiplying here by scaling_factor? It does help with not increasing the max weights of existing neurons when new neurons are added.
            new_F = tf.concat([self.F.value(), F_growth], axis=-2)
        else:
            new_F = self.F.value()

        if self.fixed_size:
            n_new_output_units = 0
        else:
            # Grow new units in this layer
            n_new_output_units = max(min_new_units, int(new_F.shape[-1] * percentage))
            if n_new_output_units > 0:
                F_growth = self.F_init(
                    shape=(new_F.shape[0], new_F.shape[1], new_F.shape[2], new_F.shape[3] + n_new_output_units),
                    dtype=dtype)[:, :, :, -n_new_output_units:] * scaling_factor
                b_growth = self.b_init(shape=(n_new_output_units,),
                                       dtype=dtype)  # TODO for all possible bias initializers to work properly, the whole bias vector should be initialized at once
                new_F = tf.concat([new_F, F_growth], axis=-1)
                new_b = tf.concat([self.b.value(), b_growth], axis=0)

                self.b = tf.Variable(name='b', initial_value=new_b, trainable=True)

        self.F = tf.Variable(name='F', initial_value=new_F, trainable=True)

        self.regularizer.grow(n_new_output_units)
        return n_new_output_units

    def mutate(self, mutation_strength):
        self.F.assign_add(tf.random.normal(self.F.shape, mean=0.0, stddev=mutation_strength))
        self.b.assign_add(tf.random.normal(self.b.shape, mean=0.0, stddev=mutation_strength))

    def set_regularization_penalty(self, regularization_penalty):
        if not self.fixed_size:
            self.regularizer.set_regularization_penalty(regularization_penalty)

    def set_regularization_method(self, regularization_method):
        if not self.fixed_size:
            self.regularizer.set_regularization_method(regularization_method)

    def get_param_string(self):
        param_string = ""
        # TODO
        return param_string


class Flatten(tf.keras.layers.Layer):
    def call(self, inputs, training=None):
        return tf.reshape(tf.transpose(inputs, perm=[0, 3, 1, 2]), (inputs.shape[0], -1))


class Sequential(tf.keras.Model):
    def __init__(self, layers):
        super().__init__()

        self.lrs = layers

    def call(self, inputs, training=None):
        x = inputs
        for layer in self.lrs:
            x = layer(x, training=training)
        return x

    def copy(self):
        copied_layers = list()
        for layer in self.lrs:
            if isinstance(layer, DASLayer):
                layer_copy = layer.copy()
            else:
                layer_copy = copy.deepcopy(layer)
            copied_layers.append(layer_copy)

        model_copy = Sequential(copied_layers)
        return model_copy

    def get_layer_input_shape(self, target_layer):
        if target_layer._input_shape is not None:
            return target_layer._input_shape

        input = np.random.normal(size=(1,) + self.lrs[0]._input_shape)
        for layer in self.lrs:
            if layer is target_layer:
                return tuple(input.shape[1:])
            input = layer(input)
        raise Exception("Layer not found in the model.")

    def get_layer_output_shape(self, target_layer):
        input = np.random.normal(size=(1,) + self.lrs[0]._input_shape)
        for layer in self.lrs:
            output = layer(input)
            if layer is target_layer:
                return tuple(output.shape[1:])
            input = output
        raise Exception("Layer not found in the model.")

    def get_layer_sizes(self):
        """
        Returns the sizes of all layers in the model, including the input and output layer.
        """
        layer_sizes = list()
        first_layer = True
        for l in range(len(self.lrs)):
            layer = self.lrs[l]
            if isinstance(layer, DASLayer) and not layer.fixed_size:
                layer_size = layer.get_size()
                if first_layer:
                    layer_sizes.append(layer_size[0])
                    first_layer = False
                layer_sizes.append(layer_size[1])
        return layer_sizes

    def get_hidden_layer_sizes(self):
        return self.get_layer_sizes()

    def get_regularization_penalty(self):
        # TODO improve
        dense_layers = [l for l in self.lrs if isinstance(l, Dense)]
        return dense_layers[-2].regularizer.regularization_penalty

    def set_regularization_penalty(self, regularization_penalty):
        for layer in self.lrs:
            if isinstance(layer, DASLayer) and not layer.fixed_size:
                layer.set_regularization_penalty(regularization_penalty)

    def set_regularization_method(self, regularization_method):
        for layer in self.lrs:
            if isinstance(layer, DASLayer) and not layer.fixed_size:
                layer.set_regularization_method(regularization_method)

    def prune(self, params):
        input_shape = self.get_layer_input_shape(self.lrs[0])
        n_input_units = input_shape[-1]
        active_units_indices = list(range(n_input_units))

        last_custom_layer = None
        for layer in self.lrs:
            if isinstance(layer, Flatten):
                convolutional_shape = self.get_layer_output_shape(last_custom_layer)
                active_units_indices = self.convert_channel_indices_to_flattened_indices(active_units_indices,
                                                                                         convolutional_shape)
            elif isinstance(layer, DASLayer):
                active_units_indices = layer.prune(params.pruning_threshold, active_units_indices)
                last_custom_layer = layer

    def grow(self, params):
        n_new_units = 0

        last_custom_layer = None
        for layer in self.lrs:
            if isinstance(layer, Flatten):
                convolutional_shape = self.get_layer_output_shape(last_custom_layer)
                n_new_units = n_new_units * convolutional_shape[0] * convolutional_shape[1]
            elif isinstance(layer, DASLayer):
                n_new_units = layer.grow(n_new_units, params.growth_percentage, min_new_units=params.min_new_neurons,
                                         scaling_factor=params.pruning_threshold)
                last_custom_layer = layer

    def mutate(self, mutation_strength):
        for layer in self.lrs:
            if isinstance(layer, DASLayer):
                layer.mutate(mutation_strength)

    @staticmethod
    def convert_channel_indices_to_flattened_indices(channel_indices, convolutional_shape):
        dense_indices = list()
        units_per_channel = convolutional_shape[0] * convolutional_shape[1]
        for channel_index in channel_indices:
            for iter in range(units_per_channel):
                dense_indices.append(channel_index * units_per_channel + iter)
        return dense_indices

    def print_neurons(self):
        for layer in self.lrs[:-1]:
            print(layer.get_param_string())

    def evaluate(self, params, summed_training_loss, summed_training_metric):
        # Calculate training loss and metric
        if summed_training_loss is not None:
            loss = summed_training_loss / params.x.shape[0]
        else:
            loss = None

        if summed_training_metric is not None:
            metric = summed_training_metric / params.x.shape[0]
        else:
            metric = None

        # Calculate val loss and metric
        summed_val_loss = 0
        summed_val_metric = 0
        n_val_instances = 0

        for step, (x_batch, y_batch) in enumerate(params.val_dataset):
            # y_pred = tf.reshape(self(x_batch, training=False), y_batch.shape)
            y_pred = self(x_batch, training=False)
            summed_val_loss += tf.reduce_sum(params.loss_fn(y_batch, y_pred))
            summed_val_metric += float(tf.reduce_sum(params.metric_fn(y_batch, y_pred)))
            n_val_instances += x_batch.shape[0]

        val_loss = summed_val_loss / n_val_instances
        val_metric = summed_val_metric / n_val_instances

        return loss, metric, val_loss, val_metric

    def list_params(self):
        trainable_count = np.sum([K.count_params(w) for w in self.trainable_weights])
        non_trainable_count = np.sum([K.count_params(w) for w in self.non_trainable_weights])
        total_count = trainable_count + non_trainable_count

        print('Total params: {:,}'.format(total_count))
        print('Trainable params: {:,}'.format(trainable_count))
        print('Non-trainable params: {:,}'.format(non_trainable_count))

        return total_count, trainable_count, non_trainable_count

    def print_epoch_statistics(self, params, summed_training_loss, summed_training_metric, message=None,
                               require_result=False):
        if not params.verbose:
            if require_result:
                return self.evaluate(params, summed_training_loss, summed_training_metric)
            else:
                return

        loss, metric, val_loss, val_metric = self.evaluate(params, summed_training_loss, summed_training_metric)

        if message is not None:
            print(message)

        print(
            f"loss: {loss} - metric: {metric} - val_loss: {val_loss} - val_metric: {val_metric} - penalty: {self.get_regularization_penalty()}")
        hidden_layer_sizes = self.get_hidden_layer_sizes()
        print(f"hidden layer sizes: {hidden_layer_sizes}, total units: {sum(hidden_layer_sizes)}")
        if params.print_neurons:
            self.print_neurons()

        if require_result:
            return loss, metric, val_loss, val_metric

    def update_history(self, params, loss, metric, val_loss, val_metric):
        params.history['loss'].append(float(loss))
        params.history['metric'].append(float(metric))
        params.history['val_loss'].append(float(val_loss))
        params.history['val_metric'].append(float(val_metric))
        params.history['hidden_layer_sizes'].append(self.get_hidden_layer_sizes())

    @staticmethod
    def prepare_datasets(x, y, batch_size, validation_data):
        train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
        train_dataset = train_dataset.shuffle(buffer_size=20000).batch(batch_size)
        val_dataset = tf.data.Dataset.from_tensor_slices(validation_data).batch(batch_size)
        return train_dataset.prefetch(tf.data.AUTOTUNE), val_dataset.prefetch(tf.data.AUTOTUNE)

    def manage_dynamic_regularization(self, params, val_loss):
        if val_loss >= params.best_conditional_val_loss * params.stall_coefficient:
            # Training is currently in stall
            if not params.training_stalled:
                penalty = self.get_regularization_penalty() * params.regularization_penalty_multiplier
                print("Changing penalty...")
                # TODO this must be modified, penalty can differ for each layer
                self.set_regularization_penalty(penalty)
                params.training_stalled = True
        else:
            params.best_conditional_val_loss = val_loss
            params.training_stalled = False

    def grow_wrapper(self, params):
        dynamic_reqularization_active = params.regularization_penalty_multiplier != 1.
        if dynamic_reqularization_active:
            loss, metric, val_loss, val_metric = self.print_epoch_statistics(params, None, None, "Before growing:",
                                                                             require_result=True)
            self.manage_dynamic_regularization(params, val_loss)
        else:
            self.print_epoch_statistics(params, None, None, "Before growing:")

        self.grow(params)
        print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
        self.print_epoch_statistics(params, None, None, "After growing:")

    def prune_wrapper(self, params, summed_loss, summed_metric):
        loss, metric, _, _ = self.print_epoch_statistics(params, summed_loss, summed_metric, "Before pruning:",
                                                         require_result=True)
        self.prune(params)
        _, _, val_loss, val_metric = self.print_epoch_statistics(params, None, None, "After pruning:",
                                                                 require_result=True)
        self.update_history(params, loss, metric, val_loss, val_metric)

    class ParameterContainer:
        def __init__(self, x, y, optimizer, batch_size, min_new_neurons, validation_data, pruning_threshold,
                     regularization_penalty_multiplier,
                     stall_coefficient, growth_percentage, mini_epochs_per_epoch, verbose, print_neurons,
                     use_static_graph, loss_fn, metric_fn):
            self.x = x
            self.y = y
            self.optimizer = optimizer
            self.batch_size = batch_size
            self.min_new_neurons = min_new_neurons
            self.validation_data = validation_data
            self.pruning_threshold = pruning_threshold
            self.regularization_penalty_multiplier = regularization_penalty_multiplier
            self.stall_coefficient = stall_coefficient
            self.growth_percentage = growth_percentage
            self.mini_epochs_per_epoch = mini_epochs_per_epoch
            self.verbose = verbose
            self.print_neurons = print_neurons
            self.use_static_graph = use_static_graph
            self.loss_fn = loss_fn
            self.metric_fn = metric_fn

            self.train_dataset, self.val_dataset = Sequential.prepare_datasets(x, y, batch_size, validation_data)
            self.history = self.prepare_history()

            self.best_conditional_val_loss = np.inf
            self.training_stalled = False

        @staticmethod
        def prepare_history():
            history = {
                'loss': list(),
                'metric': list(),
                'val_loss': list(),
                'val_metric': list(),
                'hidden_layer_sizes': list(),
            }
            return history

    def fit_single_step(self, x_batch, y_batch, optimizer, loss_fn, metric_fn):
        with tf.GradientTape() as tape:
            # y_pred = tf.reshape(self(x_batch, training=True), y_batch.shape)
            y_pred = self(x_batch, training=True)
            raw_loss = loss_fn(y_batch, y_pred)
            loss_value = tf.reduce_mean(raw_loss)
            loss_value += sum(self.losses)  # Add losses registered by model.add_loss

            loss = tf.reduce_sum(raw_loss)
            metric = float(tf.reduce_sum(metric_fn(y_batch, y_pred)))

        grads = tape.gradient(loss_value, self.trainable_variables)
        optimizer.apply_gradients(zip(grads, self.trainable_variables))

        return loss, metric

    def fit_single_epoch(self, params):
        summed_loss = 0
        summed_metric = 0

        for mini_epoch in range(params.mini_epochs_per_epoch):
            summed_loss = 0
            summed_metric = 0

            if params.use_static_graph:
                fit_single_step_function = tf.function(self.fit_single_step)
            else:
                fit_single_step_function = self.fit_single_step
            for step, (x_batch, y_batch) in enumerate(params.train_dataset):
                loss, metric = fit_single_step_function(x_batch, y_batch, params.optimizer, params.loss_fn,
                                                        params.metric_fn)
                summed_loss += loss
                summed_metric += metric

        return summed_loss, summed_metric

    def fit(self, x, y, optimizer, schedule, batch_size, min_new_neurons, validation_data, pruning_threshold=0.001,
            regularization_penalty_multiplier=1.,
            stall_coefficient=1, growth_percentage=0.2, mini_epochs_per_epoch=1, verbose=True, print_neurons=False,
            use_static_graph=True,
            loss_fn=tf.keras.losses.sparse_categorical_crossentropy,
            metric_fn=tf.keras.metrics.sparse_categorical_accuracy):
        params = self.ParameterContainer(x=x, y=y, optimizer=optimizer, batch_size=batch_size,
                                         min_new_neurons=min_new_neurons, validation_data=validation_data,
                                         pruning_threshold=pruning_threshold,
                                         regularization_penalty_multiplier=regularization_penalty_multiplier,
                                         stall_coefficient=stall_coefficient,
                                         growth_percentage=growth_percentage,
                                         mini_epochs_per_epoch=mini_epochs_per_epoch, verbose=verbose,
                                         print_neurons=print_neurons,
                                         use_static_graph=use_static_graph, loss_fn=loss_fn, metric_fn=metric_fn)
        self.build(x.shape)  # Necessary when verbose == False

        for epoch_no, epoch in enumerate(schedule):
            if verbose:
                print("##########################################################")
                print(f"Epoch {epoch_no + 1}/{len(schedule)}")

            self.set_regularization_penalty(epoch.regularization_penalty)
            self.set_regularization_method(epoch.regularization_method)

            if epoch.grow:
                self.grow_wrapper(params)

            summed_loss, summed_metric = self.fit_single_epoch(params)

            if epoch.prune:
                self.prune_wrapper(params, summed_loss, summed_metric)
            else:
                loss, metric, val_loss, val_metric = self.print_epoch_statistics(params, summed_loss, summed_metric,
                                                                                 require_result=True)
                self.update_history(params, loss, metric, val_loss, val_metric)

        return params.history

##### schedule.py

class Epoch:
    def __init__(self, grow, prune, regularization_penalty, regularization_method):
        self.grow = grow
        self.prune = prune
        self.regularization_penalty = regularization_penalty
        self.regularization_method = regularization_method

    def __str__(self):
        return f'{int(self.grow)}{int(self.prune)}{self.regularization_penalty}{self.regularization_method}'

    def __repr__(self):
        return self.__str__()


class DynamicEpoch(Epoch):
    def __init__(self, regularization_penalty, regularization_method):
        super().__init__(True, True, regularization_penalty, regularization_method)


class StaticEpoch(Epoch):
    def __init__(self, regularization_penalty, regularization_method):
        super().__init__(False, False, regularization_penalty, regularization_method)


class StaticEpochNoRegularization(StaticEpoch):
    def __init__(self):
        super().__init__(0., None)


class Schedule:
    def __init__(self, epochs):
        self.epochs = epochs

    def __iter__(self):
        return self.epochs.__iter__()

    def __len__(self):
        return len(self.epochs)

    def __str__(self):
        text = ''.join([str(epoch) for epoch in self.epochs])
        _hash = hashlib.sha1(text.encode('utf-8')).hexdigest()[:10]
        return f'{_hash}({self.epochs[0].regularization_penalty})'

    def __repr__(self):
        return self.__str__()

##### helpers.py

def get_statistics_from_history(history):
    best_epoch_number = np.argmax(history['val_metric'])
    best_loss = history['loss'][best_epoch_number]
    best_metric = history['metric'][best_epoch_number]
    best_val_loss = history['val_loss'][best_epoch_number]
    best_val_metric = history['val_metric'][best_epoch_number]
    best_hidden_layer_sizes = history['hidden_layer_sizes'][best_epoch_number]
    return best_loss, best_metric, best_val_loss, best_val_metric, best_hidden_layer_sizes


def get_statistics_from_histories(histories):
    best_val_losses = list()
    best_val_metrics = list()
    all_best_hidden_layer_sizes = list()

    for history in histories:
        _, _, best_val_loss, best_val_metric, best_hidden_layer_sizes = get_statistics_from_history(history)
        best_val_losses.append(best_val_loss)
        best_val_metrics.append(best_val_metric)
        all_best_hidden_layer_sizes.append(best_hidden_layer_sizes)

    mean_best_val_loss = np.mean(best_val_losses)
    mean_best_val_metric = np.mean(best_val_metrics)
    mean_best_hidden_layer_sizes = [np.mean(layer) for layer in list(zip(*all_best_hidden_layer_sizes))]

    return mean_best_val_loss, mean_best_val_metric, mean_best_hidden_layer_sizes


def cross_validate(train_fn, x, y, n_splits, random_state=42, *args, **kwargs):
    from sklearn.model_selection import KFold

    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    histories = list()
    for i, (train_index, test_index) in enumerate(kf.split(x)):
        xtrain, xtest = x[train_index], x[test_index]
        ytrain, ytest = y[train_index], y[test_index]

        history = train_fn(xtrain, ytrain, validation_data=(xtest, ytest), *args, **kwargs)
        histories.append(history)

        _, _, best_val_loss, best_val_metric, best_hidden_layer_sizes = get_statistics_from_history(history)
        print(
            f"Run {i} completed, best_val_loss: {best_val_loss}, best_val_metric: {best_val_metric}, best_hidden_layer_sizes: {best_hidden_layer_sizes}")

    mean_best_val_loss, mean_best_val_metric, mean_best_hidden_layer_sizes = get_statistics_from_histories(histories)
    print(f'mean_best_val_loss: {mean_best_val_loss}')
    print(f'mean_best_val_metric: {mean_best_val_metric}')
    print(f'mean_best_hidden_layer_sizes: {mean_best_hidden_layer_sizes}')

    return histories, mean_best_hidden_layer_sizes


def hyperparameter_search(train_fn, x, y, validation_data, *args, **kwargs):
    from itertools import product

    all_params = [*args] + list(kwargs.values())
    histories = list()

    best_overall_val_loss = np.inf
    best_overall_val_metric = None
    best_overall_combination = None

    for combination in product(*all_params):
        combination_args = combination[:len(args)]

        combination_kwargs_values = combination[len(args):]
        combination_kwargs = dict(zip(kwargs.keys(), combination_kwargs_values))

        history = train_fn(x, y, validation_data, *combination_args, **combination_kwargs)
        history['parameters'] = combination
        histories.append(history)

        _, _, best_val_loss, best_val_metric, best_hidden_layer_sizes = get_statistics_from_history(history)
        print(
            f"Run with parameters {combination} completed, best_val_loss: {best_val_loss}, best_val_metric: {best_val_metric}, best_hidden_layer_sizes: {best_hidden_layer_sizes}")

        if best_val_loss < best_overall_val_loss:
            best_overall_val_loss = best_val_loss
            best_overall_val_metric = best_val_metric
            best_overall_combination = combination

    print(f'Best overall combination: {best_overall_combination}, val_metric: {best_overall_val_metric}')

    return histories, best_overall_combination



def merge_histories(history1, history2):
    merged_history = dict()
    for key in history1.keys():
        merged_history[key] = history1[key] + history2[key]
    return merged_history


def get_convolutional_model(x, layer_sizes, output_neurons=10):
    dropout_rate = 0.3
    model = Sequential([
        Conv2D(layer_sizes[0], filter_size=(3, 3), activation='selu', padding='SAME', kernel_initializer='lecun_normal', input_shape=x[0, :, :, :].shape),
        BatchNormalization(),
        tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
        tf.keras.layers.Dropout(dropout_rate),
        Conv2D(layer_sizes[1], filter_size=(3, 3), activation='selu', padding='SAME', kernel_initializer='lecun_normal'),
        BatchNormalization(),
        tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
        tf.keras.layers.Dropout(dropout_rate),
        Conv2D(layer_sizes[2], filter_size=(3, 3), activation='selu', padding='SAME', kernel_initializer='lecun_normal'),
        BatchNormalization(),
        tf.keras.layers.Dropout(dropout_rate),
        Conv2D(layer_sizes[3], filter_size=(3, 3), activation='selu', padding='SAME', kernel_initializer='lecun_normal'),
        BatchNormalization(),
        tf.keras.layers.Dropout(dropout_rate),
        Conv2D(layer_sizes[4], filter_size=(3, 3), strides=(2, 2), activation='selu', padding='SAME', kernel_initializer='lecun_normal'),
        BatchNormalization(),
        # tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
        tf.keras.layers.Dropout(dropout_rate),
        Flatten(),
        Dense(layer_sizes[5], activation='selu', kernel_initializer='lecun_normal'),
        BatchNormalization(),
        tf.keras.layers.Dropout(dropout_rate),
        Dense(layer_sizes[6], activation='selu', kernel_initializer='lecun_normal'),
        BatchNormalization(),
        tf.keras.layers.Dropout(dropout_rate),
        Dense(output_neurons, activation='softmax', fixed_size=True),
        
        # Conv2D(layer_sizes[1], filter_size=(3, 3), activation='selu', strides=(2, 2), padding='SAME',
        #        kernel_initializer='lecun_normal'),
        # tf.keras.layers.Dropout(0.2),
        # Conv2D(layer_sizes[2], filter_size=(3, 3), activation='selu', strides=(1, 1), padding='SAME',
        #        kernel_initializer='lecun_normal'),
        # Conv2D(layer_sizes[3], filter_size=(3, 3), activation='selu', strides=(2, 2), padding='SAME',
        #        kernel_initializer='lecun_normal'),
        # tf.keras.layers.Dropout(0.5),
        # Flatten(),
        # Dense(layer_sizes[4], activation='selu', kernel_initializer='lecun_normal'),
        # Dense(output_neurons, activation='softmax', fixed_size=True),
    ])
    return model


def get_dense_model(x, layer_sizes):
    layers = list()

    layers.append(
        Dense(layer_sizes[0], activation='selu', kernel_initializer='lecun_normal', input_shape=x[0, :].shape))
    for layer_size in layer_sizes[1:]:
        layers.append(Dense(layer_size, activation='selu', kernel_initializer='lecun_normal'))
    layers.append(Dense(1, activation=None, kernel_initializer='lecun_normal', fixed_size=True))

    model = Sequential(layers)
    return model


def train_fn_conv(x, y, validation_data, learning_rate, schedule, layer_sizes, output_neurons=10, min_new_neurons=20,
                  growth_percentage=0.2, verbose=False, use_static_graph=True, batch_size=128):
    model = get_convolutional_model(x, layer_sizes, output_neurons)

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    history = model.fit(x=x, y=y, optimizer=optimizer, schedule=schedule, batch_size=batch_size,
                        min_new_neurons=min_new_neurons,
                        validation_data=validation_data, growth_percentage=growth_percentage, verbose=verbose,
                        use_static_graph=use_static_graph)

    return history


def squared_error(y_true, y_pred):
    return (y_true - y_pred) ** 2


def negative_squared_error(y_true, y_pred):
    return - ((y_true - y_pred) ** 2)


def train_fn_dense(x, y, validation_data, learning_rate, schedule, layer_sizes, min_new_neurons=20,
                   growth_percentage=0.2, verbose=False, use_static_graph=True, batch_size=128):
    model = get_dense_model(x, layer_sizes)

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    history = model.fit(x=x, y=y, optimizer=optimizer, schedule=schedule, batch_size=batch_size,
                        min_new_neurons=min_new_neurons,
                        validation_data=validation_data, growth_percentage=growth_percentage, verbose=verbose,
                        use_static_graph=use_static_graph,
                        loss_fn=squared_error, metric_fn=negative_squared_error)

    return history


def early_stopping_conv(x, y, validation_data, learning_rate, schedule, layer_sizes, output_neurons=10,
                        min_new_neurons=20,
                        growth_percentage=0.2, verbose=False, use_static_graph=True, batch_size=128, max_setbacks=2):
    assert len(schedule) == 1

    model = get_convolutional_model(x, layer_sizes, output_neurons)

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    history = Sequential.ParameterContainer.prepare_history()

    best_val_loss = np.inf
    n_setbacks = 0
    while True:
        epoch_history = model.fit(x=x, y=y, optimizer=optimizer, schedule=schedule, batch_size=batch_size,
                                  min_new_neurons=min_new_neurons,
                                  validation_data=validation_data, growth_percentage=growth_percentage, verbose=verbose,
                                  use_static_graph=use_static_graph)
        history = merge_histories(history, epoch_history)
        val_loss = epoch_history['val_loss'][-1]
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            n_setbacks = 0
        else:
            n_setbacks += 1
            if n_setbacks > max_setbacks:
                break

    return history


def early_stopping_dense(x, y, validation_data, learning_rate, schedule, layer_sizes, output_neurons=1,
                         min_new_neurons=20,
                         growth_percentage=0.2, verbose=False, use_static_graph=True, batch_size=128, max_setbacks=2):
    assert len(schedule) == 1
    assert output_neurons == 1

    model = get_dense_model(x, layer_sizes)

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    history = Sequential.ParameterContainer.prepare_history()

    best_val_loss = np.inf
    n_setbacks = 0
    while True:
        epoch_history = model.fit(x=x, y=y, optimizer=optimizer, schedule=schedule, batch_size=batch_size,
                                  min_new_neurons=min_new_neurons,
                                  validation_data=validation_data, growth_percentage=growth_percentage, verbose=verbose,
                                  use_static_graph=use_static_graph,
                                  loss_fn=squared_error, metric_fn=negative_squared_error)
        history = merge_histories(history, epoch_history)
        val_loss = epoch_history['val_loss'][-1]
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            n_setbacks = 0
        else:
            n_setbacks += 1
            if n_setbacks > max_setbacks:
                break

    return history


def layer_sizes_join_postprocess(args, kwargs):
    kwargs['layer_sizes'] = kwargs['layer_1_size'], kwargs['layer_2_size'], kwargs['layer_3_size'], kwargs[
        'layer_4_size'], kwargs['layer_5_size']
    del kwargs['layer_1_size'], kwargs['layer_2_size'], kwargs['layer_3_size'], kwargs['layer_4_size'], kwargs[
        'layer_5_size']
    return args, kwargs

In [9]:
cifar100 = get_cifar_100_dataset()

In [10]:
alexnet_layer_sizes = [96, 256, 384, 384, 256, 4096, 4096]
layer_sizes = [size // 3 for size in alexnet_layer_sizes]
layer_sizes

[32, 85, 128, 128, 85, 1365, 1365]

In [139]:
class BatchNormalization(DASLayer):
    def __init__(self, momentum=0.99, epsilon=0.001, input_shape=None):
        super().__init__(input_shape, fixed_size=True)

        self.momentum = momentum
        self.epsilon = epsilon

    def copy(self):
        raise NotImplementedError
    
    def build(self, input_shape):
        if self._built:
            return
        
        self.offset = tf.Variable(
            name='offset',
            initial_value=tf.zeros(input_shape[1:]),
            trainable=True)
        self.scale = tf.Variable(
            name='scale',
            initial_value=tf.ones(input_shape[1:]),
            trainable=True)
        self.moving_mean = tf.Variable(
            name='moving_mean',
            initial_value=tf.zeros(input_shape[1:]),
            trainable=False)
        self.moving_variance = tf.Variable(
            name='moving_variance',
            initial_value=tf.ones(input_shape[1:]),
            trainable=False)

        self._built = True

    def call(self, inputs, training=None):
        # print(inputs.shape)
        # print(self.offset.shape)
        # print(self.scale.shape)
        # print(self.moving_mean.shape)
        # print(self.moving_variance.shape)
        if training:
            mean, variance = tf.nn.moments(inputs, axes=[0])
            self.moving_mean.assign(self.moving_mean * self.momentum + mean * (1 - self.momentum))
            self.moving_variance.assign(self.moving_variance * self.momentum + variance * (1 - self.momentum))
        else:
            mean = self.moving_mean
            variance = self.moving_variance
        return tf.nn.batch_normalization(inputs, mean, variance, self.offset, self.scale, self.epsilon)

    # def get_size(self):
    #     return self.F.shape[-2], self.F.shape[-1]

    def prune(self, threshold, active_input_units_indices):
        new_offset = tf.gather(self.offset.value(), active_input_units_indices, axis=-1)
        new_scale = tf.gather(self.scale.value(), active_input_units_indices, axis=-1)
        new_moving_mean = tf.gather(self.moving_mean.value(), active_input_units_indices, axis=-1)
        new_moving_variance = tf.gather(self.moving_variance.value(), active_input_units_indices, axis=-1)
        self.offset = tf.Variable(name='offset', initial_value=new_offset, trainable=True)
        self.scale = tf.Variable(name='scale', initial_value=new_scale, trainable=True)
        self.moving_mean = tf.Variable(name='moving_mean', initial_value=new_moving_mean, trainable=False)
        self.moving_variance = tf.Variable(name='moving_variance', initial_value=new_moving_variance, trainable=False)
        return active_input_units_indices

    def grow(self, n_new_input_units, percentage, min_new_units, scaling_factor):
        if n_new_input_units > 0:
            growth_shape = list(self.offset.shape)
            growth_shape[-1] = n_new_input_units
            new_offset = tf.concat([self.offset, tf.zeros(growth_shape)], axis=-1)
            new_scale = tf.concat([self.scale, tf.ones(growth_shape)], axis=-1)
            new_moving_mean = tf.concat([self.moving_mean, tf.zeros(growth_shape)], axis=-1)
            new_moving_variance = tf.concat([self.moving_variance, tf.ones(growth_shape)], axis=-1)
            self.offset = tf.Variable(name='offset', initial_value=new_offset, trainable=True)
            self.scale = tf.Variable(name='scale', initial_value=new_scale, trainable=True)
            self.moving_mean = tf.Variable(name='moving_mean', initial_value=new_moving_mean, trainable=False)
            self.moving_variance = tf.Variable(name='moving_variance', initial_value=new_moving_variance, trainable=False)
        return n_new_input_units

In [142]:
%%time

schedule = [DynamicEpoch(0.0001, 'weighted_l1')] * 20 + [StaticEpochNoRegularization()] * 20
train_fn_conv(x=cifar100.X_train_norm, y=cifar100.y_train, 
              validation_data=(cifar100.X_test_norm, cifar100.y_test), learning_rate=0.0002, 
              schedule=schedule, layer_sizes=layer_sizes, output_neurons=100, verbose=True)

##########################################################
Epoch 1/40
Before growing:
loss: None - metric: None - val_loss: 5.3226447105407715 - val_metric: 0.0109 - penalty: 0.0001
hidden layer sizes: [32, 85, 128, 128, 85, 1365], total units: 1823
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
After growing:
loss: None - metric: None - val_loss: 5.322633266448975 - val_metric: 0.0109 - penalty: 0.0001
hidden layer sizes: [52, 105, 153, 153, 105, 1638], total units: 2206
Before pruning:
loss: 4.945155620574951 - metric: 0.0459199994802475 - val_loss: 4.541201114654541 - val_metric: 0.0172 - penalty: 0.0001
hidden layer sizes: [52, 105, 153, 153, 105, 1638], total units: 2206
After pruning:
loss: None - metric: None - val_loss: 4.542862415313721 - val_metric: 0.0263 - penalty: 0.0001
hidden layer sizes: [33, 85, 128, 135, 105, 530], total units: 1016
##########################################################
Epoch 2/40
Before growing:
loss: None - metric: None - val_loss: 4.542862415313721 - val

KeyboardInterrupt: 

In [77]:
%%time

schedule = Schedule([StaticEpochNoRegularization()] * 40)
train_fn_conv(x=cifar100.X_train_norm, y=cifar100.y_train, 
              validation_data=(cifar100.X_test_norm, cifar100.y_test), learning_rate=0.001, 
              schedule=schedule, layer_sizes=layer_sizes, output_neurons=100, verbose=True)

##########################################################
Epoch 1/40
loss: 4.131021022796631 - metric: 0.1186399981379509 - val_loss: 5.007061958312988 - val_metric: 0.1172 - penalty: 0.0
hidden layer sizes: [32, 85, 128, 128, 85, 1365], total units: 1823
##########################################################
Epoch 2/40
loss: 3.346764326095581 - metric: 0.21362000703811646 - val_loss: 3.8107926845550537 - val_metric: 0.2324 - penalty: 0.0
hidden layer sizes: [32, 85, 128, 128, 85, 1365], total units: 1823
##########################################################
Epoch 3/40


KeyboardInterrupt: 

In [18]:
%%time

schedule = Schedule([StaticEpochNoRegularization()] * 40)
train_fn_conv(x=cifar100.X_train_norm, y=cifar100.y_train, 
              validation_data=(cifar100.X_test_norm, cifar100.y_test), learning_rate=0.001, 
              schedule=schedule, layer_sizes=layer_sizes, output_neurons=100, verbose=True)

##########################################################
Epoch 1/40
loss: 4.164563179016113 - metric: 0.1171799972653389 - val_loss: 4.554275989532471 - val_metric: 0.1434 - penalty: 0.0
hidden layer sizes: [32, 85, 128, 128, 85, 1365, 1365], total units: 3188
##########################################################
Epoch 2/40
loss: 3.3377609252929688 - metric: 0.21318000555038452 - val_loss: 3.5045278072357178 - val_metric: 0.2326 - penalty: 0.0
hidden layer sizes: [32, 85, 128, 128, 85, 1365, 1365], total units: 3188
##########################################################
Epoch 3/40
loss: 2.9550507068634033 - metric: 0.2761400043964386 - val_loss: 3.0465550422668457 - val_metric: 0.3036 - penalty: 0.0
hidden layer sizes: [32, 85, 128, 128, 85, 1365, 1365], total units: 3188
##########################################################
Epoch 4/40
loss: 2.701124429702759 - metric: 0.3158800005912781 - val_loss: 2.7370235919952393 - val_metric: 0.3477 - penalty: 0.0
hidden layer siz

{'loss': [4.164563179016113,
  3.3377609252929688,
  2.9550507068634033,
  2.701124429702759,
  2.4953885078430176,
  2.3433902263641357,
  2.222804069519043,
  2.1114468574523926,
  2.018197536468506,
  1.9336079359054565,
  1.860167384147644,
  1.7841598987579346,
  1.7304868698120117,
  1.6743803024291992,
  1.6206616163253784,
  1.5723124742507935,
  1.5143556594848633,
  1.4662446975708008,
  1.4362691640853882,
  1.3897252082824707,
  1.3558814525604248,
  1.3159631490707397,
  1.282547116279602,
  1.2527399063110352,
  1.2210615873336792,
  1.2001926898956299,
  1.1615076065063477,
  1.130737066268921,
  1.1033198833465576,
  1.0799574851989746,
  1.0673094987869263,
  1.0290958881378174,
  1.008555293083191,
  0.9994961619377136,
  0.9761124849319458,
  0.9613146185874939,
  0.9395744800567627,
  0.9266930222511292,
  0.9204065799713135,
  0.8923746943473816],
 'metric': [0.1171799972653389,
  0.21318000555038452,
  0.2761400043964386,
  0.3158800005912781,
  0.3582000136375427