In [1]:
"""Learned primal method."""
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import os
import time
import scipy.io
import adler
adler.util.gpu.setup_one_gpu()

from adler.odl.phantom import random_phantom
from adler.tensorflow import cosine_decay

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import odl
import odl.contrib.tensorflow
from odl.trafos.non_uniform_fourier import NonUniformFourierTransform
from tqdm import tqdm_notebook

np.random.seed(0)

name = "learned-primal"

Picking GPU(s) 0


In [2]:
# redefining adler's prelu
# https://github.com/adler-j/adler/blob/master/adler/tensorflow/activation.py
def prelu(_x, init=0.0, name='prelu', trainable=True):
    with tf.variable_scope(name):
        alphas = tf.get_variable('alphas',
                                 shape=[int(_x.get_shape()[-1])],
                                 initializer=tf.constant_initializer(init),
                                 dtype=tf.float64,
                                 trainable=True)
        pos = tf.nn.relu(_x)
        neg = -alphas * tf.nn.relu(-_x)

        return pos + neg

In [3]:
# define the nfft samples
sparkling_traj_file_path = '2019-Mar-01_N512_nc34_ns3073_OS1_decim64_decay2_tau0.75_nrevol1/samples_SPARKLING_N512_nc34x3073_OS1.mat'
kspace_loc = scipy.io.loadmat(sparkling_traj_file_path)['samples'] / (2 * 1280)
kspace_loc[np.where(kspace_loc == 0.5)] = -0.5
kspace_loc = kspace_loc.astype(np.float64)

In [4]:
# define odl stuff
# Create ODL data structures
size = 512
space = odl.uniform_discr([-size//2, -size//2], [size//2, size//2], [size, size], dtype='complex128')

operator = NonUniformFourierTransform(space=space, samples=kspace_loc, skip_normalization=True)

# Ensure operator has fixed operator norm for scale invariance
opnorm = odl.power_method_opnorm(operator)
operator = (1 / opnorm) * operator

# Create tensorflow layer from odl operator
odl_op_layer = odl.contrib.tensorflow.as_tensorflow_layer(operator, 'NFFT')
odl_op_layer_adjoint = odl.contrib.tensorflow.as_tensorflow_layer(operator.adjoint, 'NFFTAdjoint')

In [5]:
# User selected paramters
n_data = 1
n_iter = 10
n_primal = 3
n_dual = 1
# tf params
print_freq = 250
chkpt = False
maximum_steps = 20000
logs_dir = 'logs_fft'
checkpoint_path = 'fft_chkpt/chkpt_{run_id}'

# define dedicated functions
def generate_data(validation=False):
    """Generate a set of random data."""
    n_generate = 1 if validation else n_data

    y_arr = np.empty((n_generate, operator.range.shape[0], 1), dtype='complex128')
    x_true_arr = np.empty((n_generate, space.shape[0], space.shape[1], 1), dtype='complex128')

    for i in range(n_generate):
        if validation:
            phantom = odl.phantom.shepp_logan(space, True)
        else:
            phantom = random_phantom(space)
#         data = operator(phantom)
        noisy_data = operator(phantom + np.random.randn(size, size) * 0.05)
#         noisy_data = data + odl.phantom.white_noise(operator.range) * np.mean(np.abs(data)) * 0.05
        x_true_arr[i, ..., 0] = phantom
        y_arr[i, ..., 0] = noisy_data

    return y_arr, x_true_arr

def get_phantom_data():
    y_arr = np.empty((1, operator.range.shape[0], 1), dtype='complex128')
    x_true_arr = np.empty((1, space.shape[0], space.shape[1], 1), dtype='complex128')
    phantom = np.load('brain_phantom.npy').astype(np.complex64) / 255
#   data = operator(phantom)
    noisy_data = operator(phantom + np.random.randn(size, size) * 0.05)
#   noisy_data = data + odl.phantom.white_noise(operator.range) * np.mean(np.abs(data)) * 0.05
    y_arr[0, ..., 0] = noisy_data
    x_true_arr[0, ..., 0] = phantom
    
    return y_arr, x_true_arr

def apply_conv(x, filters=32):
    return tf.layers.conv2d(x, filters=filters, kernel_size=3, padding='SAME',
                            kernel_initializer=tf.contrib.layers.xavier_initializer())

def prelu_conv_complex(x, filters=32, name='prelu'):
    with tf.variable_scope(name):
        x_real = tf.math.real(x)
        x_imag = tf.math.imag(x)   
        with tf.variable_scope('real_conv', reuse=False):
            res_real = prelu(apply_conv(x_real, filters=filters))
        with tf.variable_scope('imag_conv', reuse=False):
            res_imag = prelu(apply_conv(x_imag, filters=filters))
    return tf.complex(res_real, res_imag)

In [6]:
sess = tf.InteractiveSession()
n_filters = 32
# define the model
with tf.name_scope('placeholders'):
    x_true = tf.placeholder(tf.complex128, shape=[None, size, size, 1], name="x_true")
    y_rt = tf.placeholder(tf.complex128, shape=[None, operator.range.shape[0], 1], name="y_rt")
    is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

with tf.name_scope('MRI'):
    with tf.name_scope('initial_values'):
        primal = tf.concat([tf.zeros_like(x_true)] * n_primal, axis=-1)

    for i in range(n_iter):
        with tf.variable_scope('dual_iterate_{}'.format(i)):
            evalop = odl_op_layer(primal[..., 1:2])
            dual = evalop - y_rt

        with tf.variable_scope('primal_iterate_{}'.format(i)):
            evalop = odl_op_layer_adjoint(dual[..., 0:1])
            update = tf.concat([primal, evalop], axis=-1)

            update = prelu_conv_complex(update, name='prelu_1', filters=n_filters)
            update = prelu_conv_complex(update, name='prelu_2', filters=n_filters)
            update = tf.complex(apply_conv(tf.math.real(update), filters=n_primal), apply_conv(tf.math.imag(update), filters=n_primal))
            primal = primal + update

    x_result = primal[..., 0:1]


with tf.name_scope('loss'):
    residual = x_result - x_true
    squared_error = tf.math.real(residual)**2 + tf.math.imag(residual)**2
    loss = tf.reduce_mean(squared_error)


with tf.name_scope('optimizer'):
    # Learning rate
    global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = 1e-3
    learning_rate = cosine_decay(starter_learning_rate,
                                 global_step,
                                 maximum_steps,
                                 name='learning_rate')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt_func = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                          beta2=0.99)

        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 1)
        optimizer = opt_func.apply_gradients(zip(grads, tvars),
                                             global_step=global_step)

Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, use
    tf.py_function, which takes a python function which manipulates tf eager
    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
    an ndarray (just call tensor.numpy()) but having access to eager tensors
    means `tf.py_function`s can use accelerators such as GPUs as well as
    being differentiable using a gradient tape.
    

