In [None]:
import fastmri
from fastmri.data import transforms

import h5py
import numpy as np
from matplotlib import pyplot as plt

import torch
from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union
from fastmri.data.transforms import to_tensor, center_crop, complex_center_crop

import glob
from tqdm import tqdm

import os

RECON_SIZE = (320, 320)
RSS_RECON_SIZE = (320, 320)
NUM_COMPRESSED_COILS = 16 # 4
NUM_LIMIT_COILS = 16
assert NUM_COMPRESSED_COILS <= NUM_LIMIT_COILS

# whole validation dataset

SRC = '/mnt/e/nyu_fastmri_brain/multicoil_val/' # directory of NYU fastMRI brain multi-coil validation dataset
DST_VAL = '/mnt/e/nyu_fastmri_brain/multibrain_dataset/multicoil_val/' # directory of compressed validation dataset
DST_TEST = '/mnt/e/nyu_fastmri_brain/multibrain_dataset/multicoil_test/' # directory of compressed test dataset

os.makedirs(DST_VAL, exist_ok=True)
os.makedirs(DST_TEST, exist_ok=True)

np.random.seed(42)
sample_rate = 0.2

total_num = 0
for file_idx in tqdm(range(len(glob.glob(SRC + 'file*.h5')))):
    file_name = glob.glob(SRC + 'file*.h5')[file_idx]
    with h5py.File(file_name, 'r') as hf:        
        n_slice, n_coil, n_x, n_y = hf['kspace'].shape
        if (not n_coil == NUM_LIMIT_COILS) or (n_x < RECON_SIZE[0]) or (n_y < RECON_SIZE[1]) :
            continue
        if not hf.attrs["acquisition"] == "AXT2" :
            continue

        total_num += 1

print(total_num)
select_idx = np.random.choice(np.arange(total_num), size = int(total_num * sample_rate) * 2, replace=False)

val_select_idx = select_idx[:int(total_num * sample_rate)]
test_select_idx = select_idx[int(total_num * sample_rate):]
print(len(select_idx))
print(select_idx)

current_idx = 0
for file_idx in tqdm(range(len(glob.glob(SRC + 'file*.h5')))):
    file_name = glob.glob(SRC + 'file*.h5')[file_idx]
    with h5py.File(file_name, 'r') as hf:        
        n_slice, n_coil, n_x, n_y = hf['kspace'].shape
        if (not n_coil == NUM_LIMIT_COILS) or (n_x < RECON_SIZE[0]) or (n_y < RECON_SIZE[1]) :
            continue
        if not hf.attrs["acquisition"] == "AXT2" :
            continue

        kspace = fastmri.fft2c(
            complex_center_crop(fastmri.ifft2c(to_tensor(hf['kspace'][()])), RECON_SIZE)
        ).numpy()
        
        kspace_complex = kspace.view(dtype=np.complex64)[...,0]
        kspace_compressed = np.empty_like(kspace_complex)[:,0:NUM_COMPRESSED_COILS,...]
        reconstruction_rss = np.zeros(shape=((n_slice,)+RECON_SIZE))
        
        for i in range(n_slice):
            kspace_slice = kspace_complex[i]
            
            kspace_compressed[i, ...] = kspace_slice
            
            reconstruction_slice = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(to_tensor(kspace_slice))), dim=0)
            reconstruction_rss[i, ...] = reconstruction_slice
            
        reconstruction_rss = center_crop(reconstruction_rss, RSS_RECON_SIZE).astype(np.float32)
        max_new = reconstruction_rss.max()
        norm_new = np.linalg.norm(reconstruction_rss)

        if current_idx in val_select_idx :
            DST = DST_VAL
        elif current_idx in test_select_idx :
            DST = DST_TEST
        else : 
            current_idx += 1
            continue
        
        with h5py.File(DST+file_name[len(SRC):], 'w') as hfd:
            for key in list(hf.keys()):
                if key == 'kspace':
                    hfd.create_dataset('kspace',data=kspace_compressed,maxshape=kspace_compressed.shape)
                elif key == 'reconstruction_rss':
                    hfd.create_dataset('reconstruction_rss',data=reconstruction_rss,maxshape=reconstruction_rss.shape)
                else:
                    old_data = hf[key][()]
                    hfd.create_dataset(key,data=old_data)
            global_attrs = dict(hf.attrs)
            global_attrs['max'] = max_new
            global_attrs['norm'] = norm_new
            hfd.attrs.update(global_attrs)
            
        current_idx += 1

