In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch

from utils.makepath import makepath
from data.mri.mri_data import DataUtil
from data.mri.naming import get_test_file_name
from encoding_objects.cart_2d_enc_obj import Cart2DEncObj

In [3]:
root_dir = makepath("..", "..")
os.listdir(root_dir)

['scripts',
 'requirements.txt',
 'mri.egg-info',
 'venv',
 'README.md',
 'figures',
 'config',
 'LICENSE',
 'utils',
 'networks',
 'tmp',
 'dyn_mri_test.py',
 '.gitignore',
 'gradops',
 'pyproject.toml',
 'gifs',
 'data',
 'pdhg',
 'data_lib',
 'wandb',
 'encoding_objects',
 '.git']

In [5]:
data_dir = makepath(root_dir, "tmp", "mri_data")
os.listdir(data_dir)

['x_true_val_150.pt',
 'test',
 'x_true_train_3000.pt',
 'x_true_example.pt',
 'x_true_all_3452.pt',
 'BACKUP_data',
 'x_true_test_302.pt']

In [6]:
# action, num_samples = "test", 302
action, num_samples = "val", 150

x_true_test_file_name = f"x_true_{action}_{num_samples}.pt"

In [7]:
x_true_test = torch.load(makepath(data_dir, x_true_test_file_name))

In [8]:
x_true_test.shape

torch.Size([150, 320, 320])

In [9]:
scale_factor = 1000

In [10]:
scaled_x_true_test = x_true_test * scale_factor

In [11]:
torch.save(scaled_x_true_test, f"scaled_x_true_{action}_{num_samples}-scale_{scale_factor}.pt")

In [12]:
print(f"Min abs value of x_true_test: {x_true_test.abs().min()}")

Min abs value of x_true_test: 4.595295699516555e-09


In [13]:
print(f"Max abs value of x_true_test: {x_true_test.abs().max()}")

Max abs value of x_true_test: 0.001241995021700859


In [14]:
encoder = Cart2DEncObj()

In [15]:
x_true_test_slice = x_true_test[0].unsqueeze(0)
x_true_test_slice.shape

torch.Size([1, 320, 320])

In [16]:
k_true_test_slice = encoder.apply_A(x_true_test_slice, csm=None, mask=None)
k_true_test_slice.shape

torch.Size([1, 1, 320, 320])

In [17]:
k_true_test_slice.dtype

torch.complex64

In [18]:
data_util = DataUtil(data_config=None, device="cpu")

In [19]:
EncObj = data_util.EncObj

In [20]:
coil_sensitivity_map = None

In [21]:
for acceleration_rate_R in [4, 6, 8]:
    for gaussian_noise_standard_deviation in [0.05, 0.10, 0.15, 0.20]:

        x_corrupted = []
        kdata_corrupted = []
        undersampling_kmasks = []

        for i in range(scaled_x_true_test.shape[0]):
            scaled_x_slice = scaled_x_true_test[i].unsqueeze(0)
            x_corrupted_slice, kdata_slice, kmask = data_util.get_corrupted_data(
                scaled_x_slice, acceleration_rate_R, gaussian_noise_standard_deviation)

            x_corrupted.append(x_corrupted_slice[0])
            kdata_corrupted.append(kdata_slice[0])
            undersampling_kmasks.append(kmask[0])

        x_corrupted_test = torch.stack(x_corrupted)
        kdata_corrupted_test = torch.stack(kdata_corrupted)
        undersampling_kmasks_test = torch.stack(undersampling_kmasks)
        print(f"x_corrupted_test.shape = {x_corrupted_test.shape}")
        print(f"kdata_corrupted_test.shape = {kdata_corrupted_test.shape}")
        print(f"undersampling_kmasks_test.shape = {undersampling_kmasks_test.shape}")
        # Save in current working directory
        torch.save(
            x_corrupted_test, get_test_file_name(
                "x_corrupted", action, acceleration_rate_R, gaussian_noise_standard_deviation))
        torch.save(
            kdata_corrupted_test, get_test_file_name(
                "kdata_corrupted", action, acceleration_rate_R, gaussian_noise_standard_deviation))
        torch.save(
            undersampling_kmasks_test, get_test_file_name(
                "undersampling_kmasks", action, acceleration_rate_R, gaussian_noise_standard_deviation))

  kspace_data = torch.masked_select(kspace_data, mask.to(torch.bool)).view(


x_corrupted_test.shape = torch.Size([150, 320, 320])
kdata_corrupted_test.shape = torch.Size([150, 1, 320, 320])
undersampling_kmasks_test.shape = torch.Size([150, 320, 320])
x_corrupted_test.shape = torch.Size([150, 320, 320])
kdata_corrupted_test.shape = torch.Size([150, 1, 320, 320])
undersampling_kmasks_test.shape = torch.Size([150, 320, 320])
x_corrupted_test.shape = torch.Size([150, 320, 320])
kdata_corrupted_test.shape = torch.Size([150, 1, 320, 320])
undersampling_kmasks_test.shape = torch.Size([150, 320, 320])
x_corrupted_test.shape = torch.Size([150, 320, 320])
kdata_corrupted_test.shape = torch.Size([150, 1, 320, 320])
undersampling_kmasks_test.shape = torch.Size([150, 320, 320])
x_corrupted_test.shape = torch.Size([150, 320, 320])
kdata_corrupted_test.shape = torch.Size([150, 1, 320, 320])
undersampling_kmasks_test.shape = torch.Size([150, 320, 320])
x_corrupted_test.shape = torch.Size([150, 320, 320])
kdata_corrupted_test.shape = torch.Size([150, 1, 320, 320])
undersamplin