For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.conv2d instead.
Instructions for updating:
Colocations handled automatically by placer.


In [7]:
# Summaries
# tensorboard --logdir=...
run_id = str(int(time.time()))
print(run_id)

with tf.name_scope('summaries'):
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('psnr', -10 * tf.log(loss) / tf.cast(tf.log(10.0), dtype=tf.float64))

    tf.summary.image('x_result', tf.abs(x_result), max_outputs=n_data)
    tf.summary.image('x_true', tf.abs(x_true), max_outputs=n_data)
    tf.summary.image('squared_error', squared_error, max_outputs=n_data)
    tf.summary.image('residual', tf.abs(residual), max_outputs=n_data)

    merged_summary = tf.summary.merge_all()
    test_summary_writer = tf.summary.FileWriter(logs_dir + f'/test_{run_id}')
    phantom_summary_writer = tf.summary.FileWriter(logs_dir + f'/phantom_{run_id}')
    train_summary_writer = tf.summary.FileWriter(logs_dir + f'/train_{run_id}', sess.graph)

1558535056


In [8]:
# Initialize all TF variables
sess.run(tf.global_variables_initializer())

# Add op to save and restore
saver = tf.train.Saver()

# Generate validation data
y_arr_validate, x_true_arr_validate = generate_data(validation=True)
y_arr_brain, x_true_arr_brain = get_phantom_data()

if chkpt:
    saver.restore(sess, checkpoint_path.format(run_id=run_id))

# Train the network
for i in tqdm_notebook(range(0, maximum_steps)):
    if i%10 == 0:
        y_arr, x_true_arr = generate_data()
    if (i+1)%20 == 0:
#         run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
#         run_metadata = tf.RunMetadata()
#         extra_run_args = {
#             'options': run_options,
#             'run_metadata': run_metadata,
#         }
        pass
    else:
        extra_run_args = {}
    _, merged_summary_result_train, global_step_result = sess.run(
        [optimizer, merged_summary, global_step],
        feed_dict={x_true: x_true_arr, y_rt: y_arr, is_training: True},
        **extra_run_args,
    )

    if i>0 and (i+1)%print_freq == 0:
        loss_result, merged_summary_result, global_step_result = sess.run([loss, merged_summary, global_step],
                              feed_dict={x_true: x_true_arr_validate,
                                         y_rt: y_arr_validate,
                                         is_training: False})
        
        loss_result_brain, merged_summary_result_brain, global_step_result = sess.run([loss, merged_summary, global_step],
                              feed_dict={x_true: x_true_arr_brain,
                                         y_rt: y_arr_brain,
                                         is_training: False})

        train_summary_writer.add_summary(merged_summary_result_train, global_step_result)
#         if (i+1)%20 == 0:
#             train_summary_writer.add_run_metadata(run_metadata, 'step%d' % global_step_result)
        test_summary_writer.add_summary(merged_summary_result, global_step_result)
        phantom_summary_writer.add_summary(merged_summary_result_brain, global_step_result)

        print('iter={}, loss={}'.format(global_step_result, loss_result))

    if i>0 and (i+1)%1000 == 0:
        saver.save(sess, checkpoint_path.format(run_id=run_id))

HBox(children=(IntProgress(value=0, max=20000), HTML(value='')))

iter=250, loss=0.007842192885411943
iter=500, loss=0.006143378857919086
iter=750, loss=0.01064054422379035
iter=1000, loss=0.006430048706156618
iter=1250, loss=0.002617312314507102
iter=1500, loss=0.0037992448437270607
iter=1750, loss=0.004139043175190264
iter=2000, loss=0.0037822742149142756
iter=2250, loss=0.002137479441712567
iter=2500, loss=0.0034462633865335796
iter=2750, loss=0.003326310413119175
iter=3000, loss=0.0038740484539963047
iter=3250, loss=0.00332332250018915
iter=3500, loss=0.0014841862162186397
iter=3750, loss=0.001608060114841227
iter=4000, loss=0.0036568257229508657
iter=4250, loss=0.0018119782150223735
iter=4500, loss=0.0009728605186669279
iter=4750, loss=0.0011218436822880832
iter=5000, loss=0.0023693628473413003
iter=5250, loss=0.0013515868691864187
iter=5500, loss=0.0011038089657285635
iter=5750, loss=0.0025088272329283233
iter=6000, loss=0.0014275163725186086
iter=6250, loss=0.0011549086924581206
iter=6500, loss=0.0009915640832010993
iter=6750, loss=0.001431458