In [1]:
import numpy as np
import scipy.io as sio
import os
import sys
import torch
from matplotlib import image
import matplotlib.pyplot as plt
import torch.fft as F
from sklearn.metrics import mean_squared_error as mse
from sklearn.model_selection import train_test_split
# import gurobipy
import h5py
from importlib import reload
from torch.nn.functional import relu
import torch.nn as nn
import torch.nn.functional as Func
import torch.optim as optim
import utils
from utils import *
import time

In [None]:
val_dir  = '/mnt/shared_a/fastMRI/knee_singlecoil_val.npz'
test_dir = '/mnt/shared_a/fastMRI/knee_singlecoil_test.npz'
valdata  = np.load(val_dir)['data']
testdata = np.load(test_dir)['data']
print(valdata.shape)
print(testdata.shape)

In [None]:
data  = torch.reshape(torch.tensor(valdata[0:5]),(5,1,320,320))
label = data.clone()

In [None]:
def data_augmentation(data,label):
    assert(data.shape==label.shape)
    chans,heg,wid=data.shape[1],data.shape[2],data.shape[3]
    data_aug  = torch.zeros_like(data)
    label_aug = torch.zeros_like(label)
    breakpoint()
    for ind in range(len(data)):
        flip_flag = (np.random.rand() > .5)
        if flip_flag: # Horizontal flip
            data[ind]  = torch.flip(data[ind], dims=[2]).clone()
            label[ind] = torch.flip(label[ind],dims=[2]).clone()            
        # Random cropping
        if (np.random.rand() > .5):
            pad = 10
            dim_1, dim_2 = np.random.randint(pad*2+1, size=2)
            
            extended_data = torch.zeros((chans,heg+pad*2,wid+pad*2),dtype=torch.float)
            extended_data[:,pad:-pad,pad:-pad] = data[ind]
            data[ind] = extended_data[:,dim_1:dim_1+heg,dim_2:dim_2+wid].clone()
            
            extended_label = torch.zeros((chans,heg+pad*2,wid+pad*2),dtype=torch.float)
            extended_label[:,pad:-pad,pad:-pad] = label[ind]
            label[ind] = extended_label[:,dim_1:dim_1+heg,dim_2:dim_2+wid].clone()
        
        data_aug[ind]  = data[ind]
        label_aug[ind] = label[ind]
    return data_aug,label_aug


In [None]:
data_augmentation(data,label)

In [None]:
energy_vec_val = torch.zeros(valdata.shape[1])
val_full   = torch.tensor(np.load('/mnt/shared_a/fastMRI/knee_singlecoil_val.npz')['data'],  dtype=torch.float)
for ind in range(val_full.shape[0]):
    val_full[ind,:,:]  = val_full[ind,:,:]/val_full[ind,:,:].abs().max()
for img in val_full:
    img_fft = torch.fft.fftn(img,dim=(0,1),norm='ortho')
    energy_vec_val += torch.sum(torch.square(torch.abs(img_fft)).to(torch.float),dim=1)
energy_vec_val /= len(val_full)
energy_vec_val /= torch.sum(energy_vec_val)
energy_vec_val = torch.fft.fftshift(energy_vec_val)

In [None]:
plt.scatter(range(len(energy_vec_val)),energy_vec_val)
plt.yscale('log')

In [None]:
from unet_train import prepare_data

In [None]:
prepare_data(mode='prob',unet_inchans=1)

# train test split prep

In [None]:
data_dir1 = '/home/huangz78/data/data_gt.npz'
data1 = np.load(data_dir1)
print('file1',data1.files)
print(data1['imgdata'].shape)
data_dir2 = '/mnt/shared_b/data/fastMRI/singlecoil_train/expanded_gt.npz'
data2 = np.load(data_dir2)
print('file2',data2.files)
print(data2['imgdata'].shape)

# data = np.concatenate((data1['imgdata'],data2['imgdata']),axis=2)
data = np.concatenate((data2['imgdata'],data1['imgdata']),axis=2)
del data1
del data2

In [None]:
from sklearn.model_selection import train_test_split
imgNum = 199+1014
traininds, testinds = train_test_split(np.arange(imgNum),random_state=0,shuffle=True,train_size=1000)

