In [1]:
%cd ../../fastmri-reproducible-benchmark

/volatile/home/Zaccharie/workspace/fastmri-reproducible-benchmark


In [2]:
"""Learned primal method."""
%matplotlib nbagg
import os
import time

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm_notebook

from pdnet_crop import tf_op, tf_adj_op, tf_crop
from data import MaskedUntouched2DAllLoadedSequence, MaskedUntouched2DSequence
from utils import keras_psnr, keras_ssim

np.random.seed(0)

name = "learned-primal"

Using TensorFlow backend.


In [3]:
# redefining adler's prelu
# https://github.com/adler-j/adler/blob/master/adler/tensorflow/activation.py
def prelu(_x, init=0.0, name='prelu', trainable=True):
    with tf.variable_scope(name):
        alphas = tf.get_variable('alphas',
                                 shape=[int(_x.get_shape()[-1])],
                                 initializer=tf.constant_initializer(init),
                                 dtype=tf.float32,
                                 trainable=True)
        pos = tf.nn.relu(_x)
        neg = -alphas * tf.nn.relu(-_x)

        return pos + neg
    
# redefining adler's cosine decay
# https://github.com/adler-j/adler/blob/master/adler/tensorflow/training.py
def cosine_decay(learning_rate, global_step, maximum_steps,
                 name=None):
    from tensorflow.python.ops import math_ops
    from tensorflow.python.framework import ops

    if global_step is None:
        raise ValueError("global_step is required for cosine_decay.")
    with ops.name_scope(name, "CosineDecay",
                      [learning_rate, global_step, maximum_steps]) as name:
        learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
        dtype = learning_rate.dtype
        global_step = math_ops.cast(global_step, dtype)
        maximum_steps = math_ops.cast(maximum_steps, dtype)

        p = tf.mod(global_step / maximum_steps, 1)

    return learning_rate * (0.5 + 0.5 * math_ops.cos(p * np.pi))
    
def apply_conv(x, filters=32):
    return tf.layers.conv2d(x, filters=filters, kernel_size=3, padding='SAME',
                            kernel_initializer=tf.contrib.layers.xavier_initializer(), use_bias=False,)

def prelu_conv_complex(x, filters=16, name='prelu'):
    with tf.variable_scope(name):
        x_real = tf.math.real(x)
        x_imag = tf.math.imag(x)   
        with tf.variable_scope('real_conv', reuse=False):
            res_real = prelu(apply_conv(x_real, filters=filters))
        with tf.variable_scope('imag_conv', reuse=False):
            res_imag = prelu(apply_conv(x_imag, filters=filters))
    return tf.complex(res_real, res_imag)

def prelu_conv_complex_concat(x, filters=32, name='prelu'):
    with tf.variable_scope(name):
        x_real = tf.math.real(x)
        x_imag = tf.math.imag(x)   
        res = prelu(apply_conv(tf.concat([x_real, x_imag], axis=-1), filters=filters))
    return res

In [4]:
# User selected paramters
n_iter = 10
n_primal = 5
n_dual = 5
# tf params
print_freq = 500
chkpt = False
maximum_steps = 10000
logs_dir = 'logs_fastmri'
checkpoint_path = 'fastmri_chkpt/chkpt_{run_id}'

In [5]:
sess = tf.InteractiveSession()
# define the model
with tf.name_scope('placeholders'):
    x_true = tf.placeholder(tf.float32, shape=[None, 320, 320, 1], name="x_true")
    y_rt = tf.placeholder(tf.complex64, shape=[None, 640, None, 1], name="y_rt")
    mask = tf.placeholder(tf.float32, shape=[None, 640, None], name="mask")
    is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

with tf.name_scope('MRI'):
    with tf.name_scope('control'):
        zero_filled = tf_adj_op([y_rt, mask])
    with tf.name_scope('initial_values'):
        primal = tf.concat([tf.zeros_like(y_rt, dtype=tf.complex64)] * n_primal, axis=-1)
        dual = tf.concat([tf.zeros_like(y_rt, dtype=tf.complex64)] * n_dual, axis=-1)

    for i in range(n_iter):
        with tf.variable_scope('dual_iterate_{}'.format(i)):
            evalop = tf_op([primal[..., 1:2], mask])
            update = tf.concat([dual, evalop, y_rt], axis=-1)
            update = prelu_conv_complex_concat(update, name='prelu_1')
            update = prelu_conv_complex_concat(update, name='prelu_2')
            update = tf.complex(apply_conv(tf.math.real(update), filters=n_dual), apply_conv(tf.math.imag(update), filters=n_dual))
            dual = dual + update

        with tf.variable_scope('primal_iterate_{}'.format(i)):
            evalop = tf_adj_op([dual[..., 0:1], mask])
            update = tf.concat([primal, evalop], axis=-1)

            update = prelu_conv_complex_concat(update, name='prelu_1')
            update = prelu_conv_complex_concat(update, name='prelu_2')
            update = tf.complex(apply_conv(tf.math.real(update), filters=n_primal), apply_conv(tf.math.imag(update), filters=n_primal))
            primal = primal + update

    x_result = primal[..., 0:1]
    x_result = tf.math.abs(x_result)
    x_result = tf_crop(x_result)


