In [1]:
"""Learned primal method."""

import os
import time
import adler
adler.util.gpu.setup_one_gpu()

from adler.odl.phantom import random_phantom
from adler.tensorflow import prelu, cosine_decay

import tensorflow as tf
import numpy as np
import odl
import odl.contrib.tensorflow
from tqdm import tqdm_notebook

np.random.seed(0)

name = "learned-primal"

Picking GPU(s) 0


In [2]:
# define odl stuff
# Create ODL data structures
size = 128
space = odl.uniform_discr([-64, -64], [64, 64], [size, size],
                          dtype='float32')

geometry = odl.tomo.parallel_beam_geometry(space, num_angles=30)
operator = odl.tomo.RayTransform(space, geometry)

# Ensure operator has fixed operator norm for scale invariance
opnorm = odl.power_method_opnorm(operator)
operator = (1 / opnorm) * operator

# Create tensorflow layer from odl operator
odl_op_layer = odl.contrib.tensorflow.as_tensorflow_layer(operator,
                                                          'RayTransform')
odl_op_layer_adjoint = odl.contrib.tensorflow.as_tensorflow_layer(operator.adjoint,
                                                                  'RayTransformAdjoint')

In [3]:
# User selected paramters
n_data = 5
n_iter = 10
n_primal = 5
n_dual = 1
# tf params
print_freq = 100
chkpt = False
maximum_steps = 100000
logs_dir = 'logs'
checkpoint_path = 'chkpt_{run_id}'

# define dedicated functions
def generate_data(validation=False):
    """Generate a set of random data."""
    n_generate = 1 if validation else n_data

    y_arr = np.empty((n_generate, operator.range.shape[0], operator.range.shape[1], 1), dtype='float32')
    x_true_arr = np.empty((n_generate, space.shape[0], space.shape[1], 1), dtype='float32')

    for i in range(n_generate):
        if validation:
            phantom = odl.phantom.shepp_logan(space, True)
        else:
            phantom = random_phantom(space)
        data = operator(phantom)
        noisy_data = data + odl.phantom.white_noise(operator.range) * np.mean(np.abs(data)) * 0.05

        x_true_arr[i, ..., 0] = phantom
        y_arr[i, ..., 0] = noisy_data

    return y_arr, x_true_arr

def apply_conv(x, filters=32):
    return tf.layers.conv2d(x, filters=filters, kernel_size=3, padding='SAME',
                            kernel_initializer=tf.contrib.layers.xavier_initializer())

In [4]:
sess = tf.InteractiveSession()
# define the model
with tf.name_scope('placeholders'):
    x_true = tf.placeholder(tf.float32, shape=[None, size, size, 1], name="x_true")
    y_rt = tf.placeholder(tf.float32, shape=[None, operator.range.shape[0], operator.range.shape[1], 1], name="y_rt")
    is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

with tf.name_scope('tomography'):
    with tf.name_scope('initial_values'):
        primal = tf.concat([tf.zeros_like(x_true)] * n_primal, axis=-1)

    for i in range(n_iter):
        with tf.variable_scope('dual_iterate_{}'.format(i)):
            evalop = odl_op_layer(primal[..., 1:2])
            dual = evalop - y_rt

        with tf.variable_scope('primal_iterate_{}'.format(i)):
            evalop = odl_op_layer_adjoint(dual[..., 0:1])
            update = tf.concat([primal, evalop], axis=-1)

            update = prelu(apply_conv(update), name='prelu_1')
            update = prelu(apply_conv(update), name='prelu_2')
            update = apply_conv(update, filters=n_primal)
            primal = primal + update

    x_result = primal[..., 0:1]


with tf.name_scope('loss'):
    residual = x_result - x_true
    squared_error = residual ** 2
    loss = tf.reduce_mean(squared_error)


with tf.name_scope('optimizer'):
    # Learning rate
    global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = 1e-3
    learning_rate = cosine_decay(starter_learning_rate,
                                 global_step,
                                 maximum_steps,
                                 name='learning_rate')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt_func = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                          beta2=0.99)

        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 1)
        optimizer = opt_func.apply_gradients(zip(grads, tvars),
                                             global_step=global_step)

Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, use
    tf.py_function, which takes a python function which manipulates tf eager
    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
    an ndarray (just call tensor.numpy()) but having access to eager tensors
    means `tf.py_function`s can use accelerators such as GPUs as well as
    being differentiable using a gradient tape.
    

For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.conv2d instead.
Instructions for updating:
Colocations handled automatically by placer.


