# Enhance bart with tf computation graph

**Authors**: [Guanxiong Luo](mailto:guanxiong.luo@med.uni-goettingen.de); [Nick Scholand](mailto:nick.scholand@med.uni-goettingen.de); [Christian Holme](mailto:christian.holme@med.uni-goettingen.de).

**Presenter**: [Guanxiong Luo](mailto:guanxiong.luo@med.uni-goettingen.de).

**Institution**: University Medical Center Göttingen

## Overview
This tutorial is to present how to create regularization term with tensorflow and use it for image reconstruction in bart.

<img src="over.png" width="800"/>

## What we have
TensorFlow provides C API that can be used to build bindings for other languages. 

1. bart src/nn/tf_wrapper.c

    * create tensors, create tf session

    * import the exported graph

    * restore the session from the saved model

    * get operation nodes from the graph

    * execute operation with session.run()


2. TensorFlow C Libraries [2.4.0](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz)

3. A python program to export graph and weights (if any)

## What you can do with tf graph

1. we can create the regularization term $R(x)$ with tf graph for image reconstruction (integrated in bart pics).

$$\underset{x}{\arg \min}\ \|Ax-y\|^2+\lambda R(x)$$

## What you can learn here

1. simple example $R(x)=\|x\|^2$ without trainable weights

2. $R(x)=\log p(x, net(\Theta,x))$ with trainable weights $\Theta$, $net$ is represented as a prior [1]

[1] Luo, G, Zhao, N, Jiang, W, Hui, ES, Cao, P. MRI reconstruction using deep Bayesian estimation. Magn Reson Med. 2020; 84: 2246– 2261. https://doi.org/10.1002/mrm.28274 <br />
[2] Proc. Intl. Soc. Mag. Reson. Med. 29 (2021) P.1756

## Part 0: Fetch data

In [5]:
! wget https://raw.githubusercontent.com/mrirecon/bart-workshop/master/ismrm2021/bart_tensorflow/data.zip
! unzip data.zip

--2021-05-14 17:07:57--  https://raw.githubusercontent.com/mrirecon/bart-workshop/master/ismrm2021/bart_tensorflow/data.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 38699214 (37M) [application/zip]
Saving to: 'data.zip'


2021-05-14 17:07:58 (67.2 MB/s) - 'data.zip' saved [38699214/38699214]



## Part I: How to create tf graph for bart

In [6]:
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

if True:
    import os
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import numpy as np

### Step 1: define input $x$

In [7]:
image_shape = [256, 256, 2]
batch_size = 1

# CAPI -> TF_GraphOperationByName(graph, "input_0")
# give name with input_0, ..., input_I 
x = tf.placeholder(tf.float32,
                   shape=[batch_size]+image_shape,
                   name='input_0')
v = tf.Variable(1.)
x = x * v

### Step 2: define output $R(x)=\|x\|^2$

In [8]:
l2 = tf.nn.l2_loss(x)#/np.product(image_shape)/batch_size        #R(x)=|x|^2
# CAPI -> TF_GraphOperationByName(graph, "output_0") -> nlop forward
# give name with output_0, ..., output_I
output = tf.identity(tf.stack([l2, tf.ones_like(l2)], axis=-1), name='output_0') 

### Step 3: define the gradient of $R(x)=\|x\|^2$

In [9]:
grad_ys = tf.placeholder(tf.float32,
                         shape=[2],
                         name='grad_ys_0')

# CAPI -> TF_GraphOperationByName(graph, "grad_0") -> nlop adj
grads = tf.squeeze(tf.gradients(output, x, grad_ys), name='grad_0') 

### Step 4: export graph and weights (if any)

In [11]:
from utils import export_model
# export_model(model_path, exported_path, name, as_text, use_gpu):

export_model(None, "./", "l2_toy", as_text=False, use_gpu=False)

Exported


In [12]:
!ls

__pycache__    data.zip			   l2_toy.index  prior
bart_tf.ipynb  ksp_256.cfl		   l2_toy.meta	 traj_256.cfl
checkpoint     ksp_256.hdr		   l2_toy.pb	 traj_256.hdr
clean	       l2_toy.data-00000-of-00001  over.png	 utils.py


In [13]:
# release the gpu
import IPython
IPython.Application.instance().kernel.do_shutdown(True) 

{'status': 'ok', 'restart': True}

##  Part II: How to use the graph in bart

###  Step 1: set bart

In [None]:
%%bash

# Download tensorflow c libraries
wget https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz
mkdir tensorflow && tar -C tensorflow -xvzf libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz

In [None]:
%%bash

# Estimate GPU Type
GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader)

echo "GPU Type:"
echo $GPU_NAME