with tf.name_scope('loss'):
    residual = x_result - x_true
    squared_error = tf.math.real(residual)**2 + tf.math.imag(residual)**2
    loss = tf.reduce_mean(squared_error)


with tf.name_scope('optimizer'):
    # Learning rate
    global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = 1e-2
    learning_rate = 1e-3
    learning_rate = cosine_decay(starter_learning_rate,
                                 global_step,
                                 maximum_steps,
                                 name='learning_rate')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt_func = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                          beta2=0.99)

        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 1)
        optimizer = opt_func.apply_gradients(zip(grads, tvars),
                                             global_step=global_step)

Instructions for updating:
Use tf.cast instead.

For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.conv2d instead.
Instructions for updating:
Colocations handled automatically by placer.


In [6]:
# Summaries
run_id = str(int(time.time()))
print(run_id)

with tf.name_scope('summaries'):
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('psnr', tf.reduce_mean(keras_psnr(x_true, x_result)))
    tf.summary.scalar('ssim', tf.reduce_mean(keras_ssim(x_true, x_result)))

    tf.summary.image('x_result', tf.abs(x_result), max_outputs=1)
    tf.summary.image('x_true', tf.abs(x_true), max_outputs=1)
    tf.summary.image('zero_filled', tf.abs(tf_crop(zero_filled)), max_outputs=1)
    tf.summary.image('squared_error', squared_error, max_outputs=1)
    tf.summary.image('residual', tf.abs(residual), max_outputs=1)

    merged_summary = tf.summary.merge_all()
    test_summary_writer = tf.summary.FileWriter(logs_dir + f'/test_{run_id}')
    train_summary_writer = tf.summary.FileWriter(logs_dir + f'/train_{run_id}', sess.graph)

1567435989


In [7]:
# paths
train_path = '/media/Zaccharie/UHRes/singlecoil_train/singlecoil_train/'
val_path = '/media/Zaccharie/UHRes/singlecoil_val/'
test_path = '/media/Zaccharie/UHRes/singlecoil_test/'

n_samples_train = 34742
n_samples_val = 7135

n_volumes_train = 973
n_volumes_val = 199

# generators
AF = 4
# MaskShifted2DSequence, MaskShiftedSingleImage2DSequence, MaskedUntouched2DSequence
train_gen = MaskedUntouched2DAllLoadedSequence(train_path, af=AF, inner_slices=1)
val_gen = MaskedUntouched2DSequence(val_path, af=AF)

In [None]:
# Initialize all TF variables
sess.run(tf.global_variables_initializer())

# Add op to save and restore
saver = tf.train.Saver()

# Generate validation data
if chkpt:
    saver.restore(sess, checkpoint_path.format(run_id=run_id))

# Train the network
for i in tqdm_notebook(range(0, maximum_steps)):
#     [kspaces, mask_batch], images = train_gen[i % n_volumes_train]
    [kspaces, mask_batch], images = train_gen[0]
    _, merged_summary_result_train, global_step_result = sess.run([optimizer, merged_summary, global_step],
                              feed_dict={x_true: images,
                                         y_rt: kspaces,
                                         mask: mask_batch,
                                         is_training: True})
    train_summary_writer.add_summary(merged_summary_result_train, global_step_result)
    if i>0 and i%print_freq == 0:
        [kspaces_val, mask_batch_val], images_val = val_gen[i % n_volumes_val]
        loss_result, merged_summary_result, global_step_result = sess.run([loss, merged_summary, global_step],
                              feed_dict={x_true: images_val,
                                         y_rt: kspaces_val,
                                         mask: mask_batch_val,
                                         is_training: False})

        
        test_summary_writer.add_summary(merged_summary_result, global_step_result)

        print('iter={}, loss={}'.format(global_step_result, loss_result))


    if i>0 and (i+1)%1000 == 0:
        saver.save(sess, checkpoint_path.format(run_id=run_id))

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))

iter=501, loss=2.2292662160694476e-10
iter=1001, loss=6.701293298050359e-09
iter=1501, loss=1.734375004724953e-10
iter=2001, loss=2.0298562830589617e-09
iter=2501, loss=3.3552047251639294e-10
iter=3001, loss=2.143660943199066e-10
iter=3501, loss=2.5833543637610035e-10
iter=4001, loss=3.6143443793434926e-10
iter=4501, loss=3.418285932088594e-10
iter=5001, loss=1.643195912670592e-09
