In [1]:
import numpy as np
import argparse
import os
import sys
import torch
import torch.fft as F
from importlib import reload
from torch.nn.functional import relu
import torch.nn as nn
import torch.nn.functional as Func
import torch.optim as optim
import utils
from matplotlib import pyplot as plt
import random
import copy

from utils import kplot,mask_naiveRand,mask_filter, get_x_f_from_yfull
from mnet import MNet

from mask_backward_new import mask_backward, mask_eval
from utils import mask_complete , mask_makebinary,kplot, mask_filter, mask_makebinary,raw_normalize

sys.path.insert(0,'/home/huangz78/mri/unet/')
from unet_model import UNet

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [2]:
# load a mnet
mnet = MNet(out_size=320-24)
checkpoint = torch.load('/home/huangz78/checkpoints/mnet.pth')
mnet.load_state_dict(checkpoint['model_state_dict'])
print('MNet loaded successfully from: ' + '/home/huangz78/checkpoints/mnet.pth')

MNet loaded successfully from: /home/huangz78/checkpoints/mnet.pth


In [3]:
# load a unet for maskbackward
UNET =  UNet(n_channels=1,n_classes=1,bilinear=True,skip=False)
checkpoint = torch.load('/home/huangz78/checkpoints/unet_'+ str(UNET.n_channels) +'.pth')
UNET.load_state_dict(checkpoint['model_state_dict'])
print('Unet loaded successfully from : ' + '/home/huangz78/checkpoints/unet_'+ str(UNET.n_channels) +'.pth' )

Unet loaded successfully from : /home/huangz78/checkpoints/unet_1.pth


In [4]:
train_dir = '/home/huangz78/data/traindata_x.npz'
# train_sub = np.load(train_dir)['x']
train_full = np.load(train_dir)['xfull']

In [5]:
test_dir = '/home/huangz78/data/testdata_x.npz'
# test_sub  = torch.tensor(np.load(test_dir)['x'])     ; test_sub  = test_sub[0:10,:,:]
test_full = torch.tensor(np.load(test_dir)['xfull']) 

In [6]:
# select an image whose greedy mask we have
test_dir  = '/home/huangz78/data/data_gt.npz'
test_full = torch.tensor( np.transpose(np.load(test_dir)['imgdata'],axes=(2,0,1)) )
mask_greedy = np.load('/home/huangz78/data/data_gt_greedymask.npz')
mask_greedy = mask_greedy['mask'].T
print(mask_greedy.shape)

(199, 320)


In [7]:
fullmask  = torch.tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1.,
       1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
       0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
       0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
       0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
       0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
       0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 1., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],dtype=torch.float)

