In [1]:
import sys
sys.path.append('/host/d/Github/')  # add the path to your own Example_UNet folder
import numpy as np
import matplotlib.pyplot as plt
# %matplotlib inline
import ismrmrd
import ismrmrdtools
import ismrmrdtools.coils as coils
import ismrmrdtools.transform as transform
import scipy.ndimage
import ismrmrdtools.sense as sense
import h5py
import os
import nibabel as nb
import Diffusion_denoising_thin_slice.functions_collection as ff
main_path = '/host/d/Data/NYU_MR/multicoil_train'  # change to your own data path


  from ._conv import register_converters as _register_converters
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
files = ff.find_all_target_files(['*.h5'], os.path.join(main_path,'data'))
print('Total number of training files: {}'.format(len(files)))

Total number of training files: 197


In [3]:
def ReadData_multicoil(filename):
    with h5py.File(filename, 'r') as f:
        kspace = f['kspace'][()].astype('complex64')   # shape = (nslices, d1, d2)
        xml_bytes = f['ismrmrd_header'][()]            # raw XML string (bytes)
    print('kspace shape:', kspace.shape)

    header = ismrmrd.xsd.CreateFromDocument(xml_bytes)
    enc = header.encoding[0]

    eNx_xml = enc.encodedSpace.matrixSize.x
    eNy_xml = enc.encodedSpace.matrixSize.y
    eNz_xml = enc.encodedSpace.matrixSize.z

    rNx_xml = enc.reconSpace.matrixSize.x
    rNy_xml = enc.reconSpace.matrixSize.y
    rNz_xml = enc.reconSpace.matrixSize.z

    print('eNx,eNy,eNz:', eNx_xml, eNy_xml, eNz_xml, 'rNx,rNy,rNz:', rNx_xml, rNy_xml, rNz_xml)

    # Field of View
    eFOVx = enc.encodedSpace.fieldOfView_mm.x
    eFOVy = enc.encodedSpace.fieldOfView_mm.y
    eFOVz = enc.encodedSpace.fieldOfView_mm.z
    rFOVx = enc.reconSpace.fieldOfView_mm.x
    rFOVy = enc.reconSpace.fieldOfView_mm.y
    rFOVz = enc.reconSpace.fieldOfView_mm.z

    nslices, ncoils, Ny, Nx = kspace.shape
    print('nslices,ncoils,Nx,Ny:', nslices, ncoils, Nx, Ny)

    
    if Nx != rNx_xml:
        # 原始代码思路：图像域裁剪
        print("Removing oversampling according to XML...")

        # 先变为图像域
        img = transform.transform_kspace_to_image(kspace, dim = [-2, -1])
        # 计算需要裁剪的范围
        x0 = (Nx - rNx_xml) // 2
        x1 = x0 + rNx_xml

        # 裁剪
        img_cropped = img[..., x0:x1]

        # 再转回 k-space
        kspace = transform.transform_image_to_kspace(img_cropped, dim = [-2, -1])

        # 更新 Nx
        Nx = rNx_xml

    all_data = kspace
    all_data = all_data.astype('complex64')
    print('all_data shape:', all_data.shape)

    y0 = int((eNy_xml - rNy_xml) / 2)
    y1 = int((eNy_xml - rNy_xml) / 2 + rNy_xml)

    return all_data, y0, y1