In [5]:
# Summaries
# tensorboard --logdir=...
run_id = str(int(time.time()))

with tf.name_scope('summaries'):
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('psnr', -10 * tf.log(loss) / tf.log(10.0))

    tf.summary.image('x_result', x_result, max_outputs=n_data)
    tf.summary.image('x_true', x_true, max_outputs=n_data)
    tf.summary.image('squared_error', squared_error, max_outputs=n_data)
    tf.summary.image('residual', residual, max_outputs=n_data)

    merged_summary = tf.summary.merge_all()
    test_summary_writer = tf.summary.FileWriter(logs_dir + f'/test_{run_id}')
    train_summary_writer = tf.summary.FileWriter(logs_dir + f'/train_{run_id}', sess.graph)

In [6]:
# Initialize all TF variables
sess.run(tf.global_variables_initializer())

# Add op to save and restore
saver = tf.train.Saver()

# Generate validation data
y_arr_validate, x_true_arr_validate = generate_data(validation=True)

if chkpt:
    saver.restore(sess, checkpoint_path.format(run_id=run_id))

# Train the network
for i in tqdm_notebook(range(0, maximum_steps)):
    if i%10 == 0:
        y_arr, x_true_arr = generate_data()

    _, merged_summary_result_train, global_step_result = sess.run([optimizer, merged_summary, global_step],
                              feed_dict={x_true: x_true_arr,
                                         y_rt: y_arr,
                                         is_training: True})

    if i>0 and (i+1)%print_freq == 0:
        loss_result, merged_summary_result, global_step_result = sess.run([loss, merged_summary, global_step],
                              feed_dict={x_true: x_true_arr_validate,
                                         y_rt: y_arr_validate,
                                         is_training: False})

        train_summary_writer.add_summary(merged_summary_result_train, global_step_result)
        test_summary_writer.add_summary(merged_summary_result, global_step_result)

        print('iter={}, loss={}'.format(global_step_result, loss_result))

    if i>0 and (i+1)%1000 == 0:
        saver.save(sess, checkpoint_path.format(run_id=run_id))

HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))

iter=11, loss=0.030791152268648148
iter=21, loss=0.025490805506706238
iter=31, loss=0.025570228695869446
iter=41, loss=0.01621125265955925
iter=51, loss=0.014470618218183517
iter=61, loss=0.014319454319775105
iter=71, loss=0.0151748675853014
iter=81, loss=0.01388933788985014
iter=91, loss=0.016491960734128952
iter=101, loss=0.01638680323958397
iter=111, loss=0.012187549844384193
iter=121, loss=0.011276472359895706
iter=131, loss=0.011320197023451328
iter=141, loss=0.01084638200700283
iter=151, loss=0.010776890441775322
iter=161, loss=0.010750843212008476
iter=171, loss=0.009907884523272514
iter=181, loss=0.009856287389993668
iter=191, loss=0.00994746945798397
iter=201, loss=0.009803042747080326
iter=211, loss=0.010763794183731079
iter=221, loss=0.009239017963409424
iter=231, loss=0.008684330619871616
iter=241, loss=0.009320997633039951
iter=251, loss=0.009821703657507896
iter=261, loss=0.009501857683062553
iter=271, loss=0.00914940144866705
iter=281, loss=0.008979646489024162
iter=291,

iter=2251, loss=0.003138615982607007
iter=2261, loss=0.0030099889263510704
iter=2271, loss=0.0029240637086331844
iter=2281, loss=0.0031833634711802006
iter=2291, loss=0.002735947258770466
iter=2301, loss=0.0028414903208613396
iter=2311, loss=0.0034679402597248554
iter=2321, loss=0.002733362140133977
iter=2331, loss=0.0031002811156213284
iter=2341, loss=0.0029247375205159187
iter=2351, loss=0.0026825941167771816
iter=2361, loss=0.0027931900694966316
iter=2371, loss=0.0026691036764532328
iter=2381, loss=0.0029773805290460587
iter=2391, loss=0.0029486212879419327
iter=2401, loss=0.003084425814449787
iter=2411, loss=0.002909980248659849
iter=2421, loss=0.0028691855259239674
iter=2431, loss=0.002755998633801937
iter=2441, loss=0.0025971606373786926
iter=2451, loss=0.002548417542129755
iter=2461, loss=0.003121796529740095
iter=2471, loss=0.002975266659632325
iter=2481, loss=0.0031408476643264294
iter=2491, loss=0.002813628874719143
iter=2501, loss=0.002926656510680914
iter=2511, loss=0.00282