In [None]:
import numpy as np
import argparse
import os
import sys
import torch
import torch.fft as F
from importlib import reload
from torch.nn.functional import relu
import torch.nn as nn
import torch.nn.functional as Func
import torch.optim as optim
import utils
from matplotlib import pyplot as plt
import random
import copy
from sklearn.model_selection import train_test_split
from utils import *
import skimage
from mnet import MNet
from loupe_env.loupe_wrap import *
from mask_backward_new import mask_backward, mask_eval

sys.path.insert(0,'/home/huangz78/mri/unet/')
from unet_model import UNet

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

# comparison between MNet and Loupe

In [None]:
dtyp = torch.float
testxdata  = np.load('/home/huangz78/data/testdata_x.npz')
testydata  = np.load('/home/huangz78/data/testdata_y.npz')
testxfull = torch.tensor(testxdata['xfull'],dtype=dtyp)
testyfull = torch.tensor(testydata['yfull'],dtype=torch.cfloat)

In [None]:
sparsity = .25
preselect_num = 24
unet_skip = True

In [None]:
# load mnet
mnet = MNet(beta=1,in_channels=2,out_size=320-preselect_num, imgsize=(320,320),poolk=3)
mnetpath = '/home/huangz78/checkpoints/mnet_split_trained.pth'
checkpoint = torch.load(mnetpath)
mnet.load_state_dict(checkpoint['model_state_dict'])
mnet.eval()
print('MNet loaded successfully from: ' + mnetpath)

unet_recon = UNet(n_channels=1,n_classes=1,bilinear=False,skip=True)
unetpath = '/home/huangz78/checkpoints/unet_split_trained.pth'
checkpoint = torch.load(unetpath)
unet_recon.load_state_dict(checkpoint['model_state_dict'])
print('Unet loaded successfully from: ' + unetpath )
unet_recon.eval()
print('nn\'s are ready')

In [None]:
def mnet_eval(mnet,unet,testdata,preselect_num,sparsity,\
             batchsize=5):
    
    # prepare test data for mnet input
    nimgs = testdata.shape[0]; heg = testdata.shape[1]; wid = testdata.shape[2]
    
    data = torch.zeros(nimgs,2,heg,wid)    
    lf_mask = mask_naiveRand(heg,fix=preselect_num,other=0,roll=False)
    data[:,0,lf_mask==1,:] = torch.real(testdata[:,lf_mask==1,:])
    data[:,1,lf_mask==1,:] = torch.imag(testdata[:,lf_mask==1,:])
    
    pred_mnet = torch.zeros((nimgs,heg))
    batchnums = int(np.ceil(nimgs/batchsize))
    batchind = 0
    while batchind < batchnums:
        batch = np.arange(batchsize*batchind, min(batchsize*(batchind+1),nimgs))
        databatch = data[batch]
        preds = mnet_wrapper(mnet,databatch,budget=int(heg*sparsity)-preselect_num,\
                             imgshape=[heg,wid],normalize=True)
        pred_mnet[batch] = F.ifftshift(preds,dim=1)
        batchind += 1
    
    observed_kspace = torch.zeros_like(testdata)
    for ind in range(len(data)):
        observed_kspace[ind,pred_mnet[ind,:]==1,:] = testdata[ind,pred_mnet[ind,:]==1,:]
    input_unet = F.ifftn(observed_kspace,dim=(1,2),norm='ortho').abs().view(nimgs,1,heg,wid)

    imgs_recon = torch.zeros(testdata.shape)
    batchind = 0
    while batchind < batchnums:
        batch = np.arange(batchsize*batchind, min(batchsize*(batchind+1),nimgs))
        databatch = input_unet[batch]
        imgs_recon[batch] = torch.squeeze(unet(databatch))
        batchind += 1
    
    ssim = compute_ssim(imgs_recon,testdata)
    psnr = compute_psnr(imgs_recon,testdata)
    hfen = np.zeros((nimgs))
    for ind in range(nimgs):
        hfen[ind] = compute_hfen(imgs_recon[ind,:,:],testdata[ind,:,:])
    rmse = np.zeros((nimgs))
    for ind in range(nimgs):
        rmse[ind] = torch.norm(imgs_recon[ind,:,:] - testdata[ind,:,:],2)/torch.norm(testdata[ind,:,:],2)
    
    return ssim,psnr,hfen,rmse

In [None]:
from skimage import metrics

In [None]:
ssim,psnr,hfen,rmse = mnet_eval(mnet,unet_recon,testyfull[180:],preselect_num=24,sparsity=.25,\
                               batchsize=1)

