In [1]:
import numpy as np
import argparse
import os
import sys
import torch
import torch.fft as F
from importlib import reload
from torch.nn.functional import relu
import torch.nn as nn
import torch.nn.functional as Func
import torch.optim as optim
import utils
from matplotlib import pyplot as plt
import random
import copy

from utils import kplot,mask_naiveRand,mask_filter, get_x_f_from_yfull, mnet_wrapper
from utils import mask_complete, mask_makebinary,raw_normalize, visualization
from mnet import MNet
from mask_backward_new import mask_backward, mask_eval
sys.path.insert(0,'/home/huangz78/mri/unet/')
from unet_model import UNet

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [2]:
### load a mnet
mnet = MNet(beta=1,in_channels=2,out_size=320-24, imgsize=(320,320),poolk=3)
mnetpath = '/home/huangz78/checkpoints/mnet.pth'
checkpoint = torch.load(mnetpath)
mnet.load_state_dict(checkpoint['model_state_dict'])
mnet.eval()
print('MNet loaded successfully from: ' + mnetpath)

MNet loaded successfully from: /home/huangz78/checkpoints/mnet.pth


In [3]:
### load a unet for maskbackward
# UNET = UNet(n_channels=1,n_classes=1,bilinear=True,skip=False)
# unetpath = '/home/huangz78/checkpoints/unet_'+ str(UNET.n_channels) +'.pth'
# unetpath = '/home/huangz78/checkpoints/unet_1_False.pth'

UNET = UNet(n_channels=1,n_classes=1,bilinear=False,skip=True)
unetpath = '/home/huangz78/checkpoints/unet_1_True.pth'
checkpoint = torch.load(unetpath)
UNET.load_state_dict(checkpoint['model_state_dict'])
print('Unet loaded successfully from: ' + unetpath )
UNET.train()
print('nn\'s are ready')

Unet loaded successfully from: /home/huangz78/checkpoints/unet_1_True.pth
nn's are ready


In [4]:
train_dir = '/home/huangz78/data/traindata_x.npz'
# train_sub = np.load(train_dir)['x']
train_full = torch.tensor(np.load(train_dir)['xfull'])
train_dir = '/home/huangz78/data/traindata_y.npz'
# train_sub = np.load(train_dir)['x']
yfull = torch.tensor(np.load(train_dir)['yfull'])
print('train data fft size:', yfull.shape)
print('train data size:', train_full.shape)

train data size: torch.Size([1014, 320, 320])


In [None]:
# fullmask = torch.fft.fftshift(torch.tensor(np.load(train_dir)['mask'])) # roll the input mask

test_dir = '/home/huangz78/data/testdata_x.npz'
testimg  = torch.tensor(np.load(test_dir)['x']) 
print(testimg.shape)
# test_sub  = test_sub[0:10,:,:]
# test_full = torch.tensor(np.load(test_dir)['xfull']) 
mask_greedy = np.load('/home/huangz78/data/data_gt_greedymask.npz')
mask_greedy = mask_greedy['mask'].T # this greedy mask is rolled
print(mask_greedy.shape)

In [None]:
# select an image whose greedy mask we have
# test_dir  = '/home/huangz78/data/data_gt.npz'
# test_full = torch.tensor( np.transpose(np.load(test_dir)['imgdata'],axes=(2,0,1)) )

