In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import random
import tensorflow as tf
from networks_model import *
from data_loader import *
from error_functions import *
import datetime
from keras.callbacks import EarlyStopping
from keras.utils import multi_gpu_model
import time
from layer_custom import *
tf.__version__

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
opt = {}

opt['train_path'] = '/media/hd3/sylar/simu_data/calculated'
opt['train_img1_path'] = ['/msk_arr/msk']
opt['train_img2_path'] = ['/phs_total/phs']
opt['train_label_path'] = ['/qsm/qsm']

opt['reso'] = (0.45,0.45,1.0)
opt['patch_size'] = (64,64,64)

opt['rad'] = [5,3,1]
opt['ker'] = []
opt['ker'] = multi_smv_gen(opt)  # inital smv generation for multiple radias
opt['d'] = dipole_kernel(opt)    # inital dipole kernel generation

opt['lbd1'] = 0.1
opt['iter'] = 5
opt['thr'] = 0.05
opt['batch_size'] = 2
opt['channels'] = len(opt['rad'])+1
opt['img_shape'] = opt['patch_size'] + (opt['channels'],)
opt['in_shape'] = opt['patch_size'] + (1,)
opt['is_patch'] = True
opt['is_aug'] = True

opt['model_restored'] = False
opt['model_restored_epoch'] = 196
opt['model_total_epoch'] = 300
opt['model_save_interval'] = 1

opt['learning_rate'] = 2e-4
opt['beta_1'] = 0.9
opt['beta_2'] = 0.99
opt['loss'] = nrmse
opt['display_nums'] = [10,20,30]
opt['c_iter'] = 5 # conjugate gradient iteration number

opt['model_save_path'] = '/home/maii_station_1/Desktop/codes/SS-POCSnet/modelss_saved' 
opt['checkpoint_path'] = opt['model_save_path']+"/cp{epoch}"

if not os.path.exists(opt['model_save_path']):
    os.makedirs(log_path)

In [None]:
import re
if not opt['model_restored']:
    index = list(range(1,15030))
    random.shuffle(index)
    opt['train_index'] = index[0:15000]
    opt['test_index'] = index[15000:15030]  
    np.save(opt['model_save_path'] + '/train_index8.npy', opt['train_index'])
    np.save(opt['model_save_path'] + '/test_index2.npy', opt['test_index'])
else:
    np.save(opt['model_save_path'] + '/train_index8.npy', opt['train_index'])
    np.save(opt['model_save_path'] + '/test_index2.npy', opt['test_index'])

In [None]:
opt['train_data'] = Data_loaders(opt)
x1_train, x2_train, y_train = opt['train_data'].next([1],opt)
print(opt['train_data'].data_size)

'''
data preview
'''
plt_center(x1_train, x2_train, y_train)

In [None]:
tf.compat.v1.disable_eager_execution()

smv, l = smv_lpf_array(opt)
model = {}

model['joint'] = joint_model(smv,opt)
model['optimizer'] = Adam(opt['learning_rate'], opt['beta_1'], opt['beta_2'])

model['joint'].compile(optimizer = model['optimizer'],
                     loss = [opt['loss'],opt['loss']],
                     loss_weights = [0.5, 0.5],
                     metrics=['accuracy'])

# Create a callback that saves the model's weights
if opt['model_restored'] == True:
    model['vnet'].load_weights(opt['checkpoint_path'].format(epoch = opt['model_restored_epoch']))
    start_epoch = opt['model_restored_epoch']
else:
    start_epoch = 1

In [None]:
nr_train_im = opt['train_data'].data_size
nr_im_per_epoch = int(np.ceil(nr_train_im / opt['batch_size']) * opt['batch_size'])

start_time = time.time()
print("Start Time:" + str(start_time))

avg_epoch_cost = []
print('-----Start Training-----')
for epoch in range(start_epoch, opt['model_total_epoch']):
 
    order = np.concatenate((np.random.permutation(nr_train_im),
                                         np.random.randint(nr_train_im, size=nr_im_per_epoch - nr_train_im)))
    avg_img_cost = []
    for block_i in range(1, nr_im_per_epoch+1, opt['batch_size']):
        index = order[block_i:block_i+opt['batch_size']]
        x1_train, x2_train, y_train = opt['train_data'].next(index, opt)

        # Training
        x_in = np.concatenate((x1_train, x2_train),axis=-1)
        history = model['joint'].train_on_batch(x_in, [y_train,y_train] )
        [y_pred1, y_pred2] = model['joint'].predict(x_in)
        m_loss1 = history[0]
        m_loss2 = history[1]
        m_loss3 = history[2]
        norm_loss = np.linalg.norm(y_pred1-y_train)/np.linalg.norm(y_train)
        avg_img_cost.append(m_loss1)
        if block_i % (30 * opt['batch_size'])==1:
            # Plot the progress
            print ("[Epoch %d/%d] [Batch %d/%d] [Model loss: %f-%f-%f ; nrmse: %f]" % (epoch, opt['model_total_epoch'],
                                                                block_i, nr_train_im,
                                                                m_loss1,m_loss2,m_loss3, norm_loss))
            display_slice(3,opt['display_nums'], y_pred1,y_pred2, y_train)
    avg_epoch_cost.append(np.mean(avg_img_cost)) 
                                
    print("Epoch:", '%04d' % (epoch), "Training_cost=", "{:.5f}".format(avg_epoch_cost[-1]))
    display_error(range(start_epoch,epoch+1),avg_epoch_cost)

    # If at save interval => save models
    if epoch % opt['model_save_interval'] == 0:
        model['vnet'].save_weights(opt['checkpoint_path'].format(epoch=epoch))
            
elapsed_time = time.time() - start_time
print("Total Time:" + str(elapsed_time))