In [None]:
# load unet
UNET = UNet(n_channels=1,n_classes=1,bilinear=(not unet_skip),skip=unet_skip)
# load loupe model
loupepath = '/home/huangz78/checkpoints/loupe_model.pt'
loupe = LOUPE(n_channels=1,unet_skip=True,shape=[320,320],slope=5,sparsity=sparsity,\
                  preselect=True,preselect_num=preselect_num,\
                  sampler=None,unet=UNET)

# mnet training error check

In [None]:
rec = np.load('/home/huangz78/checkpoints/mnet_train_history.npz')
print(rec.files)

In [None]:
wid1 = 10
wid2 = 5
plt.figure()
plt.plot(rolling_mean(rec['precision_train'],wid1),label='precision')
plt.plot(rolling_mean(rec['recall_train'],wid1),label='recall')
plt.title('training accuracy')
plt.legend()
plt.show()

plt.figure()
plt.title('testing accuracy')
plt.plot(rolling_mean(rec['precision_test'],wid2),label='precision')
plt.plot(rolling_mean(rec['recall_test'],wid2),label='recall')
plt.legend()
plt.show()

plt.figure()
plt.title('train loss in cross entropy')
plt.plot(rolling_mean(rec['loss_train'],20),label='train')
plt.legend()
plt.show()

plt.figure()
plt.title('test loss in cross entropy')
plt.plot(rolling_mean(rec['loss'],8),color='orange',label='test')
plt.legend()
plt.show()


# mnet quality check: y

In [None]:
net = MNet(beta=1,in_channels=2,out_size=320-24,\
                   imgsize=(320,320),poolk=3)
checkpoint = torch.load('/home/huangz78/checkpoints/mnet.pth')
net.load_state_dict(checkpoint['model_state_dict'])
print('MNet loaded successfully from: ' + '/home/huangz78/checkpoints/mnet.pth')
net.eval()

imgs = torch.tensor( np.load('/home/huangz78/data/data_gt.npz')['imgdata'] ).permute(2,0,1)
base = 24
mask_lf,_,_ = mask_naiveRand(imgs.shape[1],fix=base,other=0,roll=True)

yfulls = torch.zeros((imgs.shape[0],2,imgs.shape[1],imgs.shape[2]),dtype=torch.float)
ys     = torch.zeros((imgs.shape[0],2,imgs.shape[1],imgs.shape[2]),dtype=torch.float)
xs     = torch.zeros((imgs.shape[0],1,imgs.shape[1],imgs.shape[2]),dtype=torch.float)
for ind in range(imgs.shape[0]):
    imgs[ind,:,:] = imgs[ind,:,:]/torch.max(torch.abs(imgs[ind,:,:]))
    y = torch.fft.fftshift(F.fftn(imgs[ind,:,:],dim=(0,1),norm='ortho'))
    ysub = torch.zeros(y.shape,dtype=y.dtype)
    ysub[mask_lf==1,:] = y[mask_lf==1,:]
    xs[ind,0,:,:] = torch.abs(F.ifftn(torch.fft.ifftshift(ysub),dim=(0,1),norm='ortho')) 

    yfulls[ind,0,:,:] = torch.real(y)
    yfulls[ind,1,:,:] = torch.imag(y)
    ys[ind,:,mask_lf==1,:] = yfulls[ind,:,mask_lf==1,:]

labels = torch.tensor( np.load('/home/huangz78/data/data_gt_greedymask.npz')['mask'].T ) # labels are already rolled

imgNum = imgs.shape[0]
traininds, testinds = train_test_split(np.arange(imgNum),random_state=0,shuffle=True,train_size=round(imgNum*0.8))
test_total  = testinds.size