In [26]:
def alternating_update_with_unetRecon(mnet,unet,trainfulls,yfulls=None,\
                                      maxIter_mb=20,alpha=2.8*1e-5,c=0.05, lr_mb=1e-2,\
                                      maxRep=5,lr_mn=1e-4,\
                                      epoch=1,batchsize=5,\
                                      corefreq=24,budget=56,\
                                      verbose=False,hfen=False,dtyp=torch.float,\
                                      save_cp=False):
    '''
    alpha: magnitude of l1 penalty for high-frequency mask
    mnet : the input mnet must match corefreq exactly
    '''

    DTyp = torch.cfloat if dtyp==torch.float else torch.cdouble
    dir_checkpoint = '/home/huangz78/checkpoints/'
    criterion_mnet = nn.BCEWithLogitsLoss()
    optimizer_m = optim.RMSprop(mnet.parameters(), lr=lr_mn, weight_decay=0, momentum=0)
    
    # training loop
    global_step = 0; epoch_count = 0
    batch_nums  = int(np.ceil(trainfulls.shape[0]/batchsize))
    loss_before = list([]); loss_after = list([]); loss_rand = list([]);
    while epoch_count<epoch:
        for batchind in range(batch_nums):
            batch = np.arange(batchsize*batchind, min(batchsize*(batchind+1),trainfulls.shape[0]))
            xstar = trainfulls[batch,:,:]
            if yfulls is None:
                yfull = torch.fft.fftshift(F.fftn(xstar,dim=(1,2),norm='ortho')) # y is ROLLED!
            else:
                yfull = torch.fft.fftshift(yfulls[batch,:,:],dim=(1,2))
            lowfreqmask,_,_ = mask_naiveRand(xstar.shape[1],fix=corefreq,other=0,roll=True)
            
            ########################################  
            ## (1) mask_backward
            ######################################## 
            if mnet.in_channels == 1:
                x_lf     = get_x_f_from_yfull(lowfreqmask,yfull)
                highmask = mnet(x_lf.view(batch.size,1,xstar.shape[1],xstar.shape[2]))
            elif mnet.in_channels == 2:
                y = torch.zeros((yfull.shape[0],2,yfull.shape[1],yfull.shape[2]),dtype=torch.float)
                y[:,0,lowfreqmask==1,:] = torch.real(yfull)[:,lowfreqmask==1,:]
                y[:,1,lowfreqmask==1,:] = torch.imag(yfull)[:,lowfreqmask==1,:]
                highmask = mnet(y)
            highmask_refined,unet,loss_aft,loss_bef = mask_backward(highmask,xstar,unet=unet,mnet=mnet,\
                              beta=1.,alpha=alpha,c=c,\
                              maxIter=maxIter_mb,seed=0,break_limit=np.inf,\
                              lr=lr_mb,mode='UNET',testmode='UNET',\
                              budget=budget,normalize=True,\
                              verbose=verbose,dtyp=torch.float,\
                              hfen=False,return_loss_only=False)        
            ########################################  
            ## (2) update mnet
            ########################################  
            mnet.train()
#             unet.eval()
            rep = 0
            while rep < maxRep:
                if   mnet.in_channels == 1:
                    mask_pred  = mnet(x_lf.view(batch.size,1,xstar.shape[1],xstar.shape[2]))
                elif mnet.in_channels == 2:
                    mask_pred  = mnet(y)

                train_loss = criterion_mnet(mask_pred,highmask_refined)
                optimizer_m.zero_grad()
                # optimizer step wrt unet parameters?
                train_loss.backward()
                optimizer_m.step()
                rep += 1
            mnet.eval()
            ########################################  
            ## (3) check mnet performance: does it beat random sampling?
            ########################################
            breakpoint()
            mask_rand,_,_ = mask_naiveRand(xstar.shape[1],fix=corefreq,other=budget,roll=True)
            mask_rand     = mask_rand.repeat(xstar.shape[0],1)
            randqual      = mask_eval(mask_rand,xstar,mode='UNET',UNET=UNET,dtyp=dtyp,hfen=hfen)
            
            iterprog = f'[{epoch_count+1}/{epoch}][{min(batchsize*(batchind+1),trainfulls.shape[0])}/{trainfulls.shape[0]}]'
            print(iterprog + f', quality of random   mask : {randqual}') 
            print(iterprog + f', quality of old mnet mask : {loss_bef}')
            print(iterprog + f', quality of refined  mask : {loss_aft}')
                        
            loss_rand.append(randqual); loss_after.append(loss_aft); loss_before.append(loss_bef)
            
            if (global_step%10==0) and save_cp:
                torch.save({'model_state_dict': mnet.state_dict()}, dir_checkpoint + 'mnet_split_trained.pth')
                torch.save({'model_state_dict': unet.state_dict()}, dir_checkpoint + 'unet_split_trained.pth')
                print(f'\t Checkpoint saved at epoch {epoch_count}, iter {global_step + 1}!')
                filepath = '/home/huangz78/checkpoints/alternating_update_error_track.npz'
                np.savez(filepath,loss_rand=loss_rand,loss_after=loss_after,loss_before=loss_before,freqs=(corefreq,budget))
            global_step += 1
        epoch_count+= 1
#     return mnet, unet

In [27]:
import mask_backward_new
reload(mask_backward_new)
from mask_backward_new import mask_backward

In [28]:
alternating_update_with_unetRecon(mnet,UNET,train_full,\
                                  maxIter_mb=15,alpha=10**(-5.5),c=5e-2,lr_mb=1e-2,\
                                  maxRep=2,lr_mn=1e-4,\
                                  corefreq=24,budget=56,\
                                  epoch=1,batchsize=5,\
                                  verbose=True,save_cp=False)

loss of the input mask:  0.1745380014181137
Iter 12, rows added: 0.0, rows reducted: 1.4

return at Iter ind:  15
loss of returned mask: 0.13238006830215454
samp. ratio: 0.275, Recon. rel. err: 0.13478673994541168 

