## Initialization of key variables


In [None]:
import sys
sys.path.append('/home') # set system path

import importlib
import sigpy as sp
import torch
import torch.nn as nn
from espirit.espirit import espirit, fft
import pickle
import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np
import dataOpNewKbnufft
from dataOpNewKbnufft import dataAndOperators
from generator import generatorNew
from optimize_gen_sub import optimize_generator
from latentVariable import latentVariableNew
from ptflops import get_model_complexity_info
from showVideo import showImages
from moviepy.editor import VideoClip
from moviepy.video.io.bindings import mplfig_to_npimage

gpu=torch.device('cuda:0')
params = {'name':'parameters',
     'directory':'',
     'device':gpu,
     'filename':"",
     'dtype':torch.float,     
     'fastMode':True,       #Store variables on GPU; fast, but consumes memory. Track torch.cuda.memory_allocated
     'verbose':True,       #print messages
     'im_size':(168,168),   #imaege size
     'nintlPerFrame':3,    # interleaves per frame
      'nintlvsToDelete': 0, #initial interleaves to delete to minimize transients
     'nFramesDesired':900,  # number of frames in the reconstruction
     'slice':(0,1,2,3,4,5,6,7,8,9),            # slice of the series to process, note that it begins with 0 in python
     'factor':1,           # scale image by 1/factor to save compute time
     'nBatch':5,
     'gen_base_size': 60,   # base number of filters
     'gen_reg': 0.0001,       # regularization penalty on generator
     'virtual_coils': 8,      # number of virtual coils used  for reconstruction
     'mask_size': 0.9,        # radius of the circle in the center used for coil combination; 1 selects the whole image
     'coilEst': 'espirit',      # espirit/jsense
     'siz_l':20} # number of latent parameters 

params['filename']  = '/home/data' # set data file path


## Training the network assuming latent variables to be fixed. 

The training proceeds in two levels.

During the first round of training, the latent variables are assumed to be fixed. This approach allows the learning a good initial network, which is used in the second and final round



In [None]:
import dataOpNewKbnufft
from dataOpNewKbnufft import dataAndOperators
importlib.reload(dataOpNewKbnufft)

import optimize_gen_sub
from optimize_gen_sub import optimize_generator
importlib.reload(optimize_gen_sub)

# Reading and pre-processing the data and parameters
dop = dataAndOperators(params)

# Initializaition of the generator
G = generatorNew(params)
G.weight_init()
G.to(torch.float32).cuda(gpu)

# Initialization of the latent variables
alpha =[0.8,.15,0.2,0.1,.5,0.1,.9,0.2,0.1,.5,0.1,.9,0.2,0.1,.5,.6,.4,.2,.1,.3]

alpha = torch.FloatTensor(alpha).to(gpu)
z = latentVariableNew(params,init='ones',alpha=alpha,klreg=0)

# Initial Training
params['lr_g'] = 2e-4
params['lr_z'] = 0e-4 #0

train_epoch=30         

G,z,train_hist,SER1,epoch0 = optimize_generator(dop,G,z,params,train_epoch=train_epoch) 
print(epoch0)
G_olda = G.state_dict()
z_olda = z.z_

import  os
pathname =  params['filename'].replace('.mat','_'+str(params['gen_base_size'])+'d/weights_onlyGenerator_'+str(params['coilEst']))
pathname = pathname+'_'+str(train_epoch)+'_epoch'+str(params['slice'])+'_'+str(params['nintlPerFrame'])+'arms_'+str(params['siz_l'])+'latVec'+str(params['nFramesDesired'])+'frms'
if not(os.path.exists(pathname)):
    os.makedirs(pathname)
path = os.path.join(pathname, 'net_{}_epoch{}_gloss_{}.pth'.format('onlyGEN'+str(params['coilEst']), train_epoch,train_hist['G_losses'][-1]))
torch.save({'G_olda':G.state_dict(),'z_olda':z.z_}, path)
print(path)

# Training with latent variables

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (10,10)

import latentVariable
from latentVariable import latentVariableNew
importlib.reload(latentVariable)

import optimize_gen_sub
from optimize_gen_sub import optimize_generator
importlib.reload(optimize_gen_sub)

torch.cuda.empty_cache()

checkpoint = torch.load(path)

