# Definition and training of the SpotNet computational graph

In this code section, we define a network that approximately implements a number of iterations of the accelerated proximal gradient algorithm for detecting spots in FluoroSpot. From there on, we will train the network to get closer to the optimal than the actual number of iterations of the algorithm.  

<a id="network_figure"></a>

|[![network picture][network]](pics/network.pdf)|
| --:-- |
|<center>Fig 1: *Computation graph of SpotNet, extracted from the accelerated proximal gradient algorithm applied to regularized least-squares recovery of a FluoroSpot image using our mathematical model (see the paper and [`data_simulation.ipynb`](./data_simulation.ipynb) and \[[1][1],[2][2]\] for details). *</center>|

# References

[1]: https://arxiv.org/abs/1710.01604
[2]: https://arxiv.org/abs/1710.01622

[\[1\]][1]: Pol del Aguila Pla and Joakim Jaldén, "Cell Detection by Functional Inverse Diffusion and Group Sparsity − Part I: Modeling and Inverse Problems", _IEEE Transactions on Signal Processing_, vol. 66, no. 20, pp. 5407--5421, 2018  
[\[2\]][2]: Pol del Aguila Pla and Joakim Jaldén, "Cell Detection by Functional Inverse Diffusion and Group Sparsity − Part II: Proximal Optimization and Performance Evaluation", _IEEE Transactions on Signal Processing_, vol. 66, no. 20, pp. 5422--5437, 2018  

# Index