def GetMask_multicoil(eNyMask, rate, all_data, seed = None):
    # 计算 kx 方向长度
    nx = all_data.shape[-1]
    cx = int(nx / 2)
    x0 = cx - int(eNyMask / 2)
    x1 = cx + int(eNyMask / 2)

    # 创建 indices：除中心区域的 kx
    inds = np.concatenate((np.arange(0, x0), np.arange(x1, nx)))
    nlines = int(len(inds) / rate)

    if seed is not None:
        np.random.seed(seed)

    # mask 结构: [slice, coil=1, ky, kx]
    cmask = np.zeros([1, 1, all_data.shape[-2], all_data.shape[-1]], np.float32)
    cmask[:, :, :, x0:x1] = 1     # <--- 改成针对 kx 方向的 central region

    mask1 = np.tile(cmask, (all_data.shape[0], 1, 1, 1))
    mask2 = np.tile(cmask, (all_data.shape[0], 1, 1, 1))

    for i in range(all_data.shape[0]):
        inds = np.concatenate((np.arange(0, x0), np.arange(x1, nx)))
        np.random.shuffle(inds)

        # 填 mask
        mask1[i, :, :, inds[:nlines]] = rate     # <--- 改成 kx mask
        mask2[i, :, :, inds[nlines:nlines*2]] = rate

    mask = (mask1 + mask2) / 2

    return cmask, mask1, mask2, mask

    # profile = np.abs(np.sum(all_data, (0,1,-1)))
    # ny = np.where(profile > 0)[0][-1] + 1
    # cy = int(ny / 2)    
    # y0 = cy - int(eNyMask / 2)
    # y1 = cy + int(eNyMask / 2)
    
    # inds = np.concatenate((np.arange(y0), np.arange(y1, ny)))
    # nlines = int(len(inds) / rate)
    # if seed is not None:
    #     np.random.seed(seed)
    
    # # coil mask
    # cmask = np.zeros([1, 1, all_data.shape[-2], all_data.shape[-1]], np.float32)
    # cmask[:, :, y0:y1, :] = 1
    # mask1 = np.tile(cmask, (all_data.shape[0], 1, 1, 1))
    # mask2 = np.tile(cmask, (all_data.shape[0], 1, 1, 1))
    # for i in range(all_data.shape[0]):
    #     inds = np.concatenate((np.arange(y0), np.arange(y1, ny)))
    #     np.random.shuffle(inds)
        
    #     mask1[i, :, inds[:nlines], :] = rate
    #     mask2[i, :, inds[nlines:nlines*2], :] = rate
    
    # mask = (mask1 + mask2) / 2
    
    # return cmask, mask1, mask2, mask

def GetCsms_multicoil(all_data, cmask):
    coil_data = all_data * np.tile(cmask, (all_data.shape[0], all_data.shape[1], 1, 1))
    coil_images = transform.transform_kspace_to_image(coil_data,(-2, -1))
    sos = np.sqrt(np.sum(coil_images * np.conj(coil_images), 1)).astype(np.float32)
    csms = [coil_images[:, i, ...] / sos for i in range(coil_images.shape[1])]
    csms = np.transpose(np.array(csms), (1,0,2,3))
    
    return csms


def DirectRecon_multicoil(k_data, csms):
    recons = transform.transform_kspace_to_image(k_data,(-2, -1))
    return np.sum(np.conj(csms) * recons, 1)[:, np.newaxis, ...]


In [4]:
def ReadData_singlecoil(filename):
    """
    Read single-coil MRI k-space from a fastMRI-style h5 file that contains:
    - /kspace                ← single-coil k-space, shape (nslices, Ny_or_Nx, Nx_or_Ny)
    - /ismrmrd_header       ← ISMRMRD XML header
    
    Output:
    - all_data: ndarray of shape (nslices, 1, Ny, Nx), complex64
                (keeps same interface as teacher's original ReadData output)
    - y0, y1: indices for cropping from encodedSpace → reconSpace (same meaning as original code)
    """

    # =======================
    # 1. Load k-space + header
    # =======================
    with h5py.File(filename, 'r') as f:
        kspace = f['kspace'][()].astype('complex64')   # shape = (nslices, d1, d2)
        xml_bytes = f['ismrmrd_header'][()]            # raw XML string (bytes)

    # =======================
    # 2. Parse ISMRMRD XML header (same as original)
    # =======================
    header = ismrmrd.xsd.CreateFromDocument(xml_bytes)
    enc = header.encoding[0]

    # matrix size in encodedSpace and reconSpace
    eNx = enc.encodedSpace.matrixSize.x    # encoded kx
    eNy = enc.encodedSpace.matrixSize.y    # encoded ky
    rNx = enc.reconSpace.matrixSize.x      # recon kx
    rNy = enc.reconSpace.matrixSize.y      # recon ky
    print('eNx, eNy, rNx, rNy:', eNx, eNy, rNx, rNy)
    

    # =======================
    # 3. Identify the correct (Ny, Nx) dimension order
    # =======================
    nslices, d1, d2 = kspace.shape  
    print('nslices, d1, d2:', nslices, d1, d2)

    # Try to match (Ny, Nx) with header encodedSpace sizes
    if (d1 == eNy and d2 == eNx):
        # already in (Ny, Nx)
        print("kspace already in (Ny, Nx) order")
        kspace_reordered = kspace
    elif (d1 == eNx and d2 == eNy):
        # stored as (Nx, Ny) → transpose to (Ny, Nx)
        print("kspace in (Nx, Ny) order; transposing to (Ny, Nx)")
        kspace_reordered = np.transpose(kspace, (0, 2, 1))
    else:
        # cannot infer reliably; assume (Ny, Nx) = (d1, d2)
        print(f"[Warning] kspace shape ({d1}, {d2}) does not match encodedSpace ({eNy}, {eNx}).")
        print("          Using raw order as (Ny, Nx). You should double-check this.")
        kspace_reordered = kspace

    Ny = kspace_reordered.shape[1]
    Nx = kspace_reordered.shape[2]
    print('After reorder: Nx, Ny =', Nx, Ny)

    # =======================
    # 4. Optional center-crop/pad to encodedSpace (if they mismatch)
    # =======================
    # Many datasets already store exact encodedSize; but to ensure correctness:
    if Nx != rNx:
        # 原始代码思路：图像域裁剪
        print("Removing oversampling according to XML...")

        # 先变为图像域
        img = transform.transform_kspace_to_image(kspace_reordered, dim = [-2, -1])
        # 计算需要裁剪的范围
        x0 = (Nx - rNx_xml) // 2
        x1 = x0 + rNx_xml

        # 裁剪
        img_cropped = img[..., x0:x1]

        # 再转回 k-space
        kspace_reordered = transform.transform_image_to_kspace(img_cropped, dim = [-2, -1])

        # 更新 Nx
        Nx = rNx_xml

    # =======================
    # 5. Convert to all_data format used by teacher’s code
    # =======================
    # Original returned: all_data[0,0,:,:,0,...] → shape (nslices, ncoils, Ny, Nx)
    # You are single-coil → ncoils = 1
    all_data = kspace_reordered[:, None, :, :]   # shape = (nslices, 1, Ny, Nx)
    print('all_data shape:', all_data.shape)

    # =======================
    # 6. Compute y0,y1 of reconSpace (same meaning as original code)
    # =======================
    y0 = int((eNy - rNy) / 2)
    y1 = y0 + rNy
    print('y0, y1:', y0, y1)

    # =======================
    # 7. Return data compatible with the rest of your pipeline
    # =======================
    return all_data.astype('complex64'), y0, y1