In [10]:
def alternating_update_with_unetRecon(mnet,unet,trainfulls,testimg,mask_init,mask_init_full=True,\
                                      maxIter_mb=50,evalmode='unet',alpha=2.8*1e-5,c=0.05,\
                                      lr_mb=1e-4,lr_mn=1e-4,maxRep=5,\
                                      corefreq=24,budget=56,plot=False,verbose=False,mask_greedy=None,\
                                      change_initmask=True,validate_every=10,dtyp=torch.float):
    '''
    alpha: magnitude of l1 penalty for high-frequency mask
    '''
    if mask_init_full:
        fullmask = torch.tensor(mask_init).clone()
        highmask = mask_filter(fullmask,base=corefreq,roll=True)
    else:
        fullmask = mask_complete(torch.tensor(mask_init),xstar.shape[0],rolled=True,dtyp=dtyp)
        highmask = torch.tensor(mask_init).clone()
    DTyp = torch.cfloat if dtyp==torch.float else torch.cdouble
    criterion_mnet = nn.BCEWithLogitsLoss()
    
    optimizer_m = optim.RMSprop(mnet.parameters(), lr=lr_mn, weight_decay=0, momentum=0)
    # optimizer_u = ......
    
    unet_eval = UNet(n_channels=1,n_classes=1,bilinear=True,skip=False)
    unet_eval = copy.deepcopy(unet)
    unet_eval.eval()
    # training loop
    global_step = 0
    qual_len = trainfulls.shape[0]//validate_every + np.ceil(trainfulls.shape[0]%validate_every)
    randqual = np.zeros((qual_len)); mnetqual = np.zeros((qual_len))
    randspar = np.zeros((qual_len)); mnetspar = np.zeros((qual_len))
    if mask_greedy is not None:
        greedyqual = np.zeros((qual_len))
    
    for xstar in trainfulls:
        xstar = torch.tensor(xstar,dtype=dtyp)
        yfull = torch.fft.fftshift(F.fftn(xstar,dim=(0,1),norm='ortho')) # y is ROLLED!
        lowfreqmask,_,_ = mask_naiveRand(xstar.shape[0],fix=corefreq,other=0,roll=True)
        x_lf            = get_x_f_from_yfull(lowfreqmask,yfull)
        ########################################  
        ## (1) mask_backward
        ########################################        
        if change_initmask and global_step>0: # option 2: highmask = mask_pred from step (2)
            highmask = mnet(x_lf.view(1,1,xstar.shape[0],xstar.shape[1])).view(-1)
        highmask_refined,unet = mask_backward(highmask,xstar,unet=unet, mnet=mnet,\
                          beta=1.,alpha=alpha,c=c,\
                          maxIter=maxIter_mb,seed=0,break_limit=maxIter_mb//2,\
                          lr=lr_mb,mode='UNET',budget=budget,normalize=True,\
                          verbose=verbose,dtyp=torch.float)
        
        ########################################  
        ## (2) update mnet
        ########################################
        
        mnet.train()
        unet.eval()
        rep = 0
        while rep < maxRep:
            mask_pred  = mnet(x_lf.view(1,1,xstar.shape[0],xstar.shape[1]))
            mask_pred_full = mask_complete(mask_pred.view(-1),xstar.shape[0],rolled=True,dtyp=dtyp)
            x_lf_new   = get_x_f_from_yfull(mask_pred_full,yfull).view(1,1,xstar.shape[0],xstar.shape[1])
            x_unet     = unet(x_lf_new)
            train_loss = criterion_mnet(mask_pred,highmask_refined.view(mask_pred.shape))
            optimizer_m.zero_grad()
            # optimizer step wrt unet parameters ?
            train_loss.backward()
            optimizer_m.step()
            rep += 1
        mnet.eval()
        
        
        ########################################  
        ## (3) check mnet performance: does it beat random sampling?
        ########################################
        if (global_step%validate_every==0) or (global_step==trainfulls.shape[0]-1):
            randqual_tmp = 0; mnetqual_tmp = 0; greedyqual_tmp = 0
            randspar_tmp = 0; mnetspar_tmp = 0
            for img in testimg:
                img    = torch.tensor(img,dtype=dtyp) # now we test on 1 image only.
                yfull_test = torch.fft.fftshift(F.fftn(img,dim=(0,1),norm='ortho'))
                x_test_lf  = get_x_f_from_yfull(lowfreqmask,yfull_test)

                highmask_tmp  = torch.sigmoid( mnet( x_test_lf.view(1,1,testimg.shape[0],testimg.shape[1]) ).view(-1) )   
                highmask_test = mask_makebinary( raw_normalize(highmask_tmp,budget) , sigma=False )

                mask_rand,_,_ = mask_naiveRand(xstar.shape[0],fix=corefreq,other=highmask_test.sum(),roll=True)
                mask_test     = mask_complete(highmask_test.view(-1),xstar.shape[0],rolled=True,dtyp=dtyp)
                
                randqual_img  = mask_eval(mask_rand,testimg,UNET=unet_eval)
                mnetqual_img  = mask_eval(mask_test,testimg,UNET=unet_eval)
                
                randqual_tmp += randqual_img
                mnetqual_tmp += mnetqual_img
                
                print('Quality of random mask : ', randqual_img) # UNET=unet_eval
                print('Quality of mnet   mask : ', mnetqual_img) # UNET=unet_eval
                if mask_greedy is not None:
                    greedyqual_img = mask_eval(mask_greedy,testimg,UNET=unet_eval)
                    print('Quality of greedy mask : ', )
                    greedyqual_tmp += greedyqual_img
                    randspar_tmp += mask_rand.sum().item()/xstar.shape[0]
                    mnetspar_tmp += mask_test.sum().item()/xstar.shape[0]
                    print(f'sparsity of random mask: {mask_rand.sum().item()/xstar.shape[0]}, mnet mask: {mask_test.sum().item()/xstar.shape[0]}, greedy mask: {np.sum(mask_greedy)/xstar.shape[0]}')
                else:
                    randspar_tmp += mask_rand.sum().item()/xstar.shape[0]
                    mnetspar_tmp += mask_test.sum().item()/xstar.shape[0]
                    print(f'sparsity of random mask: {mask_rand.sum().item()/xstar.shape[0]}, mnet mask: {mask_test.sum().item()/xstar.shape[0]}')
                print('\n')
            randqual[global_step//validate_every] = randqual_tmp/testimg.shape[0]
            mnetqual[global_step//validate_every] = mnetqual_tmp/testimg.shape[0]
            greedyqual[global_step//validate_every] = greedyqual_tmp/testimg.shape[0]
            randspar[global_step//validate_every] = randspar_tmp/testimg.shape[0]
            mnetspar[global_step//validate_every] = mnetspar_tmp/testimg.shape[0]
            if plot:
                
        global_step += 1
    # return mnet, unet

In [16]:
alternating_update_with_unetRecon(mnet,UNET,train_full[0:10,:,:],test_full[0,:,:],fullmask,\
                                  budget=24,alpha=2e-4,lr_mb=1e-4,lr_mn=1e-4,\
                                  maxIter_mb=50,maxRep=5,\
                                  verbose=False,mask_greedy=mask_greedy[0,:],change_initmask=True)

  # Remove the CWD from sys.path while we load stuff.
  fullmask = torch.tensor( mask_complete(M_high,imgHeg,dtyp=dtyp) )


> [0;32m/home/huangz78/mri/mask_backward_new.py[0m(191)[0;36mmask_backward[0;34m()[0m
[0;32m    189 [0;31m            [0mx[0m   [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mabs[0m[0;34m([0m[0mF[0m[0;34m.[0m[0mifftn[0m[0;34m([0m[0mz[0m[0;34m,[0m[0mdim[0m[0;34m=[0m[0;34m([0m[0;36m0[0m[0;34m,[0m[0;36m1[0m[0;34m)[0m[0;34m,[0m[0mnorm[0m[0;34m=[0m[0;34m'ortho'[0m[0;34m)[0m[0;34m)[0m [0;31m# should involve the mask to cater for the lower-level objective[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    190 [0;31m        [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 191 [0;31m        [0mloss[0m [0;34m=[0m [0mnrmse[0m[0;34m([0m[0mx[0m[0;34m,[0m[0mxstar[0m[0;34m)[0m [0;34m+[0m [0malpha[0m [0;34m*[0m [0mtorch[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mM_high[0m[0;34m,[0m[0mp[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m [0;34m+[0m [0mc[0m [0;34m*[0m [0mcriterion_mnet[0m[0;34m([0m[0mm

ipdb> n
> [0;32m/home/huangz78/mri/mask_backward_new.py[0m(209)[0;36mmask_backward[0;34m()[0m
[0;32m    207 [0;31m        [0mfullmask_old[0m [0;34m=[0m [0mmask_makebinary[0m[0;34m([0m[0mfullmask[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mnumpy[0m[0;34m([0m[0;34m)[0m[0;34m,[0m[0mthreshold[0m[0;34m=[0m[0;36m0.5[0m[0;34m,[0m[0msigma[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    208 [0;31m        [0moptimizer[0m[0;34m.[0m[0mstep[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 209 [0;31m        [0mM_high[0m [0;34m=[0m [0mproj_eps[0m[0;34m([0m[0mM_high[0m[0;34m,[0m[0meps[0m[0;34m)[0m [0;31m# soft-hard-thresholding as postprocessing[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    210 [0;31m        [0mfullmask[0m [0;34m=[0m [0mmask_complete[0m[0;34m([0m[0mM_high[0m[0;34m,[0m[0mimgHeg[0m[0;34m,[0m[0mdtyp[0m[0;34m=[0m[0mdtyp[0m[0;34m)

ipdb> n
> [0;32m/home/huangz78/mri/mask_backward_new.py[0m(217)[0;36mmask_backward[0;34m()[0m
[0;32m    215 [0;31m        [0madded_rows[0m   [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdelta_mask[0m[0;34m==[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m;[0m   [0mreducted_rows[0m[0;34m=[0m [0mtorch[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdelta_mask[0m[0;34m==[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    216 [0;31m        [0mchanged_rows[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mabs[0m[0;34m([0m[0mdelta_mask[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 217 [0;31m        [0mcr_per_batch[0m [0;34m+=[0m [0mchanged_rows[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    218 [0;31m        [0;32mif[0m [0mchanged_row

BdbQuit: 

In [15]:
reload(utils)
import mask_backward_new
reload(mask_backward_new)
from mask_backward_new import mask_backward
from utils import raw_normalize