G1 = generatorNew(params)
G1.to(torch.float32).cuda(gpu)
G1.load_state_dict(checkpoint['G_olda']) #G_oldi

z.z_ = checkpoint['z_olda'] #z_oldi

# Final training

params['lr_g'] = 2e-4
params['lr_z'] = 4e-3   #4e-3
alpha = [0.8,.15,0.2,0.1,.5,0.1,.9,0.2,0.1,.5,0.1,.9,0.2,0.1,.5,.6,.4,.2,.1,.3]

alpha = torch.FloatTensor(alpha).to(gpu)
z1 = latentVariableNew(params,z_in=z,alpha=alpha,klreg=0.000)

final_epoch=80
G1,z1,train_hist,SER1,epoch1 = optimize_generator(dop,G1,z1,params,train_epoch=final_epoch) 
print(epoch1)

In [None]:
#train_epoch=30
final_epoch=         epoch1
import imageio
from matplotlib.transforms import Bbox
import  os

im_size = params["im_size"]

if isinstance(params['slice'], int):
    nsl = 1
else:
    nsl = len(params['slice'])

TR=6.00e-3
frames_per_second = 1./(params['nintlPerFrame']*TR);
my_dpi = 100 # Good default - doesn't really matter  

images = np.zeros((params["nFramesDesired"],im_size[0],im_size[1]*nsl))

for i in range(params["nFramesDesired"]):
    image =  G1(z1.z_[i:i+1,:,:,:,7]).squeeze(0).squeeze(0).detach().abs().cpu().numpy()
    temp = image[...,0]
    for j in np.arange(1,nsl):
        #print(j)
        temp = np.concatenate((temp,image[...,j]),axis=1) 
    images[i] = temp
maxval = np.max(images)

gifs_tobe=[]
dirname =  params['filename'].replace('.mat','_'+str(params['gen_base_size'])+'d/results_'+str(params['coilEst']))
dirname = dirname+'_'+str(params['slice'])+'_'+str(params['nintlPerFrame'])+'arms_'+str(params['siz_l'])+'latVec'+str(params['nFramesDesired'])+'frms'+'_'+str(train_epoch)+'_Geph'+str(final_epoch)+'_feph_Sl5KL0.00zR1'
dn=dirname
if not(os.path.exists(dirname)):
    os.makedirs(dirname)

for k in range(params["nFramesDesired"]):
    image1=images[k]
    fig, ax = plt.subplots(1, figsize=((im_size[1]*nsl)/my_dpi, im_size[0]/my_dpi), dpi=my_dpi)
    ax.set_position([0,0,1,1])
    plt.imshow((image1), cmap='gray')
    ax.axis('off')
    img_name = dirname+'/frame_' + str(k) + '.png'
    fig.savefig(img_name,bbox_inches=Bbox([[0,0],[(im_size[1]*nsl)/my_dpi,im_size[0]/my_dpi]]),dpi=my_dpi)
    plt.close()
    gifs_tobe.append(imageio.imread(img_name))
imageio.mimsave(dn+'_'+str(train_epoch)+'_Gepoch'+'.gif', gifs_tobe, fps=frames_per_second)
print('Frame rate= %.4f fps; TR=%f sec; Narms= %d.' %(frames_per_second, TR, params['nintlPerFrame']))
print('\nResults saved in %s' %dirname)
#
np.save(dn, images)
#

In [None]:
import numpy as np
import matplotlib.pyplot as plt
siz_l=params['siz_l']
nsl = len(params['slice'])
legendstring = np.array2string(np.arange(siz_l))
legendstring = legendstring[1:-1]
fig,ax = plt.subplots(nsl,1)
fig.set_figheight(20)
fig.set_figwidth(15)
lat_vec=np.zeros((params["nFramesDesired"],params["siz_l"],nsl))

if(nsl==1):
            ax = np.expand_dims(ax,0)                
for sl in range(nsl):
             
              
              #ax[sl,0].imshow(abs(test_image1),cmap='gray')
              temp = z1.z_[...,sl].data.squeeze().cpu().numpy()
              #print(temp.shape)
              lat_vec[...,sl] = temp
              ax[sl].plot(temp[...,0:siz_l])
              ax[sl].legend(legendstring,loc='best')          
plt.pause(0.00001)
np.save(dirname+'_latent_vectors', lat_vec)
print(lat_vec.shape)