In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from tqdm.notebook import tqdm
from typing import Callable, Union, Dict, Any

from networks.mri_pdhg_net import MriPdhgNet
from utils.makepath import makepath as mkp
from config.config_loader import load_config
from data.mri.data_loader import get_data_loader, get_dataset
from scripts.mri.epoch import perform_epoch
from scripts.mri.test_utils import get_data_config_and_models
from scripts.mri.logger import Logger
from utils.metrics import ImageMetricsEvaluator

In [3]:
# device = "cpu"
device = "cuda"

In [4]:
root_dir = mkp("..", "..")

In [5]:
data_config, u_tv_net, u_tgv_net = get_data_config_and_models(
    device=device, root_dir=root_dir)

Config loaded from file ../../tmp/mri_model_09_14-14_37-good_TV-sigma_to_0_2-R_from_4/config.yaml
Config loaded from dict
Norm of operator A: 1
Norm of gradient operator nabla: 2.8284270763397217
L: 3.0
u_tv model loaded
Config loaded from file ../../tmp/mri_model_09_12-23_02-good_TGV-sigma_to_0_2-R_from_4/config.yaml
Config loaded from dict
Norm of operator A: 1
Norm of gradient operator nabla: 2.8284270763397217
L: 3.0
u_tgv model loaded


In [6]:
print(f"U-TV sigma: {u_tv_net.sigma}")
print(f"U-TV tau: {u_tv_net.tau}")

U-TV sigma: 0.3413746953010559
U-TV tau: 0.3254814147949219


In [7]:
print(f"U-TGV sigma: {u_tgv_net.sigma}")
print(f"U-TGV tau: {u_tgv_net.tau}")

U-TGV sigma: 0.16954700648784637
U-TGV tau: 0.6553410887718201


In [8]:
test_dataset = get_dataset(
    action="test",
    data_config=data_config,
    device=device,
    acceleration_factor_R=4,
    gaussian_noise_standard_deviation_sigma=0.05
)



Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Corrupted data shape: torch.Size([302, 320, 320])
min abs val of corrupted: 5.752422566729365e-06
max abs val of corrupted: 1.2349826097488403

Corrupted kdata shape: torch.Size([302, 1, 320, 320])

kmasks shape: torch.Size([302, 320, 320])



In [9]:
def get_zero_filled(
    batch_kdata,
    batch_kmask,
    batch_x,
    batch_csmap
):
    return batch_x, None

In [10]:
complex_to_real_conversion = "abs"
# complex_to_real_conversion = "view_as_real"

metrics_evaluator = ImageMetricsEvaluator(
    complex_to_real_conversion=complex_to_real_conversion, device=device)

NOTE: Complex-to-real conversion method: abs


In [11]:
def test_and_save_results(
        config: Dict[str, Any],
        model: Union[MriPdhgNet, Callable],
        model_name: str,
        general_action: str,
        acceleration_factor_R: int,
        sigma: float,
        force_overwrite: bool = False
):
    print(f"Testing with {general_action} data, R = {acceleration_factor_R}, sigma = {sigma}")

    test_data_loader = get_data_loader(
        data_config=config["data"],
        action=general_action,
        dataset_type="preprocessed",
        device=device,
        acceleration_factor_R=acceleration_factor_R,
        gaussian_noise_standard_deviation_sigma=sigma
    )

    specific_action = f"{general_action}-{complex_to_real_conversion}-R_{acceleration_factor_R}-sigma_{sigma:.2f}".replace(".", "_")

    test_logger = Logger(
        action=specific_action,
        config=config,
        force_overwrite=force_overwrite
    )

    # Adjust the save directory as required
    test_logger.save_dir = mkp(".", "results", "test_set", model_name, specific_action)
    os.makedirs(test_logger.save_dir, exist_ok=True)
    print(f"NOTE: New save directory: {test_logger.save_dir}")

    test_logger.init_metrics_logging_options()
    # Set the intermediate log frequency to 1
    test_logger.log_freq_by_iter["intermediate"] = 1

    test_data_iterator = tqdm(test_data_loader, desc="Testing")

    avg_metrics = perform_epoch(
        data_iterator=test_data_iterator,
        model=model,
        is_training=False,
        metrics_evaluator=metrics_evaluator,
        logger=test_logger,
        # learning_rate_scheduler=None,
        # optimizer=None,
        sets_tqdm_postfix=True
    )
    return avg_metrics

In [12]:
zero_filled_config = load_config(
    config_choice=mkp(root_dir, "config", "zero_filled_config.yaml"),
    is_training=False,
    root_dir=root_dir
)

Config loaded from file ../../config/zero_filled_config.yaml


In [13]:
# acceleration_factor_R = 4
# sigma = 0.15
# general_action = "test"
# # action = "val"
# force_overwrite = False
# # force_overwrite = True