In [None]:

import fastmri
from fastmri.data import transforms

import h5py
import numpy as np
from matplotlib import pyplot as plt

import torch
from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union
from fastmri.data.transforms import to_tensor, center_crop, complex_center_crop

import glob
from tqdm import tqdm

import os

RECON_SIZE = (320, 320)
RSS_RECON_SIZE = (320, 320)
NUM_COMPRESSED_COILS = 16 # 4
NUM_LIMIT_COILS = 16
assert NUM_COMPRESSED_COILS <= NUM_LIMIT_COILS

# whole train dataset

SRC = '/mnt/e/nyu_fastmri_brain/multicoil_train/' # directory of NYU fastMRI brain multi-coil train dataset
DST = '/mnt/e/nyu_fastmri_brain/multibrain_dataset/multicoil_train/' # directory of compressed train dataset

os.makedirs(DST, exist_ok=True)

np.random.seed(42)
sample_rate = 0.2

total_num = 0
for file_idx in tqdm(range(len(glob.glob(SRC + 'file*.h5')))):
    file_name = glob.glob(SRC + 'file*.h5')[file_idx]
    with h5py.File(file_name, 'r') as hf:        
        n_slice, n_coil, n_x, n_y = hf['kspace'].shape
        if (not n_coil == NUM_LIMIT_COILS) or (n_x < RECON_SIZE[0]) or (n_y < RECON_SIZE[1]) :
            continue
        if not hf.attrs["acquisition"] == "AXT2" :
            continue
        total_num += 1

print(total_num)
select_idx = np.random.choice(np.arange(total_num), size = int(total_num * sample_rate), replace=False)
print(len(select_idx))
print(select_idx)

current_idx = 0
for file_idx in tqdm(range(len(glob.glob(SRC + 'file*.h5')))):
    file_name = glob.glob(SRC + 'file*.h5')[file_idx]
    with h5py.File(file_name, 'r') as hf:   
        n_slice, n_coil, n_x, n_y = hf['kspace'].shape
        if (not n_coil == NUM_LIMIT_COILS) or (n_x < RECON_SIZE[0]) or (n_y < RECON_SIZE[1]) :
            continue
        if not hf.attrs["acquisition"] == "AXT2" :
            continue

        kspace = fastmri.fft2c(
            complex_center_crop(fastmri.ifft2c(to_tensor(hf['kspace'][()])), RECON_SIZE)
        ).numpy()
        
        kspace_complex = kspace.view(dtype=np.complex64)[...,0]
        kspace_compressed = np.empty_like(kspace_complex)[:,0:NUM_COMPRESSED_COILS,...]
        reconstruction_rss = np.zeros(shape=((n_slice,)+RECON_SIZE))
        
        for i in range(n_slice):
            kspace_slice = kspace_complex[i]
            
            kspace_compressed[i, ...] = kspace_slice
            
            reconstruction_slice = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(to_tensor(kspace_slice))), dim=0)
            reconstruction_rss[i, ...] = reconstruction_slice
            
        reconstruction_rss = center_crop(reconstruction_rss, RSS_RECON_SIZE).astype(np.float32)
        max_new = reconstruction_rss.max()
        norm_new = np.linalg.norm(reconstruction_rss)

        if current_idx in select_idx :
            with h5py.File(DST+file_name[len(SRC):], 'w') as hfd:
                for key in list(hf.keys()):
                    if key == 'kspace':
                        hfd.create_dataset('kspace',data=kspace_compressed,maxshape=kspace_compressed.shape)
                    elif key == 'reconstruction_rss':
                        hfd.create_dataset('reconstruction_rss',data=reconstruction_rss,maxshape=reconstruction_rss.shape)
                    else:
                        old_data = hf[key][()]
                        hfd.create_dataset(key,data=old_data)
                global_attrs = dict(hf.attrs)
                global_attrs['max'] = max_new
                global_attrs['norm'] = norm_new
                hfd.attrs.update(global_attrs)
            current_idx += 1
        else : 
            current_idx += 1
            continue