# Definition and testing of the SpotNet structure

In this code section, we define a network that initially implements a number of iterations of the accelerated proximal gradient algorithm for detecting spots in FluoroSpot. From there on, we will train the network to get closer to the optimal than the actual number of iterations of the algorithm. For pictorical reference, the structure of the network extracted from the algorithm is shown below.

[![network picture][network]](../../paper/figs/network.pdf)

[network]: ../../paper/figs/network.png

## Import libraries to build and initialize the network

In [None]:
# Tensorflow for differentiable programming
import tensorflow as tf
# Numpy for array management
import numpy as np

## Library of non-linear functions $\varphi_\lambda(\cdot)$

These are linked to specific regularizers through proximal operators, in the sense that in general $\varphi_\lambda(x) = \mbox{prox}_{\lambda\mathcal{R}}(x)$, where $\mathcal{R}(x)$ is a regularizer that promotes certain characteristics of $x$, while 
$$
    \mbox{prox}_{\lambda\mathcal{R}}(x) = \arg \min_y \left\lbrace \mathcal{R}(x) + \frac{1}{2\lambda}\left\|y-x\right\| \right\rbrace\,  \,\,(1)
$$

Here, we include the cases 
$$
\mathcal{R}(x)= \mathcal{R}_\mathrm{s}(x) = \sum_{k=1}^{K} \left\| x_k \right\|_1 = \sum_{m,n,k}^{M,N,K} |x_{m,n,k}| \, \,\,(2)
$$
and $\mathcal{R}(x)= \mathcal{R}_\mathrm{s}(x) + \delta_{\mathbb{R}^{M,N,K}_+}(x)\, \,\,(3)$ to promote sparsity and non-negative sparsity, respectively, but also
$$
\mathcal{R}(x)= \mathcal{R}_\mathrm{gs}(x) = \sum_{g=1}^{G} \sqrt{ \sum_{(i,j,k)\in \mathcal{G}_g}\!\!\!\!\! x_k^2[i,j] } \, \,\,(4)
$$
and
$\mathcal{R}(x)= \mathcal{R}_\mathrm{gs}(x) + \delta_{\mathbb{R}^{M,N,K}_+}(x)\, \,\,(5)$ to promote group-sparsity and non-negative group-sparsity, respectively.

In [None]:
# Prox of Sparsity (l1 norm, (2))
def soft_threshold( x, lam = 1.0 ):
    b = tf.Variable( lam )
    return (tf.nn.relu( x + b ) + tf.nn.relu( -x-b )) 
# Prox of Group sparsity (l2 norm along groups summed for all groups, (4)) 
# NWHC format with the same location in different channels belonging to the same group
def soft_group_threshold( x, lam = 1.0 ):
    b = tf.Variable(  lam )
    norm = tf.reduce_sum( x**2, 3 )
    return tf.expand_dims( tf.maximum( 1 - b / norm, 0 ) , axis = 3 ) * x
# Prox of Non-negative Sparsity (l1 norm + non-negative infinity-0 indicator, (3))
def s_nonneg( x, lam = 1.0 ):
    b = tf.Variable( lam )
    return tf.nn.relu( x - b )
# Prox of Non-negative Group Sparsity 
# (l2 norm along groups summed for all groups + non-negative infinity-0 indicator, (5))
def gs_nonneg( x ):
    return soft_group_threshold( tf.maximum( x, 0 ) )

## Initial kernel generation

In generic neural networks, one would generate the kernel

In [None]:
# Generate random kernels corresponding to random features
def generate_random_kernels( kernel_size, nrof_kernels ):
        
    kernels = tf.Variable( tf.truncated_normal( [kernel_size,
                                                 kernel_size,
                                                 nrof_kernels,
                                                 1], stddev = 0.01 ) )

    return kernels

# Generate kernels given by the accelerated proximal gradient for spot detection in FluoroSpot
# Import kernel generation from corresponding python module in our local folders
import ker
sigma_lims = np.array(
                     [ 0.        ,  2.03803742,  4.07607485,  6.44484021,  8.64666049,
                       10.78428037, 12.88968043, 14.97645529, 17.17280857, 19.33452064,
                       21.47205663, 23.59198881, 25.77936086, 27.94416127, 30.09126194,
                       32.22420107, 34.34561714, 36.5144422 , 38.66904129, 40.8116676 ,
                       42.94411325, 45.1138815 , 47.27192547, 49.41978111, 51.55872171,
                       53.68981279, 55.8511504 , 58.00356193, 60.14800566, 62.28530458,
                       64.44840214 ] )