> [0;32m<ipython-input-26-a441eb29e7c9>[0m(73)[0;36malternating_update_with_unetRecon[0;34m()[0m
[0;32m     71 [0;31m            [0;31m########################################[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 73 [0;31m            [0mmask_rand[0m[0;34m,[0m[0m_[0m[0;34m,[0m[0m_[0m [0;34m=[0m [0mmask_naiveRand[0m[0;34m([0m[0mxstar[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m,[0m[0mfix[0m[0;34m=[0m[0mcorefreq[0m[0;34m,[0m[0mother[0m[0;34m=[0m[0mbudget[0m[0;34m,[0m[0mroll[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0

ipdb> n
[1/1][5/1014], quality of refined  mask : 0.13238006830215454
> [0;32m<ipython-input-26-a441eb29e7c9>[0m(82)[0;36malternating_update_with_unetRecon[0;34m()[0m
[0;32m     80 [0;31m            [0mprint[0m[0;34m([0m[0miterprog[0m [0;34m+[0m [0;34mf', quality of refined  mask : {loss_aft}'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     81 [0;31m[0;34m[0m[0m
[0m[0;32m---> 82 [0;31m            [0mloss_rand[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0mrandqual[0m[0;34m)[0m[0;34m;[0m [0mloss_after[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0mloss_aft[0m[0;34m)[0m[0;34m;[0m [0mloss_before[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0mloss_bef[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m[0;34m[0m[0m
[0m[0;32m     84 [0;31m            [0;32mif[0m [0;34m([0m[0mglobal_step[0m[0;34m%[0m[0;36m10[0m[0;34m==[0m[0;36m0[0m[0;34m)[0m [0;32mand[0m [0msave_cp[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


BdbQuit: 

In [None]:
reload(utils)
import mask_backward_new
reload(mask_backward_new)
from mask_backward_new import mask_backward
from utils import raw_normalize, visualization

## arxiv

In [None]:
def alternating_update_with_unetRecon(mnet,unet,trainfulls,testimg,mask_init,mask_init_full=True,\
                                      maxIter_mb=50,evalmode='unet',alpha=2.8*1e-5,c=0.05, Lambda=1e-4,\
                                      lr_mb=1e-4,lr_mn=1e-4,maxRep=5,epoch=1,\
                                      corefreq=24,budget=24,plot=False,verbose=False,mask_greedy=None,\
                                      change_initmask=True,validate_every=10,dtyp=torch.float,\
                                      save_cp=False):
    '''
    alpha: magnitude of l1 penalty for high-frequency mask
    mnet : the input mnet needs to coordinate exactly with corefreq
    '''
    if mask_init_full:
        fullmask = torch.tensor(mask_init).clone()
        highmask = mask_filter(fullmask,base=corefreq,roll=True)
    else:
        fullmask = mask_complete(torch.tensor(mask_init),trainfulls.shape[1],rolled=True,dtyp=dtyp)
        highmask = torch.tensor(mask_init).clone()
    DTyp = torch.cfloat if dtyp==torch.float else torch.cdouble
    dir_checkpoint = '/home/huangz78/checkpoints/'
    criterion_mnet = nn.BCEWithLogitsLoss()
    optimizer_m = optim.RMSprop(mnet.parameters(), lr=lr_mn, weight_decay=0, momentum=0)
    # optimizer_u = ......
    
    unet_eval = UNet(n_channels=1,n_classes=1,bilinear=True,skip=False)
    unet_eval = copy.deepcopy(unet)
    unet_eval.eval()
    # training loop
    global_step = 0

    randqual = []; mnetqual = []
    randspar = []; mnetspar = []
    if mask_greedy is not None:
        greedyqual = []
        greedyspar = np.sum(mask_greedy[0,:])/trainfulls.shape[1]
    else:
        greedyqual = None; greedyspar = None
    epoch_count = 0
    while epoch_count<epoch:
        for xstar in trainfulls:
            xstar = torch.tensor(xstar,dtype=dtyp)
            yfull = torch.fft.fftshift(F.fftn(xstar,dim=(0,1),norm='ortho')) # y is ROLLED!
            lowfreqmask,_,_ = mask_naiveRand(xstar.shape[0],fix=corefreq,other=0,roll=True)
            x_lf            = get_x_f_from_yfull(lowfreqmask,yfull)
            ########################################  
            ## (1) mask_backward
            ########################################        
            if change_initmask and global_step>0: # option 2: highmask = mask_pred from step (2)
                highmask = mnet(x_lf.view(1,1,xstar.shape[0],xstar.shape[1])).view(-1)
            highmask_refined,unet = mask_backward(highmask,xstar_batch,unet=unet, mnet=mnet,\
                              beta=1.,alpha=alpha,c=c,\
                              maxIter=maxIter_mb,seed=0,break_limit=maxIter_mb*3//5,\
                              lr=lr_mb,mode='UNET',budget=budget,normalize=True,\
                              verbose=verbose,dtyp=torch.float)        
            ########################################  
            ## (2) update mnet
            ########################################        
            mnet.train()
            unet.eval()
            rep = 0
            while rep < maxRep:
                mask_pred  = mnet(x_lf.view(1,1,xstar.shape[0],xstar.shape[1]))
    #             mask_pred_full = mask_complete(mask_pred.view(-1),xstar.shape[0],rolled=True,dtyp=dtyp) # ??
    #             x_lf_new   = get_x_f_from_yfull(mask_pred_full,yfull).view(1,1,xstar.shape[0],xstar.shape[1]) # ??
    #             x_unet     = unet(x_lf_new) # ??
                train_loss = criterion_mnet(mask_pred,highmask_refined.view(mask_pred.shape))
                optimizer_m.zero_grad()
                # optimizer step wrt unet parameters ?
                train_loss.backward()
                optimizer_m.step()
                rep += 1
            mnet.eval()             
            ########################################  
            ## (3) check mnet performance: does it beat random sampling?
            ########################################
            if (global_step%validate_every==0) or (global_step==trainfulls.shape[0]-1):
                randqual_tmp = 0; mnetqual_tmp = 0; greedyqual_tmp = 0
                randspar_tmp = 0; mnetspar_tmp = 0
                imgind = 0
                for img in testimg:
                    x_test_lf     = img
                    mask_test     = mnet_wrapper(mnet,x_test_lf,budget,img.shape,dtyp=dtyp)
                    mask_rand,_,_ = mask_naiveRand(img.shape[0],fix=corefreq,other=mask_test.sum().item()-corefreq,roll=True)              
                    
                    randqual_img  = mask_eval(mask_rand,img,UNET=unet_eval)
                    mnetqual_img  = mask_eval(mask_test,img,UNET=unet) # UNET = unet_eval               
                    randqual_tmp += randqual_img
                    mnetqual_tmp += mnetqual_img                
                    if verbose:
                        print('Quality of random mask : ', randqual_img) 
                        print('Quality of mnet   mask : ', mnetqual_img)

                    ### compute sampling ratio of generated masks
                    randspar_img  = mask_rand.sum().item()/img.shape[0]
                    mnetspar_img  = mask_test.sum().item()/img.shape[0]
                    randspar_tmp += randspar_img
                    mnetspar_tmp += mnetspar_img
                    if mask_greedy is not None:
                        greedyqual_img = mask_eval(mask_greedy[imgind,:],img,mode='sigpy',Lambda=Lambda) # UNET=unet_eval
                        greedyqual_tmp += greedyqual_img
                        if verbose:
                            print('Quality of greedy mask : ', greedyqual_img)
                            print(f'sparsity of random mask: {randspar_img},mnet mask: {mnetspar_img}, \
                                    greedy mask: {greedyspar}\n')
                    else:
                        if verbose:
                            print(f'sparsity of random mask: {randspar_img},mnet mask: {mnetspar_img}\n')
                    imgind += 1
                randqual.append( randqual_tmp/testimg.shape[0] )
                mnetqual.append( mnetqual_tmp/testimg.shape[0] )
                if mask_greedy is not None:
                    greedyqual.append( greedyqual_tmp/testimg.shape[0] )
                randspar.append( randspar_tmp/testimg.shape[0] )
                mnetspar.append( mnetspar_tmp/testimg.shape[0] )
                if plot:
                    try:
                        visualization(randqual,mnetqual,greedyqual=greedyqual,\
                                 randspar=randspar,mnetspar=mnetspar,greedyspar=greedyspar*np.ones(len(greedyqual)))
                    except Exception:
                        visualization(randqual,mnetqual,randspar=randspar,mnetspar=mnetspar)
                if save_cp:
                    torch.save({'model_state_dict': mnet.state_dict()}, dir_checkpoint + 'mnet_split_trained.pth')
                    torch.save({'model_state_dict': unet.state_dict()}, dir_checkpoint + 'unet_split_trained.pth')
                    print(f'\t Checkpoint saved at epoch {epoch_count}, iter {global_step + 1}!')
                    filepath = '/home/huangz78/checkpoints/alternating_update_error_track.npz'
                    np.savez(filepath,randqual=randqual,mnetqual=mnetqual,greedyqual=greedyqual,randspar=randspar,mnetspar=mnetspar)
            global_step += 1
        epoch_count+= 1
    # return mnet, unet