1. Definition of computational graph of SpotNet
    1. [Importing libraries](#libs) to build and initialize the network
    2. [Library of non-linear functions](#phis) $\varphi_\lambda(\cdot)$
    3. Library for initial [kernel generation](#kers)
    4. Generator of [single hidden layers](#hidden) of SpotNet
    5. Definition of the [full SpotNet computational graph](#spotnet)
    6. Definition of the [training strategy](#training_strategy) for SpotNet
2. Training of SpotNet
    1. [Loading of the training data-base](#data) and setting of parameters 
    2. Construction of the [computational graph](#session)
    3. Running the [training](#training) and storing its results
    

[network]: pics/network.png

# Definition of computational graph of SpotNet

<a id="libs"></a>
## Importing libraries to build and initialize the network

In [None]:
# Tensorflow for differentiable programming
import tensorflow as tf
# Numpy for array management
import numpy as np

<a id="phis"></a>
## Library of non-linear functions $\varphi_\lambda(\cdot)$

These are linked to specific regularizers through proximal operators, in the sense that $\varphi_\lambda(x) = \mbox{prox}_{\lambda\mathcal{R}}(x)$, where $\mathcal{R}(x)$ is a regularizer that promotes certain characteristics of the three-dimensional array $x\in\mathbb{R}^{M\times N \times K}$, while 
$$
    \mbox{prox}_{\lambda\mathcal{R}}(x) = \arg \min_y \left\lbrace \mathcal{R}(x) + \frac{1}{2\lambda}\left\|y-x\right\| \right\rbrace.\,\,(1)
$$

Here, we include the cases 
$$
\mathcal{R}(x)= \mathcal{R}_\mathrm{s}(x) = \sum_{k=1}^{K} \left\| x_k \right\|_1 = \sum_{m,n,k=1}^{M,N,K} \left|x_{m,n,k}\right| \, \,\,(2)
$$
and $\mathcal{R}(x)= \mathcal{R}_\mathrm{s}(x) + \delta_{\mathbb{R}^{M,N,K}_+}(x)\, \,\,(3)$ to promote sparsity and non-negative sparsity, respectively, but also
$$
\mathcal{R}(x)= \mathcal{R}_\mathrm{gs}(x) = \sum_{m,n=1}^{M,N} \sqrt{ \sum_{k=1}^K x_k^2[i,j] }\,, \,\,(4)
$$
and
$\mathcal{R}(x)= \mathcal{R}_\mathrm{gs}(x) + \delta_{\mathbb{R}^{M,N,K}_+}(x)\, \,\,(5)$ to promote group-sparsity and non-negative group-sparsity with groups corresponding to each specific location in the image (see \[[1][1],[2][2]\] to understand why this would be a good idea for the problem addressed in the paper). Here, $\delta_{\mathbb{R}^{M,N,K}_+}(x)$ is the non-negativity
$(\infty,0)$-indicator function of the non-negative half-space, i.e., 
$$
    \delta_{\mathbb{R}^{M,N,K}_+}(x) = \begin{cases} \infty & \mbox{ if }x_{m,n,k}\geq 0, \forall m,n,k \\
    0 & \mbox{ otherwise,} \end{cases}
$$
which is a common construct to include constraints in the regularization term in non-smooth convex optimization.

[1]: https://arxiv.org/abs/1710.01604
[2]: https://arxiv.org/abs/1710.01622

In [None]:
# Prox of Sparsity (l1 norm, (2))
def soft_threshold( x, lam = 1.0 ):
    b = tf.Variable( lam )
    return (tf.nn.relu( x + b ) + tf.nn.relu( -x-b )) 
# Prox of Group sparsity (l2 norm along groups summed for all groups, (4)) 
# NWHC format with the same location in different channels belonging to the same group
def soft_group_threshold( x, lam = 1.0 ):
    b = tf.Variable(  lam )
    norm = tf.reduce_sum( x**2, 3 )
    return tf.expand_dims( tf.maximum( 1 - b / norm, 0 ) , axis = 3 ) * x
# Prox of Non-negative Sparsity (l1 norm + non-negative infinity-0 indicator, (3))
def s_nonneg( x, lam = 1.0 ):
    b = tf.Variable( lam )
    return tf.nn.relu( x - b )
# Prox of Non-negative Group Sparsity 
# (l2 norm along groups summed for all groups + non-negative infinity-0 indicator, (5))
def gs_nonneg( x ):
    return soft_group_threshold( tf.maximum( x, 0 ) )

<a id="kers"></a>
## Initial for initial kernel generation

Kernel generation functions. Conceptually, one would initially start by using the same kernels than the algorithm would. However, since the ultimate objective is to reduce computational cost, we generate random kernels of a given size, much smaller than that used by the algorithm.

In [None]:
# Generate non-negative random kernels of a given size. 
# The truncated normal with standard deviation 0.01 was selected for convenience.
def generate_random_kernels( kernel_size, nrof_kernels ):
        
    kernels = tf.Variable( tf.truncated_normal( [kernel_size,
                                                 kernel_size,
                                                 nrof_kernels,
                                                 1], stddev = 0.01 ) )

    return kernels

# As an alternative: generate kernels given by the accelerated proximal gradient for the mathematical model
# for FluoroSpot (very large)

# Import kernel generation from corresponding python module in our local folders
import ker
# Set sigma limits to match those obtained in data_generation.ipynb
sigma_lims = np.array(
                     [ 0.        ,  2.03803742,  4.07607485,  6.44484021,  8.64666049,
                       10.78428037, 12.88968043, 14.97645529, 17.17280857, 19.33452064,
                       21.47205663, 23.59198881, 25.77936086, 27.94416127, 30.09126194,
                       32.22420107, 34.34561714, 36.5144422 , 38.66904129, 40.8116676 ,
                       42.94411325, 45.1138815 , 47.27192547, 49.41978111, 51.55872171,
                       53.68981279, 55.8511504 , 58.00356193, 60.14800566, 62.28530458,
                       64.44840214 ] )
# Produce function to generate kernels according to the FluoroSpot model
generate_diffusion_kernels = lambda: ker.obtain_discrete_kernels( sigma_lims )

<a id = "hidden"></a>
## Definition of a single hidden layers of SpotNet

To be compared with a generic iteration / hidden layer in [Fig. 1](#network_figure). Basically defining an iteration of the accelerated proximal gradient algorithm for the problem at hand.

In [None]:
def hidden_layer( x_prev, z_prev, source, kernels, non_linearity = soft_threshold ):
    
    # Create variable for tunable kernels
    h = tf.Variable( kernels )
    # Depthwise convolution
    conv = tf.nn.depthwise_conv2d( input = z_prev,
                                   filter = h,
                                   strides = [1, 1, 1, 1],
                                   padding = 'SAME' )
    # Sum of all convolutions and difference to observed image
    u = tf.reduce_sum( conv, 3 ) - source
    u = tf.expand_dims( u, 3 )
    # Defining the matched filters
    h_m = tf.transpose( tf.reverse( tf.reverse( h, [0] ), [1] ), [0, 1, 3, 2] )
    # Convolve results of sum by each of the different matched filters
    conv_m = tf.nn.conv2d( input = u,
                           filter = h_m,
                           strides = [1, 1, 1, 1],
                           padding = 'SAME' )
    # Pass through non-linearity
    x_current = non_linearity( z_prev - conv_m )
    # Perform acceleration step (skip connections)
    alpha = tf.Variable( tf.constant( 0.5 ) )
    z_current = x_current + alpha * (x_current - x_prev)
        
    return (x_current, z_current, h)

<a id = "spotnet"></a>
## Definition of the full SpotNet computational graph
To be compared with the overall structure in [Fig. 1](#network_figure).

In [None]:
def spot_finding_net( nrof_images,
                      image_height, 
                      image_width, 
                      kernels,
                      nrof_hidden_layers,
                      non_linearity = soft_threshold ):
    
    # Seed for random initializations
    tf.set_random_seed(8888)
    
    # Placeholder for final FluoroSpot images
    source = tf.placeholder( tf.float32, [ nrof_images, image_height, image_width ] )
    
    # Create the input half-layer (equivalent to first iteration with initialization of 0)
    # Convolve by each of the matched filters and pass through non-linearity
    h_m = tf.transpose( tf.reverse( tf.reverse( kernels, [0] ), [1] ), [0, 1, 3, 2] )
    x = [ non_linearity( tf.nn.conv2d( input = tf.expand_dims( source, 3 ),
                                       filter = h_m,
                                       strides = [1, 1, 1, 1],
                                       padding = 'SAME' ) ) ]
    # Acceleration step (skip connection)
    alpha = tf.Variable( tf.constant( 0.5 ) )
    z = [ (1 + alpha) * x[-1] ]
    h = []
    
    # Create hidden layers as a feedforward convolutional graph. Append each hidden 
    # layer to a list and connect next layer to the last element of this list.
    for i in range( nrof_hidden_layers ):
        x_current, z_current, h_current = hidden_layer( x[-1], z[-1], source, kernels, non_linearity )
        
        x.append(x_current)
        z.append(z_current)
        h.append(h_current)
        
    # The output is simply the last x
    output = x[-1]
    
    # Placeholder for the target x
    target = tf.placeholder( tf.float32, [ nrof_images, image_height, image_width, nrof_kernels ] )
    
    # Mean squared error for the prediction of x
    loss = tf.reduce_mean( ( output - target ) ** 2 )
    
    # Return input and target placeholders, loss to optimize, and output to obtain
    return (source, loss, target, output)

<a id = "training_strategy"></a>
## Definition of the training strategy for SpotNet

In [None]:
# Create and return the optimizer
def network_training( loss, learning_rate ):
    return tf.train.AdamOptimizer( learning_rate ).minimize( loss )

<a id = "training_sec"></a>
# Training of SpotNet

<a id = "data"></a>
## Loading of the training data-base and setting of parameters

In [None]:
# Network parameters
nrof_hidden_layers = 3

# Learning process parameters
learning_rate = 1e-3
batch_size = 1
nrof_train_steps = 400000
nrof_cells = 1250

# Relevant directories
results_dir = 'results/'
data_dir = 'sim_data/'

# Load dataset
data = np.load( data_dir + 'result_' + str( nrof_cells ) + '_cells_10_images.npy' )[()]

# Extract images and xs
images = data['fluorospot']
psdrs  = data['psdrs']

# Extract shape parameters
nrof_images, image_height, image_width = images.shape
_, _, _, nrof_kernels = psdrs.shape

# Split dataset for training and validation
nrof_training_samples = int( 0.7 * nrof_images )
train_images, train_psdrs = (images[: nrof_training_samples, ...], psdrs[: nrof_training_samples, ...])
val_images, val_psdrs = (images[nrof_training_samples :, ...], psdrs[nrof_training_samples :, ...])

print( 'Dataset: Num images: %d, Image height: %d, Image width: %d, Num cells: %d, Num kernels: %d'%( nrof_images, 
                                                                                                      image_height, 
                                                                                                      image_width,
                                                                                                      nrof_cells,
                                                                                                      nrof_kernels ) )

<a id="session"></a>
## Construction of the computational graph

In [None]:
# Get kernels to initialize the network
kernels = generate_random_kernels(5, 30) 

# Build computational graph
with tf.name_scope( 'SpotNet' ) as scope:
    source, loss, target, output = spot_finding_net( batch_size,
                                                     image_height, 
                                                     image_width, 
                                                     kernels,
                                                     nrof_hidden_layers,
                                                     soft_threshold )

# Build training computational graph
with tf.name_scope( 'Training' ) as scope:
    train_step = network_training( loss, learning_rate )

# Create TensorFlow session to run the computational graph
nrof_gpus = 1 
config = tf.ConfigProto( device_count = {'GPU': nrof_gpus} )
sess = tf.Session( config = config )
sess.run( tf.global_variables_initializer( ) )

<a id="training"></a>
## Running training and storing its results

In [None]:
# Time library to time training
import time

# Inform user
print( 'Training the network' )

# Space to store loss values
train_loss_record = np.empty( (0, ) )
val_loss_record  = np.empty( (0, ) )
iterations_record = np.empty( (0, ) )

# Indices that will be picked at random at each iteration
indices = np.arange( train_images.shape[0] )

# Seed for reproducibility
np.random.seed( 8888 )

# Train the network
# Start timing
start = time.time()
for training_iteration in range( nrof_train_steps + 1 ):
    # Shuffle indices    
    np.random.shuffle( indices )
    # Extract current batch
    batch_indices = np.take( indices, np.arange( batch_size ), mode = 'wrap' )
    batch_input = train_images[ batch_indices, ... ]
    batch_target = train_psdrs[ batch_indices, ... ]
    # Run a training step
    _, train_loss = sess.run( [ train_step, loss ], feed_dict = { source: batch_input, 
                                                                  target: batch_target } )
    # Every 100th of total iterations, record progress in terms of validation loss    
    if training_iteration % (nrof_train_steps / 100) == 0:
        # Sum validation losses
        val_loss = 0
        for val_image_index in range( val_images.shape[0] ):
            val_loss += sess.run( loss, feed_dict = { source: np.expand_dims( val_images[ val_image_index, ... ], axis = 0 ), 
                                                      target: np.expand_dims( val_psdrs[  val_image_index, ... ], axis = 0 ) } )
        # Divide by number of validation images
        val_loss = val_loss / val_images.shape[0]
        # Store losses
        val_loss_record = np.append( val_loss_record, val_loss )
        train_loss_record = np.append( train_loss_record, train_loss )
        iterations_record = np.append( iterations_record, training_iteration )
        # Inform the user and store the model if it is the best one yet
        if (val_loss == val_loss_record.min() and training_iteration > 10):
            print( 'Train step %d, Batch loss: %0.4f, Test loss: %0.4f, Elapsed: %ds. Best in validation yet! Storing.'%( 
                                                                                    training_iteration, 
                                                                                    train_loss, 
                                                                                    val_loss, 
                                                                                    time.time() - start ) ) 
            tf.saved_model.simple_save( sess,
                results_dir + 'trained_spotnet_' + str( training_iteration ) + '_train-steps_5_kersize/',
                inputs = {'image': source},
                outputs = {'psdr': output} )
        else:
            print( 'Train step %d, Batch loss: %0.4f, Test loss: %0.4f, Elapsed: %ds.'%( 
                                                                                    training_iteration, 
                                                                                    train_loss, 
                                                                                    val_loss, 
                                                                                    time.time() - start ) )

# Store loss values