generate_diffusion_kernels = lambda: ker.obtain_discrete_kernels( sigma_lims )

### Definition of a single hidden layer

In [None]:
def hidden_layer( x_prev, z_prev, source, kernels, non_linearity = soft_threshold ):
    
    h = tf.Variable( kernels )
    conv = tf.nn.depthwise_conv2d( input = z_prev,
                                   filter = h,
                                   strides = [1, 1, 1, 1],
                                   padding = 'SAME' )
    
    u = tf.reduce_sum( conv, 3 ) - source
    u = tf.expand_dims( u, 3 )
    
    h_m = tf.transpose( tf.reverse( tf.reverse( h, [0] ), [1] ), [0, 1, 3, 2] )

    conv_m = tf.nn.conv2d( input = u,
                           filter = h_m,
                           strides = [1, 1, 1, 1],
                           padding = 'SAME' )

    x_current = non_linearity( z_prev - conv_m )
    
    alpha = tf.Variable( tf.constant( 0.5 ) )
    z_current = x_current + alpha * (x_current - x_prev)
        
    return (x_current, z_current, h)

## Definition of the SpotNet

In [None]:
def spot_finding_net( nrof_images,
                      image_height, 
                      image_width, 
                      kernels,
                      nrof_hidden_layers,
                      non_linearity = soft_threshold ):
    
    # Seed for random initializations
    tf.set_random_seed(8888)
    
    # Placeholder for final FluoroSpot images
    source = tf.placeholder( tf.float32, [ nrof_images, image_height, image_width ] )
    
    # Create the input half-layer (equivalent to first iteration with initialization of 0)
    h_m = tf.transpose( tf.reverse( tf.reverse( kernels, [0] ), [1] ), [0, 1, 3, 2] )
    x = [ non_linearity( tf.nn.conv2d( input = tf.expand_dims( source, 3 ),
                                       filter = h_m,
                                       strides = [1, 1, 1, 1],
                                       padding = 'SAME' ) ) ]
    
    alpha = tf.Variable( tf.constant( 0.5 ) )
    z = [ (1 + alpha) * x[-1] ]
    h = []
    
    # Create hidden layers as a feedforward convolutional graph. Append each hidden 
    # layer to a list and connect next layer to the last element of this list.
    for i in range( nrof_hidden_layers ):
        x_current, z_current, h_current = hidden_layer( x[-1], z[-1], source, kernels, non_linearity )
        
        x.append(x_current)
        z.append(z_current)
        h.append(h_current)
        
    # The output is simply the last x
    output = x[-1]
    
    # Placeholder for PSDRs
    target = tf.placeholder( tf.float32, [ nrof_images, image_height, image_width, nrof_kernels ] )
    
    # Mean squared error for the prediction of the PSDRs
    loss = tf.reduce_mean( ( output - target ) ** 2 )
    
    return (source, loss, target, output)

## Definition of training strategy

In [None]:
# Create and return the optimizer
def network_training( loss, learning_rate ):
    #return tf.train.GradientDescentOptimizer( learning_rate ).minimize( loss )
    return tf.train.AdamOptimizer( learning_rate ).minimize( loss )

### Load test data and set parameters

In [None]:
# Fluorospot data-------------------------------------------
# Network parameters
nrof_hidden_layers = 3
learning_rate = 1e-3
batch_size = 1
nrof_train_steps = 400000
nrof_cells = 1250

results_dir = 'spotnet_results/'
prefix = 'SPOTNET_%d_CELLS_%d_LAYERS_%0.4f_LR_%d_STEPS_%d_BATCHSIZE'%( nrof_cells,
                                                                       nrof_hidden_layers, 
                                                                       learning_rate,
                                                                       nrof_train_steps,
                                                                       batch_size )

# Load dataset
data = np.load( '../../sim_data/result_' + str(nrof_cells) + '_cells_10_images.npy' )[()]

nrof_cells = data['nrof_cells']

# Extract images and PSDRs
images = data['fluorospot']
psdrs = data['psdrs']

# Extract shape parameters
nrof_images, image_height, image_width = images.shape
_, _, _, nrof_kernels = psdrs.shape

# Split dataset for training and testing
nrof_training_samples = int( 0.7 * nrof_images )
train_images, train_psdrs = (images[ : nrof_training_samples, ... ], psdrs[ : nrof_training_samples, ... ])
test_images, test_psdrs = (images[nrof_training_samples : , ... ], psdrs[nrof_training_samples : , ... ])

