In [1]:
"""Learned primal method."""

import os
import time
import adler
adler.util.gpu.setup_one_gpu()

from adler.odl.phantom import random_phantom
from adler.tensorflow import prelu, cosine_decay

import tensorflow as tf
import numpy as np
import odl
import odl.contrib.tensorflow
from tqdm import tqdm_notebook

np.random.seed(0)

name = "learned-primal"

Picking GPU(s) 0


In [2]:
# define odl stuff
# Create ODL data structures
size = 128
space = odl.uniform_discr([-64, -64], [64, 64], [size, size],
                          dtype='float32')

geometry = odl.tomo.parallel_beam_geometry(space, num_angles=30)
operator = odl.tomo.RayTransform(space, geometry)

# Ensure operator has fixed operator norm for scale invariance
opnorm = odl.power_method_opnorm(operator)
operator = (1 / opnorm) * operator

# Create tensorflow layer from odl operator
odl_op_layer = odl.contrib.tensorflow.as_tensorflow_layer(operator,
                                                          'RayTransform')
odl_op_layer_adjoint = odl.contrib.tensorflow.as_tensorflow_layer(operator.adjoint,
                                                                  'RayTransformAdjoint')

In [3]:
# User selected paramters
n_data = 5
n_iter = 10
n_primal = 5
n_dual = 1
# tf params
print_freq = 100
chkpt = False
maximum_steps = 100000
logs_dir = 'logs_radon'
checkpoint_path = 'chkpt_{run_id}'

# define dedicated functions
def generate_data(validation=False):
    """Generate a set of random data."""
    n_generate = 1 if validation else n_data

    y_arr = np.empty((n_generate, operator.range.shape[0], operator.range.shape[1], 1), dtype='float32')
    x_true_arr = np.empty((n_generate, space.shape[0], space.shape[1], 1), dtype='float32')

    for i in range(n_generate):
        if validation:
            phantom = odl.phantom.shepp_logan(space, True)
        else:
            phantom = random_phantom(space)
        data = operator(phantom)
        noisy_data = data + odl.phantom.white_noise(operator.range) * np.mean(np.abs(data)) * 0.05

        x_true_arr[i, ..., 0] = phantom
        y_arr[i, ..., 0] = noisy_data

    return y_arr, x_true_arr

def apply_conv(x, filters=32):
    return tf.layers.conv2d(x, filters=filters, kernel_size=3, padding='SAME',
                            kernel_initializer=tf.contrib.layers.xavier_initializer())

In [4]:
sess = tf.InteractiveSession()
# define the model
with tf.name_scope('placeholders'):
    x_true = tf.placeholder(tf.float32, shape=[None, size, size, 1], name="x_true")
    y_rt = tf.placeholder(tf.float32, shape=[None, operator.range.shape[0], operator.range.shape[1], 1], name="y_rt")
    is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

with tf.name_scope('tomography'):
    with tf.name_scope('initial_values'):
        primal = tf.concat([tf.zeros_like(x_true)] * n_primal, axis=-1)

    for i in range(n_iter):
        with tf.variable_scope('dual_iterate_{}'.format(i)):
            evalop = odl_op_layer(primal[..., 1:2])
            dual = evalop - y_rt

        with tf.variable_scope('primal_iterate_{}'.format(i)):
            evalop = odl_op_layer_adjoint(dual[..., 0:1])
            update = tf.concat([primal, evalop], axis=-1)

            update = prelu(apply_conv(update), name='prelu_1')
            update = prelu(apply_conv(update), name='prelu_2')
            update = apply_conv(update, filters=n_primal)
            primal = primal + update

    x_result = primal[..., 0:1]


with tf.name_scope('loss'):
    residual = x_result - x_true
    squared_error = residual ** 2
    loss = tf.reduce_mean(squared_error)


with tf.name_scope('optimizer'):
    # Learning rate
    global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = 1e-3
    learning_rate = cosine_decay(starter_learning_rate,
                                 global_step,
                                 maximum_steps,
                                 name='learning_rate')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt_func = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                          beta2=0.99)

        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 1)
        optimizer = opt_func.apply_gradients(zip(grads, tvars),
                                             global_step=global_step)

Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, use
    tf.py_function, which takes a python function which manipulates tf eager
    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
    an ndarray (just call tensor.numpy()) but having access to eager tensors
    means `tf.py_function`s can use accelerators such as GPUs as well as
    being differentiable using a gradient tape.
    