def GetMask_singlecoil(eNyMask, rate, all_data, seed = None):
    profile = np.abs(np.sum(all_data, (0,1,-1)))
    ny = np.where(profile > 0)[0][-1] + 1
    cy = int(ny / 2)    
    y0 = cy - int(eNyMask / 2)
    y1 = cy + int(eNyMask / 2)
        
    inds = np.concatenate((np.arange(y0), np.arange(y1, ny)))
    nlines = int(len(inds) / rate)
    if seed is not None:
        np.random.seed(seed)
        
    # coil mask
    cmask = np.zeros([1, 1, all_data.shape[-2], all_data.shape[-1]], np.float32)
    cmask[:, :, y0:y1, :] = 1
    mask1 = np.tile(cmask, (all_data.shape[0], 1, 1, 1))
    mask2 = np.tile(cmask, (all_data.shape[0], 1, 1, 1))
    for i in range(all_data.shape[0]):
        inds = np.concatenate((np.arange(y0), np.arange(y1, ny)))
        np.random.shuffle(inds)
            
        mask1[i, :, inds[:nlines], :] = rate
        mask2[i, :, inds[nlines:nlines*2], :] = rate
        
    mask = (mask1 + mask2) / 2
        
    return cmask, mask1, mask2, mask


def GetCsms_singlecoil(all_data, cmask):
    coil_data = all_data * np.tile(cmask, (all_data.shape[0], all_data.shape[1], 1, 1))
    coil_images = transform.transform_kspace_to_image(coil_data,(-2, -1))
    sos = np.sqrt(np.sum(coil_images * np.conj(coil_images), 1)).astype(np.float32)
    csms = [coil_images[:, i, ...] / sos for i in range(coil_images.shape[1])]
    csms = np.transpose(np.array(csms), (1,0,2,3))
    
    return csms

def DirectRecon_singlecoil(k_data,csms):
    return transform.transform_kspace_to_image(k_data, (-2, -1))

In [4]:
eNyMask = 48
rate = 4
senseDir =  os.path.join(main_path,'sense')
reconDir = os.path.join(main_path,'undersample_' + str(rate))
refDir = os.path.join(main_path,'ref')

if not os.path.exists(reconDir):
    os.makedirs(reconDir)
if not os.path.exists(refDir):
    os.makedirs(refDir)
if senseDir is not None:
    if not os.path.exists(senseDir):
        os.makedirs(senseDir)

In [None]:
ReadData = ReadData_multicoil if 'multicoil' in main_path else ReadData_singlecoil
GetMask = GetMask_multicoil if 'multicoil' in main_path else GetMask_singlecoil
GetCsms = GetCsms_multicoil if 'multicoil' in main_path else GetCsms_singlecoil
DirectRecon = DirectRecon_multicoil if 'multicoil' in main_path else DirectRecon_singlecoil
coil = 'multicoil' if 'multicoil' in main_path else 'singlecoil'
print('Using {} functions'.format(coil))

