In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import random
import tensorflow as tf
from model_structures import *
from data_loader import *
from loss_func import *
import datetime
import time
from custom_layers import *
from keras.optimizers import Adam
tf.__version__

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
opt = {}

opt['train_path'] = '/home/sylar/data/Simu/bg_rm_simu_train/brain_phantom_simu/patch_simu_hemr_4/patch_calculated_msk'
opt['train_img1_path'] = ['/msk_arr/msk']
opt['train_img2_path'] = ['/phs_total/phs']
opt['train_label_path'] = ['/phs_tissue/phs']

opt['reso'] = (0.9375,0.9375,1.5)

opt['rad'] = [5,3]
opt['ker'] = []
for n in range(len(opt['rad'])):
    nx,ny,nz=round(opt['rad'][n]/opt['reso'][0]),round(opt['rad'][n]/opt['reso'][1]),round(opt['rad'][n]/opt['reso'][2])
    nx,ny,nz = max(nx,2),max(ny,2),max(nz,2)
    ky,kx,kz = np.mgrid[-nx:nx+1,-ny:ny+1,-nz:nz+1]
    k = (kx**2/nx**2+ky**2/ny**2+kz**2/nz**2<=1)
    a = k/np.sum(k)
    opt['ker'].append(a)

opt['patch_size'] = (48,48,48)

opt['lbd0'] = 100   
opt['lbd1'] = 0.1
opt['thr'] = 0.3
opt['iter'] = 5

opt['batch_size'] = 30
opt['channels'] = len(opt['rad'])+1
opt['img_shape'] = opt['patch_size'] + (opt['channels'],)
opt['in_shape'] = opt['patch_size'] + (1,)
opt['rnd_crop'] = False
opt['is_aug'] = False

opt['model_restored'] = False
opt['model_restored_epoch'] = 10
opt['model_total_epoch'] = 150
opt['model_save_interval'] = 1

opt['learning_rate'] = 2e-5
opt['beta_1'] = 0.9
opt['beta_2'] = 0.99

opt['loss'] = mse

opt['display_nums'] = [10,20,30]

opt['model_save_path'] = '/home/sylar/data/results/model1_revision/pocsnet1_model'
opt['checkpoint_path'] = opt['model_save_path']+"/cp{epoch}"

In [None]:
if opt['model_restored'] == False:
    index = list(range(1,10001))
    random.shuffle(index)
    opt['train_index'] = index[0:9500]
    opt['val_index'] = index[9500:10000]
    np.save(opt['model_save_path'] + '/train_index8.npy', opt['train_index'])
    np.save(opt['model_save_path'] + '/val_index2.npy', opt['val_index'])
else:
    opt['train_index'] = np.load(opt['model_save_path']+'/train_index8.npy')
    opt['val_index'] = np.load(opt['model_save_path']+'/val_index2.npy')

In [None]:
opt['train_data'] = Data_loaders_bgrm(opt)
x1_train, x2_train, y_train = opt['train_data'].next([1],opt)

print('training data size: '+ str(opt['train_data'].data_size))

In [None]:
'''loading data preview'''
ch,cw,cd = findcenter3d(y_train[0])
aa=cw
dd=cd

f, axarr = plt.subplots(2, 3, figsize=(20, 10))
axarr[1,0].imshow(np.flip(np.transpose(x1_train[0,:,aa,:,1].squeeze()),0))
axarr[1,0].axis('off')
axarr[0,2].imshow(np.transpose(y_train[0,:,:,dd].squeeze()))
axarr[0,2].axis('off')
axarr[0,0].imshow(np.transpose(x1_train[0,:,:,dd,1].squeeze()))
axarr[0,0].axis('off')
axarr[1,1].imshow(np.flip(np.transpose(x2_train[0,:,aa,:].squeeze()),0))
axarr[1,1].axis('off')
axarr[1,2].imshow(np.flip(np.transpose(y_train[0,:,aa,:].squeeze()),0))
axarr[1,2].axis('off')
axarr[0,1].imshow(np.transpose(x2_train[0,:,:,dd].squeeze()))
axarr[0,1].axis('off')
plt.show()

In [None]:
tf.compat.v1.disable_eager_execution()
opt['c_iter'] = 5
smv,trun = smv_iter_array(opt)
model = {}

model['vnet'] = unet_at1(opt, 1)

x = Input(shape=opt['img_shape'])
LP = tf.zeros_like(x[...,1:2])

model['CG_grad'] = cg_br_grad_model(x,LP,smv,trun,opt)

LP = model['CG_grad'](LP)
for it in range(opt['iter']):
    net_out = model['vnet'](LP)
    LP = model['CG_grad'](net_out)

model['optimizer'] = Adam(opt['learning_rate'], opt['beta_1'], opt['beta_2'])

model['combine'] = Model(inputs = x,
                     outputs = [net_out, LP],
                     name='combine_model')

model['combine'].compile(optimizer = model['optimizer'],
                     loss = [opt['loss'],opt['loss']],
                     loss_weights = [0.5, 0.5],
                     metrics=['accuracy'])

# Create a callback that saves the model's weights
if opt['model_restored'] == True:
    model['vnet'].load_weights(opt['checkpoint_path'].format(epoch = opt['model_restored_epoch']))
    start_epoch = opt['model_restored_epoch']
else:
    start_epoch = 1

In [None]:
nr_train_im = opt['train_data'].data_size
nr_im_per_epoch = int(np.ceil(nr_train_im / opt['batch_size']) * opt['batch_size'])

start_time = time.time()
print("Start Time:" + str(start_time))

avg_epoch_cost = []
print('-----Start Training-----')
for epoch in range(start_epoch, opt['model_total_epoch']):
 
    order = np.concatenate((np.random.permutation(nr_train_im),
                                         np.random.randint(nr_train_im, size=nr_im_per_epoch - nr_train_im)))
    avg_img_cost = []
    for block_i in range(1, nr_im_per_epoch+1, opt['batch_size']):
        index = order[block_i:block_i+opt['batch_size']]
        x1_train, x2_train, y_train = opt['train_data'].next(index, opt)

        # Training
        x_in = np.concatenate((x1_train, x2_train),axis=-1)

        history = model['combine'].train_on_batch(x_in, [y_train,y_train] )
        [y_pred1, y_pred2] = model['combine'].predict(x_in)
        m_loss1 = history[0]
        m_loss2 = history[1]
        m_loss3 = history[2]
        norm_loss = np.linalg.norm(y_pred1-y_train)/np.linalg.norm(y_train)
        avg_img_cost.append(m_loss1)
        if block_i % (30 * opt['batch_size'])==1:
            # Plot the progress
            print ("[Epoch %d/%d] [Batch %d/%d] [Model loss: %f-%f-%f ; nrmse: %f]" % (epoch, opt['model_total_epoch'],
                                                                block_i, nr_train_im,
                                                                m_loss1,m_loss2,m_loss3, norm_loss))
            display_slice(opt['display_nums'], y_pred1,y_pred2, y_train)
    avg_epoch_cost.append(np.mean(avg_img_cost)) 
                                
    print("Epoch:", '%04d' % (epoch), "Training_cost=", "{:.5f}".format(avg_epoch_cost[-1]))
    display_error(range(start_epoch,epoch+1),avg_epoch_cost)

    # If at save interval => save models
    if epoch % opt['model_save_interval'] == 0:
        model['vnet'].save_weights(opt['checkpoint_path'].format(epoch=epoch))
            
elapsed_time = time.time() - start_time
print("Total Time:" + str(elapsed_time))