if [ "Tesla K80" = "$GPU_NAME" ];
then
    echo "GPU type Tesla K80 does not support CUDA 11. Set CUDA to version 10.1."

    # Install CUDA-10.1 if not already installed
    apt-get install cuda-10-1 cuda-drivers &> /dev/null

    # Change default CUDA to version 10.1
    cd /usr/local
    rm cuda
    ln -s cuda-10.1 cuda
    cd /content

else
    echo "GPU Information:"
    nvidia-smi --query-gpu=gpu_name,driver_version,memory.total --format=csv
    nvcc --version
    echo "Current GPU supports default CUDA-11."
    echo "No further actions are necessary."
fi


# Install BARTs dependencies
apt-get install -y make gcc libfftw3-dev liblapacke-dev libpng-dev libopenblas-dev &> /dev/null

# Download BART version
BRANCH=master
[ -d /content/bart ] && rm -r /content/bart
git clone https://github.com/mrirecon/bart/ bart

[ -d "bart" ] && echo "BART branch ${BRANCH} was downloaded successfully."

cd bart

# Switch to desired branch of the BART project
git checkout $BRANCH

# Define specifications 
COMPILE_SPECS=" PARALLEL=4
                CUDA=1
                CUDA_BASE=/usr/local/cuda
                CUDA_LIB=lib64
                TENSORFLOW=1
                TENSORFLOW_BASE=../tensorflow/"

printf "%s\n" $COMPILE_SPECS > Makefiles/Makefile.local

make &> /dev/null

In [None]:
%env LIBRARY_PATH=$LIBRARY_PATH:/content/tensorflow/include 
%env LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/content/tensorflow/lib
%env TF_CPP_MIN_LOG_LEVEL=3

In [None]:
import os
import sys

if not('TOOLBOX_PATH' in os.environ):
    os.environ['TOOLBOX_PATH'] = "/content/bart"
os.environ['PATH'] = os.environ['TOOLBOX_PATH'] + ":" + os.environ['PATH']
sys.path.append(os.environ['TOOLBOX_PATH'] + "/python/")

### Step 2: check help info

In [None]:
!bart pics -Rh

proximal operation on $R(x)$

$$\hat{x}=\underset{x}{\arg \min} \|x-v\|^2 + \lambda R(x)$$

### Step 3: extract radial spokes and compute coil sensitivities

In [None]:
%%bash
# prepare coil sensitivities and radial spokes
# explain the dimensions
spokes=60
nx=256

bart extract 2 0 $spokes ksp_256 ksp_256_c
bart extract 2 0 $spokes traj_256 traj_256_c

In [None]:
!head -n2 ksp_256_c.hdr

In [None]:
%%bash 

bart nufft -i traj_256_c ksp_256_c zero_filled
bart fft $(bart bitmask 0 1) zero_filled grid_ksp
bart ecalib -r20 -m1 -c0.0001 grid_ksp coilsen_esp

## Example 1: $R(x)=\|x\|^2$

In [None]:
!bart pics -i100 -R TF:{$(pwd)/l2_toy}:0.02 -d5 -e -t traj_256_c ksp_256_c coilsen_esp l2_pics_tf

In [None]:
!bart pics -l2 0.01 -e -d5 -t traj_256_c ksp_256_c coilsen_esp l2_pics

In [None]:
from utils import *
import matplotlib.pyplot as plt
fig, axis = plt.subplots(figsize=(8,4), ncols=2)
l2_pics = readcfl("l2_pics")
l2_pics_tf = readcfl("l2_pics_tf")

axis[0].imshow(abs(l2_pics), cmap='gray', interpolation='None')
axis[1].imshow(abs(l2_pics_tf), cmap='gray', interpolation='None')
axis[0].set_title("l2_pics")
axis[1].set_title("l2_pics_tf")
axis[0].axis('off')
axis[1].axis('off')

## Example 2: $R(x)=\log p(x, net(x))$ 

In [None]:
!ls prior/

In [None]:
# generate weights for density compensation
writecfl("weights", gen_weights(60, 256))           

In [None]:
!bart pics -i30 -R TF:{./prior/pixel_cnn}:8 -d5 -e -I -p weights -t traj_256_c ksp_256_c coilsen_esp w_pics_prior

In [None]:
import matplotlib.pyplot as plt
pics_prior = readcfl("w_pics_prior")
fig, axis = plt.subplots(figsize=(12,4), ncols=3)

axis[0].imshow(abs(l2_pics), cmap='gray', interpolation='None')
axis[0].set_title("l2_pics")
axis[1].imshow(abs(l2_pics_tf), cmap='gray', interpolation='None')
axis[1].set_title("l2_pics_tf")
axis[2].imshow(abs(pics_prior), cmap='gray', interpolation='None')
axis[2].set_title("prior_pics")
axis[0].axis('off')
axis[1].axis('off')
axis[2].axis('off')


In [None]:
! bash clean