for random_n in range(0,1):
    for i in range(0,63):
        filename = files[i]  
        print(filename)
        name = os.path.basename(filename)[:-3]
        seed = int(name.split('-')[-1][-8:], 16) + random_n

        # if os.path.isfile(os.path.join(reconDir, name, 'random_'+str(random_n), 'recon2','img.nii.gz')):
        #     print('File {} already exists, skiping...'.format(name))
        #     continue
    

        all_data, y0, y1 = ReadData(filename)
        cmask, mask1, mask2, mask = GetMask(eNyMask, rate, all_data, seed)
        print('cmask shape:', cmask.shape, mask1.shape, mask2.shape, mask.shape)

        csms = GetCsms(all_data, cmask)

        
        recon1 = DirectRecon(all_data * np.tile(mask1, (1, all_data.shape[1], 1, 1)), csms)
        recon2 = DirectRecon(all_data * np.tile(mask2, (1, all_data.shape[1], 1, 1)), csms)
        ref = DirectRecon(all_data, csms)

        ref_save_folder = os.path.join(refDir,name)
        recon1_save_folder = os.path.join(reconDir,name, 'random_'+str(random_n), 'recon1')
        recon2_save_folder = os.path.join(reconDir,name, 'random_'+str(random_n), 'recon2')
        ff.make_folder([ref_save_folder, os.path.join(reconDir,name), os.path.join(reconDir,name, 'random_'+str(random_n)), recon1_save_folder, recon2_save_folder])

        np.save(os.path.join(recon1_save_folder,'img.npy'), recon1)
        np.save(os.path.join(recon2_save_folder,'img.npy'), recon2)
        np.save(os.path.join(refDir,name,'img.npy'), ref)

        # also save for ITKsnap
        recon1_mag = np.abs(recon1)
        recon2_mag = np.abs(recon2)
        ref_mag = np.abs(ref)
        nb.save(nb.Nifti1Image(np.squeeze(recon1_mag), np.eye(4)), os.path.join(recon1_save_folder,'img.nii.gz'))
        nb.save(nb.Nifti1Image(np.squeeze(recon2_mag), np.eye(4)), os.path.join(recon2_save_folder,'img.nii.gz'))
        nb.save(nb.Nifti1Image(np.squeeze(ref_mag), np.eye(4)), os.path.join(refDir,name,'img.nii.gz'))

Using multicoil functions
/host/d/Data/NYU_MR/multicoil_train/data/file1000010.h5


  arr = numpy.ndarray(selection.mshape, dtype=new_dtype)


kspace shape: (36, 15, 640, 368)
eNx,eNy,eNz: 640 368 1 rNx,rNy,rNz: 320 320 1
nslices,ncoils,Nx,Ny: 36 15 368 640
Removing oversampling according to XML...
all_data shape: (36, 15, 640, 320)
cmask shape: (1, 1, 640, 320) (36, 1, 640, 320) (36, 1, 640, 320) (36, 1, 640, 320)


  sos = np.sqrt(np.sum(coil_images * np.conj(coil_images), 1)).astype(np.float32)


/host/d/Data/NYU_MR/multicoil_train/data/file1000015.h5
kspace shape: (36, 15, 640, 372)
eNx,eNy,eNz: 640 372 1 rNx,rNy,rNz: 320 320 1
nslices,ncoils,Nx,Ny: 36 15 372 640
Removing oversampling according to XML...
all_data shape: (36, 15, 640, 320)
cmask shape: (1, 1, 640, 320) (36, 1, 640, 320) (36, 1, 640, 320) (36, 1, 640, 320)
/host/d/Data/NYU_MR/multicoil_train/data/file1000021.h5
kspace shape: (35, 15, 640, 368)
eNx,eNy,eNz: 640 368 1 rNx,rNy,rNz: 320 320 1
nslices,ncoils,Nx,Ny: 35 15 368 640
Removing oversampling according to XML...
all_data shape: (35, 15, 640, 320)
cmask shape: (1, 1, 640, 320) (35, 1, 640, 320) (35, 1, 640, 320) (35, 1, 640, 320)
/host/d/Data/NYU_MR/multicoil_train/data/file1000057.h5
kspace shape: (42, 15, 640, 368)
eNx,eNy,eNz: 640 368 1 rNx,rNy,rNz: 320 320 1
nslices,ncoils,Nx,Ny: 42 15 368 640
Removing oversampling according to XML...
all_data shape: (42, 15, 640, 320)
cmask shape: (1, 1, 640, 320) (42, 1, 640, 320) (42, 1, 640, 320) (42, 1, 640, 320)
/hos