for acceleration_factor_R in [4]:
    for sigma in [0.05, 0.1, 0.15, 0.2]:
        test_and_save_results(
            config=zero_filled_config,
            # model=get_zero_filled, model_name="zero_filled",
            # model=u_tv_net, model_name="u_tv",
            model=u_tgv_net, model_name="u_tgv",
            general_action="test",
            acceleration_factor_R=acceleration_factor_R,
            sigma=sigma,
            # force_overwrite=False
            force_overwrite=True
        )

Testing with test data, R = 4, sigma = 0.05


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Corrupted data shape: torch.Size([302, 320, 320])
min abs val of corrupted: 5.752422566729365e-06
max abs val of corrupted: 1.2349826097488403

Corrupted kdata shape: torch.Size([302, 1, 320, 320])

kmasks shape: torch.Size([302, 320, 320])

test_data_loader contains 302 batches.
Action: test-abs-R_4-sigma_0_05
Save directory: ../../tmp/zero_filled
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
File 'results/test_set/u_tgv/test-abs-R_4-sigma_0_05/test-abs-R_4-sigma_0_05_epoch_metrics.csv' already exists.
Overwriting the file 'results/test_set/u_tgv/test-abs-R_4-sigma_0_05/test-abs-R_4-sigma_0_05_epoch_metrics.csv'...
File 'results/test_set/u_tgv/test-abs-R_4-sigma_0_05/test-abs-R_4-sigma_0_05_epoch_metrics.csv' initialized.
File 'results/test_set/u_tgv/test-ab

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

min_val = 9.105447134061251e-07
max_val = 1.3289061784744263
Testing with test data, R = 4, sigma = 0.1


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Corrupted data shape: torch.Size([302, 320, 320])
min abs val of corrupted: 2.8516878955997527e-06
max abs val of corrupted: 1.2406718730926514

Corrupted kdata shape: torch.Size([302, 1, 320, 320])

kmasks shape: torch.Size([302, 320, 320])

test_data_loader contains 302 batches.
Action: test-abs-R_4-sigma_0_10
Save directory: ../../tmp/zero_filled
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
File 'results/test_set/u_tgv/test-abs-R_4-sigma_0_10/test-abs-R_4-sigma_0_10_epoch_metrics.csv' already exists.
Overwriting the file 'results/test_set/u_tgv/test-abs-R_4-sigma_0_10/test-abs-R_4-sigma_0_10_epoch_metrics.csv'...
File 'results/test_set/u_tgv/test-abs-R_4-sigma_0_10/test-abs-R_4-sigma_0_10_epoch_m

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

min_val = 1.5852201613597572e-06
max_val = 1.3109369277954102
Testing with test data, R = 4, sigma = 0.15


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Corrupted data shape: torch.Size([302, 320, 320])
min abs val of corrupted: 8.815023647912312e-06
max abs val of corrupted: 1.2514125108718872

Corrupted kdata shape: torch.Size([302, 1, 320, 320])

kmasks shape: torch.Size([302, 320, 320])

test_data_loader contains 302 batches.
Action: test-abs-R_4-sigma_0_15
Save directory: ../../tmp/zero_filled
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
File 'results/test_set/u_tgv/test-abs-R_4-sigma_0_15/test-abs-R_4-sigma_0_15_epoch_metrics.csv' already exists.
Overwriting the file 'results/test_set/u_tgv/test-abs-R_4-sigma_0_15/test-abs-R_4-sigma_0_15_epoch_metrics.csv'...
File 'results/test_set/u_tgv/test-abs-R_4-sigma_0_15/test-abs-R_4-sigma_0_15_epoch_

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

min_val = 2.6396403427497717e-06
max_val = 1.3041800260543823
Testing with test data, R = 4, sigma = 0.2


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Corrupted data shape: torch.Size([302, 320, 320])
min abs val of corrupted: 1.3283834050525911e-06
max abs val of corrupted: 1.2621699571609497

Corrupted kdata shape: torch.Size([302, 1, 320, 320])

kmasks shape: torch.Size([302, 320, 320])

test_data_loader contains 302 batches.
Action: test-abs-R_4-sigma_0_20
Save directory: ../../tmp/zero_filled
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
File 'results/test_set/u_tgv/test-abs-R_4-sigma_0_20/test-abs-R_4-sigma_0_20_epoch_metrics.csv' already exists.
Overwriting the file 'results/test_set/u_tgv/test-abs-R_4-sigma_0_20/test-abs-R_4-sigma_0_20_epoch_metrics.csv'...
File 'results/test_set/u_tgv/test-abs-R_4-sigma_0_20/test-abs-R_4-sigma_0_20_epoch_

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

KeyboardInterrupt: 