In [None]:
traininds = np.arange(0,1014,1)
testinds  = np.arange(1014,199+1014,1)

In [None]:
# trainimgs = data['imgdata'][:,:,traininds]
# testimgs = data['imgdata'][:,:,testinds]
dtyp = torch.float
Dtyp = torch.cfloat
trainimgs = data[:,:,traininds]
testimgs  = data[:,:,testinds]
train_y = torch.zeros((trainimgs.shape[2],trainimgs.shape[0],trainimgs.shape[1]),dtype=Dtyp)
train_yfull = torch.zeros((trainimgs.shape[2],trainimgs.shape[0],trainimgs.shape[1]),dtype=Dtyp)
train_x = torch.zeros((trainimgs.shape[2],trainimgs.shape[0],trainimgs.shape[1]),dtype=dtyp)
train_xfull = torch.zeros((trainimgs.shape[2],trainimgs.shape[0],trainimgs.shape[1]),dtype=dtyp)

test_y  = torch.zeros((testimgs.shape[2],testimgs.shape[0],testimgs.shape[1]),dtype=Dtyp)
test_yfull  = torch.zeros((testimgs.shape[2],testimgs.shape[0],testimgs.shape[1]),dtype=Dtyp)
test_x  = torch.zeros((testimgs.shape[2],testimgs.shape[0],testimgs.shape[1]),dtype=dtyp)
test_xfull  = torch.zeros((testimgs.shape[2],testimgs.shape[0],testimgs.shape[1]),dtype=dtyp)

#### load an image and make it into correct shape

In [None]:
# prepare document list with all files end with .h5
fastMRI_path = '/mnt/shared_b/data/fastMRI/singlecoil_train/'
# fastMRI_path = '/mnt/shared_b/data/fastMRI/singlecoil_val/'
# fastMRI_path = '/mnt/shared_b/data/fastMRI/singlecoil_test_v2/' # ALL files in this directory is broken!
os.chdir(fastMRI_path)
doculist = list([])
for file in os.listdir(fastMRI_path):
    if (not file.startswith('.')) and (file.endswith('.h5')):
#     if (not file.startswith('.')) and (not file.endswith('.npz')) and (not file.endswith('.txt')):
        doculist.append(file)
print(len(doculist))

In [None]:
# filename = 'file1000605.h5'
# filename = 'file1000568.h5'
filename = doculist[0]
f = h5py.File(fastMRI_path + filename, 'r')
print(f.keys())

In [None]:
for ind in range(f['reconstruction_rss'].shape[0]):
    plt.imshow(f['reconstruction_rss'][ind,:,:],origin='lower')
    plt.title(f'{filename}, frame {ind}')
    plt.colorbar()
    plt.show()
f.close()

In [None]:
13*973

In [None]:
for file in doculist:
    try:
        f = h5py.File(file,'r')
        print('filename: ', file, '\t', 'Number of images: ', f['reconstruction_esc'].shape[0])
        f.close()
    except:
        print(file,'failed to be opened')