traindata   = ys[traininds,:,:,:]
valdata     = ys[testinds[0:test_total//2],:,:,:]

trainlabels = mask_filter(labels[traininds,:],base=base)
vallabels   = mask_filter(labels[testinds[0:test_total//2],:],base=base)

In [None]:
imgind = 19
testimg  = valdata[imgind,:,:,:]
output_1 = net(testimg.view(-1,2,320,320))
binary_1 = sigmoid_binarize(output_1)[0,:]
greedy_1 = vallabels[imgind,:]

imgind = 13
testimg = valdata[imgind,:,:,:]
output_2 = net(testimg.view(-1,2,320,320))
binary_2 = sigmoid_binarize(output_2)[0,:]
greedy_2 = vallabels[imgind,:]

print(torch.sum(torch.abs(binary_1-binary_2)))
print(output_1 - output_2)
# sigmoid_binarize(output)[0,:]

In [None]:
from utils import mask_complete

In [None]:
mask_1 = mask_complete(binary_1.view(1,-1),320,rolled=True)
kplot(mask_1.view(-1))
kplot(greedy_1)

# mnet quality check: x

In [None]:
# generate a mnet mask as an example

mnet = MNet(out_size=320-24)
checkpoint = torch.load('/home/huangz78/checkpoints/mnet_split_trained.pth')
mnet.load_state_dict(checkpoint['model_state_dict'])
print('MNet loaded successfully from: ' + '/home/huangz78/checkpoints/mnet_split_trained.pth')

test_dir = '/home/huangz78/data/testdata_x.npz'
# testimg  = torch.tensor(np.load(test_dir)['x']) 
# print(testimg.shape)
# test_sub  = test_sub[0:10,:,:]
test_full = torch.tensor(np.load(test_dir)['xfull']) 
mask_greedy = np.load('/home/huangz78/data/data_gt_greedymask.npz')
mask_greedy = mask_greedy['mask'].T # this greedy mask is rolled

In [None]:
reload(utils)
from utils import mnet_wrapper

In [None]:
# imgind = np.random.randint(test_full.shape[0])
mnet.train()
imgind1 = np.random.randint(199)
print('current selected image is indexed: ',imgind1)
img1 = test_full[imgind1,:,:]
imgind2 = np.random.randint(199)
print('current selected image is indexed: ',imgind2)
img2 = test_full[imgind2,:,:]
# plt.imshow(img)
# plt.show()

budget = 56
lowfreqmask,_,_ = mask_naiveRand(img1.shape[0],fix=24,other=0,roll=True)
# randmask,_,_ = mask_naiveRand(img.shape[0],fix=24,other=budget,roll=True)
# kplot(randmask)

yfull1 = torch.fft.fftshift(F.fftn(img1,dim=(0,1),norm='ortho')) # y is ROLLED in this line!
yfull2 = torch.fft.fftshift(F.fftn(img2,dim=(0,1),norm='ortho')) # y is ROLLED!
yfull  = torch.stack((yfull1,yfull2),dim=0)
# x_lf_minus      = -x_lf.clone()
x_lf   = get_x_f_from_yfull(lowfreqmask,yfull,DTyp=torch.cfloat)

In [None]:
mnet( x_lf.view(x_lf.shape[0],1,img1.shape[0],img1.shape[1]) ).shape

In [None]:
yfull1.shape

In [None]:
mnet.eval()
mnetmask = mnet( 1e3*torch.randn(10,1,img1.shape[0],img1.shape[1]) )
# x_lf =  1e3*torch.randn(10,img1.shape[0],img1.shape[1]) 
# highmask_raw = mnet( x_lf.view(x_lf.shape[0],1,img1.shape[0],img1.shape[1]) )
# mnetmask = mnet_wrapper(mnet,x_lf,budget,img1.shape)
print(mnetmask.shape)

torch.sum( torch.abs( mnetmask[0,:] - mnetmask[1,:] ) )

In [None]:
a = torch.randn(2,2,4,5)
b = torch.randn(2,2,4,5)
c = a[:,0,:,:] + b[:,1,:,:]
print(c.shape)

In [None]:
highmask_raw  = mnet( x_lf.view(x_lf.shape[0],1,img1.shape[0],img1.shape[1]) ).view(-1)  # no sigmoid 
# highmask_raw  = mnet( 1e10*torch.randn(1,1,img.shape[0],img.shape[1]) ).view(-1)  # no sigmoid 
# print('highmask_raw = ', highmask_raw)
plt.plot(highmask_raw.detach().numpy())
plt.show()
mnetmask = mnet_wrapper(mnet,x_lf,budget,img.shape)
kplot(mnetmask)
kplot(mask_greedy[imgind,:])

In [None]:
print(mnet.)

In [None]:
for param in mnet.named_parameters():
      print(param)

In [None]:
mnetmask_old = mnetmask.clone()
mnetmask - mnetmask_old

In [None]:
randqual[1:]

In [None]:
# show quality of mnet
filepath = '/home/huangz78/checkpoints/alternating_update_error_track.npz'
data_loss = np.load(filepath)

print(data_loss.files)

randqual   = data_loss['randqual']
mnetqual   = data_loss['mnetqual']
greedyqual = data_loss['greedyqual']
randspar   = data_loss['randspar']
mnetspar   = data_loss['mnetspar']

try:
    visualization(randqual[1:],mnetqual[1:],greedyqual=greedyqual,\
             randspar=randspar,mnetspar=mnetspar,greedyspar=greedyspar*np.ones(len(greedyqual)))
except Exception:
    visualization(randqual[1:],mnetqual[1:],randspar=randspar,mnetspar=mnetspar,log1=True)