In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import random
import tensorflow as tf
from networks_model import *
from data_loader import *
from error_functions import *
import datetime
from keras.callbacks import EarlyStopping
from keras.utils import multi_gpu_model
import scipy.io as sio
from QSM_func import *

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]='1'

In [None]:
opt = {}
opt['datatype'] = 'cos' # 'cos' for COSMOS / 'rc2' for simulated data using reconstruction challenge 2

'''in vivo'''
if opt['datatype'] == 'cos':
    opt['train_path'] = '/media/hd1/sylar/data'
    opt['train_img1_path'] = '/msk_arr'
    opt['train_img2_path'] = '/phs_unwrap_total'
    opt['train_label_path'] = '/chi_cosmos'
    opt['patch_size'] = (160,160,160)
    opt['reso'] = (1.06,1.06,1.06)
    opt['out_dir'] = '/home/sylar/data/invivo/data/single-step/'

if opt['datatype'] == 'rc2':
    opt['train_path'] = '/home/sylar/data/invivo/RC2/DGM_HR/'
    opt['train_img1_path'] = 'msk_arr'
    opt['train_img2_path'] = 'unphs'
    opt['train_label_path'] = 'Chi'
    opt['patch_size'] = (256,320,320)
    opt['reso'] = (0.64,0.64,0.64)
    opt['out_dir'] = '/home/sylar/data/invivo/RC2/DGM_HR/' 

if opt['datatype'] == 'phan':
    opt['train_path'] = '/home/sylar/data/Simu/simu_RC2/test_data'
    opt['train_img1_path'] = '/cmsk_arr'
    opt['train_img2_path'] = '/Cphs_total'
    opt['train_label_path'] = '/Cchi'
    opt['patch_size'] = (336,416,160)
    opt['reso'] = (0.45,0.45,1)
    opt['out_dir'] = '/home/sylar/data/Simu/simu_RC2/test_data/pred/'
    index = sio.loadmat( '/home/sylar/data/invivo/swi_index1.mat')
    opt['train_index'] = index['swi_index'][16:17,:]
    
if opt['datatype'] == 'phan_lr':
    opt['train_path'] = '/home/sylar/data/Simu/simu_RC2/test_data/low_reso'
    opt['train_img1_path'] = '/L0p7msk_arr'
    opt['train_img2_path'] = '/L0p7phs_total'
    opt['train_label_path'] = '/L0p7chi'
    opt['patch_size'] = (336,416,128)
    opt['reso'] = (0.45,0.45,1.5)
    opt['out_dir'] = '/home/sylar/data/Simu/simu_RC2/test_data/pred/'
    index = sio.loadmat( '/home/sylar/data/invivo/swi_index1.mat')
    opt['train_index'] = index['swi_index'][16:21,:]
    

opt['rad'] = [3]
opt['ker'] = []
opt['ker'] = multi_smv_gen(opt)

opt['lbd1'] = 0.1
opt['iter'] = 5
opt['batch_size'] = 1
opt['img_shape'] = opt['patch_size'] + (len(opt['rad'])+1,)
opt['in_shape'] = opt['patch_size'] + (1,)
opt['is_patch'] = True

opt['model_restored_epoch'] = 111

opt['loss'] = nrmse
opt['display_nums'] = [70,50,100]
opt['model_save_path'] = '/home/maii_station_1/Desktop/codes/SS-POCSnet/modelss_saved' 

opt['checkpoint_path'] = opt['model_save_path']+"/cp{epoch}"

In [None]:
s = 1
i = 11
j = 61
opt['index1'] = list(range(s,i))
opt['index2'] = list(range(1,j))

In [None]:
opt['test_data'] = Data_loaders_invivo(opt)
x1_test, x2_test, y_test = opt['test_data'].next([0],opt)
print(opt['test_data'].data_size)

In [None]:
'''data preview'''
plt_center(x1_test, x2_test, y_test)

In [None]:
model = {}
model['vnet'] = base_unet(opt, 1)
model['vnet'].load_weights(opt['checkpoint_path'].format(epoch = opt['model_restored_epoch']))

In [None]:
nr_im = opt['test_data'].data_size
start_time = datetime.datetime.now()
opt['lbd1'] = 0.1
opt['lbd0'] = 0
opt['c_iter'] = 5
opt['iter'] = 5
opt['thr']  = 0.001
opt['rad'] = [5,3,1]
opt['ker'] = []
opt['ker'] = multi_smv_gen(opt)
opt['d'] = dipole_kernel(opt)

