# Enhance bart with tf computation graph

Author: Guanxiong Luo<br />
Email: guanxiong.luo@med.uni-goettingen.de

## Overview
This tutorial is to present how to create regularization term with tensorflow and use it for image reconstruction in bart.

<img src="over.png" width="800"/>

## What we have
TensorFlow provides C API that can be used to build bindings for other languages. 

1. bart src/nn/tf_wrapper.c

    * create tensors, create tf session

    * import the exported graph

    * restore the session from the saved model

    * get operation nodes from the graph

    * execute operation with session.run()


2. TensorFlow C Libraries [2.4.0](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz)

3. A python program to export graph and weights (if any)

## What you can do with tf graph

1. we can create the regularization term $R(x)$ with tf graph for image reconstruction (integrated in bart pics).

$$\underset{x}{\arg \min}\ \|Ax-y\|^2+\lambda R(x)$$

## What you can learn here

1. simple example $R(x)=\|x\|^2$ without trainable weights

2. $R(x)=\log p(x, net(\Theta,x))$ with trainable weights $\Theta$, $net$ is represented as a prior [1]

[1] Luo, G, Zhao, N, Jiang, W, Hui, ES, Cao, P. MRI reconstruction using deep Bayesian estimation. Magn Reson Med. 2020; 84: 2246– 2261. https://doi.org/10.1002/mrm.28274 <br />
[2] Proc. Intl. Soc. Mag. Reson. Med. 29 (2021) P.1756

## Part I: How to create tf graph for bart

In [None]:
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

if True:
    import os
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import numpy as np

### Step 1: define input $x$

In [None]:
image_shape = [256, 256, 2]
batch_size = 1

# CAPI -> TF_GraphOperationByName(graph, "input_0")
# give name with input_0, ..., input_I 
x = tf.placeholder(tf.float32,
                   shape=[batch_size]+image_shape,
                   name='input_0')
v = tf.Variable(1.)
x = x * v

### Step 2: define output $R(x)=\|x\|^2$

In [None]:
l2 = tf.nn.l2_loss(x)#/np.product(image_shape)/batch_size        #R(x)=|x|^2
# CAPI -> TF_GraphOperationByName(graph, "output_0") -> nlop forward
# give name with output_0, ..., output_I
output = tf.identity(tf.stack([l2, tf.ones_like(l2)], axis=-1), name='output_0') 

### Step 3: define the gradient of $R(x)=\|x\|^2$

In [None]:
grad_ys = tf.placeholder(tf.float32,
                         shape=[2],
                         name='grad_ys_0')

# CAPI -> TF_GraphOperationByName(graph, "grad_0") -> nlop adj
grads = tf.squeeze(tf.gradients(output, x, grad_ys), name='grad_0') 

### Step 4: export graph and weights (if any)

In [None]:
from utils import export_model
# export_model(model_path, exported_path, name, as_text, use_gpu):

export_model(None, "./", "l2_toy", as_text=False, use_gpu=False)

In [None]:
!ls

##  Part II: How to use the graph in bart

###  Step 1: set envs bart

In [None]:
%env LIBRARY_PATH=$LIBRARY_PATH:/home/gluo/local_lib/tensorflow/include 
%env LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/gluo/local_lib/tensorflow/lib
%env TF_CPP_MIN_LOG_LEVEL=3

In [None]:
# release the gpu
import IPython
IPython.Application.instance().kernel.do_shutdown(True) 

In [None]:
%%bash

# set bart
TOOLBOX_PATH=/home/gluo/bin/

if [ ! -e $TOOLBOX_PATH/bart ] ; then
    echo 'create symbolic link...'
    ln -s /home/gluo/bart/bart /home/gluo/bin/bart
fi
which bart
bart version


### Step 2: check help info

In [None]:
!bart pics -Rh

proximal operation on $R(x)$

$$\hat{x}=\underset{x}{\arg \min} \|x-v\|^2 + \lambda R(x)$$

### Step 3: extract radial spokes and compute coil sensitivities

In [None]:
%%bash
# prepare coil sensitivities and radial spokes
# explain the dimensions
spokes=60
nx=256

bart extract 2 0 $spokes ksp_256 ksp_256_c
bart extract 2 0 $spokes traj_256 traj_256_c

In [None]:
!head -n2 ksp_256_c.hdr

In [None]:
%%bash 

bart nufft -i traj_256_c ksp_256_c zero_filled
bart fft $(bart bitmask 0 1) zero_filled grid_ksp
bart ecalib -r20 -m1 -c0.0001 grid_ksp coilsen_esp

## Example 1: $R(x)=\|x\|^2$

In [None]:
%%bash

graph_path=$(pwd)/l2_toy
lambda=0.01

bart pics -i100 -R TF:{$graph_path}:$lambda -d5 -e -t traj_256_c ksp_256_c coilsen_esp l2_pics_tf

In [None]:
%%bash

bart pics -l2 0.01 -e -t traj_256_c ksp_256_c coilsen_esp l2_pics

In [None]:
from utils import *
import matplotlib.pyplot as plt
fig, axis = plt.subplots(figsize=(8,4), ncols=2)
l2_pics = readcfl("l2_pics")
l2_pics_tf = readcfl("l2_pics_tf")

axis[0].imshow(abs(l2_pics), cmap='gray', interpolation='None')
axis[1].imshow(abs(l2_pics_tf), cmap='gray', interpolation='None')
axis[0].axis('off')
axis[1].axis('off')

## Example 2: $R(x)=\log p(x, net(x))$ 

In [None]:
# generate weights for density compensation
writecfl("weights", gen_weights(60, 256))           

In [None]:
!ls prior/

In [None]:
%%bash

graph_path=/home/gluo/ISMRM/ismrm-2021-software-session/05_bart_tf/prior/pixel_cnn
lp=10
bart pics -i30 -R TF:{$graph_path}:$lp -d5 -e\
              -p weights \
              -t traj_256_c \
              ksp_256_c coilsen_esp w_pics_prior

In [None]:
import matplotlib.pyplot as plt
pics_prior = readcfl("w_pics_prior")
fig, axis = plt.subplots(figsize=(12,4), ncols=3)

axis[0].imshow(abs(l2_pics), cmap='gray', interpolation='None')
axis[0].set_title("l2_pics")
axis[1].imshow(abs(l2_pics_tf), cmap='gray', interpolation='None')
axis[1].set_title("l2_pics_tf")
axis[2].imshow(abs(pics_prior), cmap='gray', interpolation='None')
axis[2].set_title("prior_pics")
axis[0].axis('off')
axis[1].axis('off')
axis[2].axis('off')


In [None]:
! bash clean