For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.conv2d instead.
Instructions for updating:
Colocations handled automatically by placer.


In [5]:
# Summaries
# tensorboard --logdir=...
run_id = str(int(time.time()))

with tf.name_scope('summaries'):
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('psnr', -10 * tf.log(loss) / tf.log(10.0))

    tf.summary.image('x_result', x_result, max_outputs=n_data)
    tf.summary.image('x_true', x_true, max_outputs=n_data)
    tf.summary.image('squared_error', squared_error, max_outputs=n_data)
    tf.summary.image('residual', residual, max_outputs=n_data)

    merged_summary = tf.summary.merge_all()
    test_summary_writer = tf.summary.FileWriter(logs_dir + f'/test_{run_id}')
    train_summary_writer = tf.summary.FileWriter(logs_dir + f'/train_{run_id}', sess.graph)

In [6]:
# Initialize all TF variables
sess.run(tf.global_variables_initializer())

# Add op to save and restore
saver = tf.train.Saver()

# Generate validation data
y_arr_validate, x_true_arr_validate = generate_data(validation=True)

if chkpt:
    saver.restore(sess, checkpoint_path.format(run_id=run_id))

# Train the network
for i in tqdm_notebook(range(0, maximum_steps)):
    if i%10 == 0:
        y_arr, x_true_arr = generate_data()

    _, merged_summary_result_train, global_step_result = sess.run([optimizer, merged_summary, global_step],
                              feed_dict={x_true: x_true_arr,
                                         y_rt: y_arr,
                                         is_training: True})

    if i>0 and (i+1)%print_freq == 0:
        loss_result, merged_summary_result, global_step_result = sess.run([loss, merged_summary, global_step],
                              feed_dict={x_true: x_true_arr_validate,
                                         y_rt: y_arr_validate,
                                         is_training: False})

        train_summary_writer.add_summary(merged_summary_result_train, global_step_result)
        test_summary_writer.add_summary(merged_summary_result, global_step_result)

        print('iter={}, loss={}'.format(global_step_result, loss_result))

    if i>0 and (i+1)%1000 == 0:
        saver.save(sess, checkpoint_path.format(run_id=run_id))

HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))

iter=100, loss=0.013401556760072708
iter=200, loss=0.009856192395091057
iter=300, loss=0.009816672652959824
iter=400, loss=0.013956650160253048
iter=500, loss=0.009177898988127708
iter=600, loss=0.007365264929831028
iter=700, loss=0.00783451460301876
iter=800, loss=0.006058550905436277
iter=900, loss=0.005195289850234985
iter=1000, loss=0.004707389511168003
iter=1100, loss=0.005421053152531385
iter=1200, loss=0.004818640649318695
iter=1300, loss=0.0063102710992097855
iter=1400, loss=0.0060930922627449036
iter=1500, loss=0.0038196039386093616
iter=1600, loss=0.003832656191661954
iter=1700, loss=0.004254722036421299
iter=1800, loss=0.004779539071023464
iter=1900, loss=0.0034136828035116196
iter=2000, loss=0.002609821269288659
iter=2100, loss=0.003618695307523012
iter=2200, loss=0.002771900501102209
iter=2300, loss=0.002542594913393259
iter=2400, loss=0.0029977248050272465
iter=2500, loss=0.0025201523676514626
iter=2600, loss=0.0027549811638891697
iter=2700, loss=0.0018051951192319393
ite

iter=21600, loss=0.0005254679708741605
iter=21700, loss=0.0006342025008052588
iter=21800, loss=0.0004963738610967994
iter=21900, loss=0.0005737472674809396
iter=22000, loss=0.0005647923098877072
iter=22100, loss=0.0006720967940054834
iter=22200, loss=0.0005625277990475297
iter=22300, loss=0.0007540753576904535
iter=22400, loss=0.000638811441604048
iter=22500, loss=0.0005314939189702272
iter=22600, loss=0.0005919551476836205
iter=22700, loss=0.0005256544682197273
iter=22800, loss=0.0006311449105851352
iter=22900, loss=0.0005239403108134866
iter=23000, loss=0.000654612435027957
iter=23100, loss=0.0009409851627424359
iter=23200, loss=0.0006053365068510175
iter=23300, loss=0.00053225620649755
iter=23400, loss=0.0005489299073815346
iter=23500, loss=0.00047325441846624017
iter=23600, loss=0.0006655040197074413
iter=23700, loss=0.0005751101416535676
iter=23800, loss=0.0005957772955298424
iter=23900, loss=0.000539942120667547
iter=24000, loss=0.0005606257473118603
iter=24100, loss=0.0005088022