print( 'Dataset: Num images: %d, Image height: %d, Image width: %d, Num cells: %d, Num kernels: %d'%( nrof_images, 
                                                                                                      image_height, 
                                                                                                      image_width,
                                                                                                      nrof_cells,
                                                                                                      nrof_kernels ) )

### Construct network graph

In [None]:
# Build network computational graph

# Extend to placeholder for kernels in future version to allow for different kernel sizes
# kernels = tf.expand_dims( tf.Variable( generate_diffusion_kernels( ) ), 3 )

# Smaller kernels (for testing)
kernels = generate_random_kernels(5, 30) 

with tf.name_scope( 'SpotNet' ) as scope:
    source, loss, target, output = spot_finding_net( batch_size,
                                                     image_height, 
                                                     image_width, 
                                                     kernels,
                                                     nrof_hidden_layers,
                                                     soft_threshold )

# Build training computational graph
with tf.name_scope( 'Training' ) as scope:
    train_step = network_training( loss, learning_rate )

### Create TF session

In [None]:
nrof_gpu = 1 # {0, 1}

config = tf.ConfigProto( device_count = {'GPU': nrof_gpu} )

sess = tf.Session( config = config )
sess.run( tf.global_variables_initializer( ) )

### Run session and training, see loss evolution

In [None]:
import time

print( 'Training the network' )

# Space to store loss values
train_loss_record = np.empty( (0, ) )
test_loss_record  = np.empty( (0, ) )
iterations_record = np.empty( (0, ) )

# Indices that will be picked at random at each iteration
indices = np.arange( train_images.shape[0] )
np.random.seed(8888)

# Train the network
start = time.time()
for training_iteration in range( nrof_train_steps + 1 ):
    
    
    np.random.shuffle( indices )
    batch_indices = np.take( indices, np.arange( batch_size ), mode = 'wrap' )
    batch_input = train_images[ batch_indices, ... ]
    batch_target = train_psdrs[ batch_indices, ... ]
    
    _, train_loss = sess.run( [ train_step, loss ], feed_dict = { source: batch_input, 
                                                                  target: batch_target } )
    
    
    if training_iteration % (nrof_train_steps / 100) == 0:
        test_loss = 0
        for test_image_index in range( test_images.shape[0] ):
            test_loss += sess.run( loss, feed_dict = { source: np.expand_dims( test_images[ test_image_index, ... ], axis = 0 ), 
                                                       target: np.expand_dims( test_psdrs[  test_image_index, ... ], axis = 0 ) } )
        test_loss = test_loss / test_images.shape[0]
        test_loss_record = np.append( test_loss_record, test_loss )
        train_loss_record = np.append( train_loss_record, train_loss )
        iterations_record = np.append( iterations_record, training_iteration )
        
        if (test_loss == test_loss_record.min() and training_iteration > 10):
            print( 'Train step %d, Batch loss: %0.4f, Test loss: %0.4f, Elapsed: %ds. Best in test yet! Storing.'%( 
                                                                                    training_iteration, 
                                                                                    train_loss, 
                                                                                    test_loss, 
                                                                                    time.time() - start ) ) 
            tf.saved_model.simple_save( sess,
                results_dir + 'trained_spotnet_' + str( training_iteration ) + '_train-steps_5_kersize/',
                inputs = {'image': source},
                outputs = {'psdr': output} )
        else:
            print( 'Train step %d, Batch loss: %0.4f, Test loss: %0.4f, Elapsed: %ds.'%( 
                                                                                    training_iteration, 
                                                                                    train_loss, 
                                                                                    test_loss, 
                                                                                    time.time() - start ) )

## Save the model

In [None]:
tf.saved_model.simple_save(
    sess,
    results_dir + 'trained_spotnet_' + str( nrof_train_steps) + '_train-steps/',
    inputs = {'image': source},
    outputs = {'psdr': output}
)

In [None]:
results_dir = 'spotnet_results/'
results = np.loadtxt( results_dir + 'trained_spotnet_history_5_kersize.txt' )
iterations_record = results[:,0]; train_loss_record = results[:,1]; test_loss_record = results[:,2];

In [None]:
import matplotlib.pyplot as plt
plt.figure( figsize = [20,20] );
plt.plot(iterations_record,train_loss_record,iterations_record,test_loss_record);
plt.legend( ( 'Train loss (Batch size = 1)', 'Test loss (3 images)' ) );
plt.xlabel('Training steps');
plt.ylabel("MSE");