In [None]:
# what volumes should we pick? 'diff' below means the number of volumes we should select
diff = 6
indset = np.random.choice(range(len(doculist)),size=20,replace=False)
for ind in indset:
    try:
        filename = doculist[ind]
        f = h5py.File(filename, 'r')
        total_frames = f['reconstruction_esc'].shape[0]
        print(filename,'number of frames: ', total_frames)
        
        im = torch.tensor(f['reconstruction_esc'][total_frames//2-diff,:,:])
        plt.title(filename+'  slice: '+str(total_frames//2-diff))
        plt.imshow(im,origin='lower')
        plt.colorbar()
        plt.show()
        
        im = torch.tensor(f['reconstruction_esc'][total_frames//2+diff,:,:])
        plt.title(filename+'  slice: '+str(total_frames//2+diff))
        plt.imshow(im,origin='lower')
        plt.colorbar()
        plt.show()
        
        f.close()
    except:
#         doculist.remove(filename)
        print('failed to open the file: ',filename)

In [None]:
# create training/validation/testing dataset
diff = 6
fileinds = np.random.permutation(len(doculist))
doculist_val  = [doculist[i] for i in fileinds[0:99]] # validation
doculist_test = [doculist[i] for i in fileinds[99:]]  # test
# doculist_current = doculist[:99] # validation
# doculist_current = doculist[99:] # test
doculist_current = doculist_test
imgdata = np.zeros((len(doculist_current)*(2*diff+1),320,320))
imgind = 0
for filename in doculist_current:
    f = h5py.File(filename, 'r')    
    fsize = f['reconstruction_esc'].shape[0]
    print(filename, fsize)
    for ind in np.arange(-diff,diff+1,1):
        imgdata[imgind,:,:] = torch.tensor(f['reconstruction_esc'][fsize//2+ind,:,:])
        imgind += 1
print(len(doculist_current)*(2*diff+1),imgind)

In [None]:
# save training dataset
# filename = '/mnt/shared_a/data/fastMRI/knee_singlecoil_train.npz'
# filename = '/mnt/shared_a/fastMRI/knee_singlecoil_val_2.npz'
filename = '/mnt/shared_a/fastMRI/knee_singlecoil_test_2.npz'
np.savez(filename,data=imgdata)

In [None]:
im = torch.tensor(f['reconstruction_rss'][22,:,:])
plt.imshow(im,cmap='gray')
plt.colorbar()
y = F.fftn(im,dim=(0,1),norm='ortho')
f.close()


imgHeg = y.shape[0]
imgWid = y.shape[1]
if len(y.shape)<3:
     y = y.view((y.shape[0],y.shape[1],1)) 
x_star  = im.view(imgHeg,imgWid,-1)

##### load Siddhant's greedy mask

In [None]:
### load greedy mask provided by Siddhant
file1 = np.load('/mnt/shared_b/gautamsi/mri-sampling/simulation-results/greedy_fastmri_mp50.npz')
file2 = np.load('/mnt/shared_b/gautamsi/mri-sampling/simulation-results/greedy_fastmri_mp100.npz')
file3 = np.load('/mnt/shared_b/gautamsi/mri-sampling/simulation-results/greedy_fastmri_mp150.npz')
file4 = np.load('/mnt/shared_b/gautamsi/mri-sampling/simulation-results/greedy_fastmri_mp198.npz')
file5 = np.load('/mnt/shared_b/gautamsi/mri-sampling/simulation-results/greedy_fastmri_mp199.npz')
mask1 = file1['arr_0']; mask2 = file2['arr_0']; mask3 = file3['arr_0']; mask4 = file4['arr_0']; mask5 = file5['arr_0']
plt.imshow(mask1[:,:,0])

In [None]:
masks = [mask1,mask2,mask3,mask4,mask5]
mask_label = np.zeros((320,199))
ind = 0
for maskfile in masks:
    masknum = maskfile.shape[2]
    for i in range(masknum):
        mask_label[:,i+ind] = maskfile[0,:,i]
    ind += masknum

np.savez('/home/huangz78/data/data_gt_greedymask.npz',mask=mask_label)

### At what sampling ratio does random sampling overtake low frequency sampling ?
    - Conclusion : 30% base, 10% additional

In [None]:
# At what sampling ratio does random sampling overtake low frequency sampling ?
# Conclusion : 30% base, 10% additional
# lamda = 5e-7
np.random.seed(0)
roll_flag = True
mps = np.ones((1,imgHeg,imgWid))
# x_recon = np.fft.fftshift(np.real(TotalVariationRecon(ksp, mps, lamda, weights=mask).run()))
# x_recon = np.fft.fftshift(np.real(L1WaveletRecon(ksp, mps, lamda, weights=mask).run()))
# l1wavelet: -9,-8.5; tv: -6.5,-6.3
lamda = 10**(-6.31)

base_r_grid = np.linspace(.25,.35,10)
rand_r = 0.1
error_rand = np.zeros(base_r_grid.size); error_freq = np.zeros(base_r_grid.size)

Rep = 10

ind = 0
for base_r in base_r_grid:
#     base_r = 0.3; 
    total_r   = base_r + rand_r    
    mask_freq,_,_ = mask_naiveRand(imgHeg,fix=imgHeg*total_r,other=0,roll=roll_flag)
    mask_freq = mask_freq.numpy()
    y_freq = np.reshape(np.diag(mask_freq)@yraw,(-1,imgHeg,imgWid)) 
    
    rep = 0
    while rep < Rep:
        mask_rand = mask_prob(img,fix=imgHeg*base_r,other=imgHeg*rand_r,roll=roll_flag,seed=int(time.strftime('%S')))
        y_rand = np.reshape(np.diag(mask_rand)@yraw,(-1,imgHeg,imgWid)) 
        x_recon_rand = np.fft.fftshift( np.real(TotalVariationRecon(y_rand,mps,lamda,show_pbar=False,max_iter=50).run()) )
        error_rand[ind] += np.sqrt( np.sum((x_recon_rand.flatten()-img.flatten())**2) )/np.sqrt( np.sum( (img.flatten())**2 ))
        rep += 1
    error_rand[ind] /= Rep
    
    x_recon_freq = np.fft.fftshift( np.real(TotalVariationRecon(y_freq,mps,lamda,show_pbar=False,max_iter=50).run()) )
    error_freq[ind] = np.sqrt( np.sum((x_recon_freq.flatten()-img.flatten())**2) )/np.sqrt( np.sum( (img.flatten())**2 ))
    # error_rand = np.mean(np.abs(x_recon_rand.flatten()-img.flatten())) 
    # error_freq = np.mean(np.abs(x_recon_freq.flatten()-img.flatten())) 

    print('rand.     mask recon. error = ' , error_rand[ind])
    print('low.freq. mask recon. error = ' , error_freq[ind])
    ind += 1

plt.figure()
plt.plot(base_r_grid,error_rand,label='rand')
plt.plot(base_r_grid,error_freq,label='freq')
plt.legend()
plt.show()

### sanity test
  - shepp-logan phantom
  - brain img from class material
  - the following cells are mostly loading imgs into notebook

In [None]:
fastMRI_path = '/Users/leonardohuang/Desktop/msu_research/code/data/singlecoil_val/'
# sys.path.append(fastMRI_path)
imgHeg   = 320
imgWid   = 320
DType    = torch.cfloat

os.chdir(fastMRI_path)
filename = 'file1001557.h5'
f = h5py.File(filename, 'r')
print(f.keys())

im = torch.tensor(f['reconstruction_rss'][22,:,:])
plt.imshow(im,cmap='gray')
plt.colorbar()
y = F.fftn(im,dim=(0,1),norm='ortho')
f.close()

In [None]:
# shepp-logan phantom image loading
im = image.imread('/Users/leonardohuang/Desktop/msu_research/code/data/phantom.gif')
im = torch.tensor(im[:,:,0]).to(torch.float)
imgHeg = im.shape[0]
imgWid = im.shape[1]
y = F.fftn(im,dim=(0,1),norm='ortho')
DType  = torch.cfloat
plt.clf()
plt.imshow(im)
plt.colorbar()

In [None]:
def rescale_sp(x):
    maxval = np.max(np.abs(x))
    minval = np.min(np.abs(x))
    K = 255./(maxval-minval)
    B = - (minval*255.)/(maxval-minval)
    x = K*x+B
    return x

### prepare useful FastMRI dataset

In [None]:
imgdata = np.load('/home/huangz78/data/imgdata.npz')
data = imgdata['imgdata']

In [None]:
base = .3
addi = .1
imgHeg = 320
labels = np.zeros((199,int(imgHeg*(1-base))))
suby   = np.zeros((199,int(imgHeg*base),320,2))
ifftimgs  = np.zeros((199,320,320))

coreInds = np.arange(int(imgHeg/2)-int(imgHeg*base/2), int(imgHeg/2)+int(imgHeg*base/2))
mask_low,_,_ = mask_naiveRand(imgHeg,fix=int(imgHeg*base),other=0,roll=True)
mask_low = mask_low.numpy()

np.random.seed(2021)
for ind in range(199):
    img = data[:,:,ind]
    fullmask = mask_prob(img,fix=imgHeg*base,other=imgHeg*addi,roll=True)
    labels[ind,:] = fullmask[np.setdiff1d(np.arange(imgHeg),coreInds)] # labels for high freq
    
    yraw = np.fft.fftshift(np.fft.fftn(img,norm='ortho'))
    y = np.diag(mask_low)@yraw # subsampled y
    ifftimgs[ind,:,:] = np.abs(np.fft.ifftn(np.fft.fftshift(y),norm='ortho')) # ifft imgs
    suby[ind,:,:,0] = np.real(y[coreInds,:]) # subsampled y real
    suby[ind,:,:,1] = np.imag(y[coreInds,:]) # subsampled y imag


In [None]:
filepath = '/home/huangz78/data/datafornn.npz'
np.savez(filepath,labels=labels,sub_y=suby,ifftimgs=ifftimgs)

In [None]:
img = data[:,:,3]
fullmask = mask_prob(img,fix=imgHeg*base,other=imgHeg*addi,roll=True)
labels[ind,:] = fullmask[np.setdiff1d(np.arange(imgHeg),coreInds)]

mask_low,_,_ = mask_naiveRand(imgHeg,fix=int(imgHeg*base),other=0,roll=True)
mask_low = mask_low.numpy()
yraw = np.fft.fftshift(np.fft.fftn(img,norm='ortho'))
y = np.diag(mask_low)@yraw
xifft = np.abs(np.fft.ifftn(np.fft.fftshift(y),norm='ortho'))

plt.figure(figsize=(5,10))
plt.subplot(311)
plt.title('orig img')
plt.imshow(img)
plt.colorbar()

plt.subplot(312)
plt.title('naive masked y')
plt.imshow(np.log(np.abs(y)))
plt.colorbar()

plt.subplot(313)
plt.title('ifft img')
plt.imshow(xifft)
plt.colorbar()

plt.show()

kplot(fullmask)

### manually select images by printing all images for view

In [None]:
fastMRI_path = '/Users/leonardohuang/Desktop/msu_research/code/data/singlecoil_val/'
# sys.path.append(fastMRI_path)
imgHeg   = 320
imgWid   = 320
DType    = torch.cfloat

doculist = list([])
for file in os.listdir(fastMRI_path):
    if not file.startswith('.'):
        doculist.append(file)
# docutrain,docutest = train_test_split(doculist,train_size=int(len(doculist)*.8), random_state=1024)
# print('Number of image documents for training: ', len(docutrain))
# print('Number of image documents for testing : ', len(docutest) )

In [None]:
data = {}
mainInd = int(26)
fileind = 0
for filename in doculist:
    print("image {} out of {}".format(fileind+1,len(doculist)))
    f  = h5py.File(filename,'r')    
    fileNum = f['reconstruction_rss'].shape[0]
    if mainInd < fileNum:
        plt.clf()
        plt.imshow(f['reconstruction_rss'][mainInd,:,:])
        plt.title('{0}: slice {1}'.format(filename,mainInd))
        plt.colorbar()
        plt.show()
        indicator = input()
        goodind = mainInd
        if int(indicator) != 1:
            for ind in range(18,min(32,fileNum),1):
                plt.clf()
                plt.imshow(f['reconstruction_rss'][ind,:,:])
                plt.title('{0}: slice {1}'.format(filename,ind))
                plt.colorbar()
                plt.show()
            goodind = input()
        data[filename] = int(goodind)
    else:
        print(filename," fails to load ", mainInd, " slice!!!")
    f.close()
    fileind += 1

#     for pic in range(fileNum):   

In [None]:
database = data.copy()

In [None]:
f = h5py.File(list(database.keys())[51],'r')
for ind in range(f['reconstruction_rss'].shape[0]):
                plt.clf()
                plt.imshow(f['reconstruction_rss'][ind,:,:])
                plt.title('{0}: slice {1}'.format(filename,ind))
                plt.colorbar()
                plt.show()

f.close()

In [None]:
f = open("dict.txt","w")
f.write( str(database) )
f.close()

In [None]:
np.savez('imgdata',imgdata=imgdata)