iter=42500, loss=0.0003168671391904354
iter=42600, loss=0.00037690604222007096
iter=42700, loss=0.0003860926954075694
iter=42800, loss=0.0003092371625825763
iter=42900, loss=0.0004397864977363497
iter=43000, loss=0.0003380870330147445
iter=43100, loss=0.00032736663706600666
iter=43200, loss=0.00030836043879389763
iter=43300, loss=0.0004063511732965708
iter=43400, loss=0.0004177007940597832
iter=43500, loss=0.00036302220541983843
iter=43600, loss=0.00041245363536290824
iter=43700, loss=0.0003430037177167833
iter=43800, loss=0.00034489939571358263
iter=43900, loss=0.00038762204349040985
iter=44000, loss=0.00032269611256197095
iter=44100, loss=0.00031701874104328454
iter=44200, loss=0.0004047924594487995
iter=44300, loss=0.00038702855817973614
iter=44400, loss=0.0004595720674842596
iter=44500, loss=0.0003645707620307803
iter=44600, loss=0.0005042238044552505
iter=44700, loss=0.0004185682046227157
iter=44800, loss=0.00032745327916927636
iter=44900, loss=0.0005034086061641574
iter=45000, lo

iter=63300, loss=0.00029139683465473354
iter=63400, loss=0.00026550597976893187
iter=63500, loss=0.0004021083004772663
iter=63600, loss=0.00035487377317622304
iter=63700, loss=0.00028605395345948637
iter=63800, loss=0.00026130519108846784
iter=63900, loss=0.00028575892793014646
iter=64000, loss=0.0003083109331782907
iter=64100, loss=0.0002518828259781003
iter=64200, loss=0.00029543228447437286
iter=64300, loss=0.0002535175590310246
iter=64400, loss=0.0002527812903281301
iter=64500, loss=0.00026042881654575467
iter=64600, loss=0.00026173534570261836
iter=64700, loss=0.0002718601026572287
iter=64800, loss=0.00026283730403520167
iter=64900, loss=0.00026587527827359736
iter=65000, loss=0.00023845990654081106
iter=65100, loss=0.0003091043618042022
iter=65200, loss=0.0002860981912817806
iter=65300, loss=0.00024070742074400187
iter=65400, loss=0.0002934572985395789
iter=65500, loss=0.00023650772345717996
iter=65600, loss=0.00024286469852086157
iter=65700, loss=0.00025790531071834266
iter=6580

iter=84100, loss=0.00031138473423197865
iter=84200, loss=0.0002737544709816575
iter=84300, loss=0.00023470932501368225
iter=84400, loss=0.00024266084074042737
iter=84500, loss=0.00022319331765174866
iter=84600, loss=0.0002964366867672652
iter=84700, loss=0.00023257547582034022
iter=84800, loss=0.00027724960818886757
iter=84900, loss=0.0002069795154966414
iter=85000, loss=0.00019455913570709527
iter=85100, loss=0.00021704146638512611
iter=85200, loss=0.00021850131452083588
iter=85300, loss=0.00020028295693919063
iter=85400, loss=0.00019846268696710467
iter=85500, loss=0.00019665165746118873
iter=85600, loss=0.00021514357649721205
iter=85700, loss=0.0002102152066072449
iter=85800, loss=0.00024382064293604344
iter=85900, loss=0.00020399497589096427
iter=86000, loss=0.00023462915851268917
iter=86100, loss=0.00023577547108288854
iter=86200, loss=0.00022697488020639867
iter=86300, loss=0.00020771389245055616
iter=86400, loss=0.0002604102483019233
iter=86500, loss=0.00021776727226097137
iter=