nr_im_per_epoch = int(np.ceil(nr_im/ opt['batch_size']) * opt['batch_size'])
if nr_im < opt['batch_size']:
    order = list(range(nr_im)) + [nr_im-1]*(opt['batch_size']-nr_im)
else:
    order = list(range(nr_im)) + list(range(nr_im_per_epoch - nr_im))
m=s-1
n=0
avg_img_cost = []

smv, l = smv_lpf_array(opt)

for block_i in range(0, opt['test_data'].data_size, opt['batch_size']):
    if block_i % (j-1) ==0:
        m+=1
    n+=1
    if n>(j-1):
        n=1
    indices = order[block_i:block_i + opt['batch_size']]
    x1_test, x2_test, y_test = opt['test_data'].next(indices,opt)
    x = np.concatenate((x1_test, x2_test),axis=-1)
    
    print('-----subject-'+str(m)+'-----'+str(n))

    L = tf.zeros_like(x[...,1:2])
    msk = L
    for m in range(len(opt['rad'])):
        msk = msk + x[...,m:m+1]

    x_k = conjgrad_sst_opt(x[...,0:len(opt['rad'])], L,x[...,len(opt['rad']):len(opt['rad'])+1],L,smv,l,opt)
    for it in range(opt['iter']):
        net_out = model['vnet'].predict(x_k.numpy())
        x_k = conjgrad_sst_opt(x[...,0:len(opt['rad'])],x_k,x[...,len(opt['rad']):len(opt['rad'])+1],net_out,d,smv,l,opt)

    m_loss = np.linalg.norm(x_k*msk-y_test*msk)/np.linalg.norm(y_test*msk)

    # Plot the progress
    print ("[Batch %d/%d] [Model loss: %f/%f]" % (block_i+1, nr_im_per_epoch, m_loss, m_loss))
    avg_img_cost.append(m_loss)

    elapsed_time = datetime.datetime.now() - start_time
    print(elapsed_time)

    s = 1 # scaling factor

    if opt['datatype'] == 'phan' or opt['datatype'] == 'phan_lr':
        save_nii(msk.numpy().squeeze(), opt['reso'],  opt['out_dir'], 'msk_sslmic_1r_0p7'+str(opt['train_index'][block_i, 0])+'-'+str(opt['train_index'][block_i, 1]))
        save_nii((x_k.numpy()*msk.numpy()/s).squeeze(), opt['reso'],  opt['out_dir'], 'qsm_sslmic_0p7xk'+str(opt['train_index'][block_i, 0])+'-'+str(opt['train_index'][block_i, 1]))
        save_nii((net_out*msk.numpy()/s).squeeze(), opt['reso'],  opt['out_dir'], 'qsm_sslmic_0p7net'+str(opt['train_index'][block_i, 0])+'-'+str(opt['train_index'][block_i, 1]))
        save_nii((y_test*msk.numpy()/s).squeeze(), opt['reso'],  opt['out_dir'], 'qsm_sslmic_0p7truth'+str(opt['train_index'][block_i, 0])+'-'+str(opt['train_index'][block_i, 1]))

    if opt['datatype'] == 'rc2' or opt['datatype'] == 'phan_e':
        save_nii((net_out*msk.numpy()/s).squeeze(), opt['reso'],  opt['out_dir'], 'simu_sslmic_0.1net')
        save_nii((x_k.numpy()*msk.numpy()/s).squeeze(), opt['reso'],  opt['out_dir'], 'simu_sslmic_0.1xk')
        save_nii((y_test*msk.numpy()/s).squeeze(), opt['reso'],  opt['out_dir'], 'simu_sslmic_0.1truth')
        
    if opt['datatype'] == 'cos':
        save_nii((net_out*msk.numpy()/s).squeeze(), opt['reso'],  opt['out_dir'], 'cosmos_psslmic_trun_0.1net')
        save_nii((x_k.numpy()*msk.numpy()/s).squeeze(), opt['reso'],  opt['out_dir'], 'cosmos_psslmic_trun_0.1xk')

    display_slice(0, opt['display_nums'], net_out/s, x_k.numpy()/s, y_test/s)
epoch_loss = np.mean(avg_img_cost)

print("Testing_cost=", "{:.5f}".format(epoch_loss))
        