In [23]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
import torch
import os
from tqdm.notebook import tqdm
from typing import Callable, Union, Dict, Any

from networks.mri_pdhg_net import MriPdhgNet
from utils.makepath import makepath as mkp
from config.config_loader import load_config
from data.mri.data_loader import get_data_loader, get_dataset
from scripts.epoch import perform_epoch
from scripts.mri.test.test_utils import get_config_and_model
from scripts.mri.logger import Logger
from utils.metrics import ImageMetricsEvaluator
from scripts.mri.mri_iteration import MriIteration

In [25]:
# device = "cpu"
device = "cuda"

In [26]:
root_dir = mkp("..", "..")

In [27]:
u_tv_config, u_tv_net = get_config_and_model(
    model_id="u_tv", state_dict_file="new_u_tv_model_state_dict_30.pth",
    device=device, root_dir=root_dir)
u_tv_data_config = u_tv_config["data"]
u_tv_net.eval()
print(f"U-TV sigma: {u_tv_net.sigma}")
print(f"U-TV tau: {u_tv_net.tau}")

Config loaded from file ../../pretrained/u_tv/config.yaml
Config loaded from dict
Loading model on device: cuda
Norm of operator A: 1
Norm of gradient operator nabla: 2.8284270763397217
L: 3.0
PDHG net device: cuda
u_tv model loaded
U-TV sigma: 0.42191851139068604
U-TV tau: 0.263347327709198


In [28]:
u_tgv_config, u_tgv_net = get_config_and_model(
    model_id="u_tgv", state_dict_file="new_u_tgv_model_state_dict_30.pth",
    device=device, root_dir=root_dir)
u_tgv_data_config = u_tgv_config["data"]
u_tgv_net.eval()
print(f"U-TGV sigma: {u_tgv_net.sigma}")
print(f"U-TGV tau: {u_tgv_net.tau}")

Config loaded from file ../../pretrained/u_tgv/config.yaml
Config loaded from dict
Loading model on device: cuda
Norm of operator A: 1
Norm of gradient operator nabla: 2.8284270763397217
L: 3.0
PDHG net device: cuda
u_tgv model loaded
U-TGV sigma: 0.19924500584602356
U-TGV tau: 0.5576607584953308


In [29]:
assert u_tv_data_config == u_tgv_data_config, "Data config must be the same for both models"
data_config = u_tv_data_config
data_config = u_tgv_data_config

In [30]:
test_dataset = get_dataset(
    action="test",
    dataset_type="preprocessed",
    data_config=data_config,
    device=device,
    acceleration_factor_R=4,
    gaussian_noise_standard_deviation_sigma=0.05
)



Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 5.752422566729365e-06
max abs val of corrupted: 1.2349826097488403

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB



In [31]:
def get_zero_filled(
    batch_kdata,
    batch_kmask,
    batch_x,
    batch_csmap
):
    return batch_x, None

In [32]:
complex_to_real_conversion = "abs"
# complex_to_real_conversion = "view_as_real"

metrics_evaluator = ImageMetricsEvaluator(
    complex_to_real_conversion=complex_to_real_conversion, device=device)

NOTE: Complex-to-real conversion method: abs


In [37]:
def test_and_save_results(
        config: Dict[str, Any],
        model: Union[MriPdhgNet, Callable],
        model_name: str,
        general_action: str,
        acceleration_factor_R: int,
        sigma: float,
        force_overwrite: bool = False
):
    print(f"Testing with {general_action} data, R = {acceleration_factor_R}, sigma = {sigma}")

    test_data_loader = get_data_loader(
        data_config=config["data"],
        action=general_action,
        dataset_type="preprocessed",
        device=device,
        acceleration_factor_R=acceleration_factor_R,
        gaussian_noise_standard_deviation_sigma=sigma
    )

    specific_action = f"{general_action}-{complex_to_real_conversion}-R_{acceleration_factor_R}-sigma_{sigma:.2f}".replace(".", "_")

    test_logger = Logger(
        action=specific_action,
        config=config,
        force_overwrite=force_overwrite
    )

    # Adjust the save directory as required
    test_logger.save_dir = mkp(".", "results", "test_set", model_name, specific_action)
    os.makedirs(test_logger.save_dir, exist_ok=True)
    print(f"NOTE: New save directory: {test_logger.save_dir}")

    test_logger.init_metrics_logging_options()
    # Set the intermediate log frequency to 1
    test_logger.log_freq_by_iter["intermediate"] = 1

    test_data_iterator = tqdm(test_data_loader, desc="Testing")

    mri_iteration = MriIteration(model=model, metrics_evaluator=metrics_evaluator)

    with torch.no_grad():
        avg_metrics = perform_epoch(
            data_iterator=test_data_iterator,
            # model=model,
            is_training=False,
            # metrics_evaluator=metrics_evaluator,
            logger=test_logger,
            # learning_rate_scheduler=None,
            # optimizer=None,
            perform_iteration=mri_iteration.perform_iteration,
            sets_tqdm_postfix=True
        )
    test_data_iterator.close()
    del test_data_loader
    return avg_metrics

In [38]:
zero_filled_config = load_config(
    config_choice=mkp(root_dir, "config", "zero_filled_config.yaml"),
    is_training=False,
    root_dir=root_dir
)

Config loaded from file ../../config/zero_filled_config.yaml


In [40]:
# acceleration_factor_R = 4
# sigma = 0.15
# general_action = "test"
# # action = "val"
# force_overwrite = False
# # force_overwrite = True

for acceleration_factor_R in [6, 8]:
    for sigma in [0.05, 0.10, 0.15, 0.20]:
        test_and_save_results(
            # config=zero_filled_config, model=get_zero_filled, model_name="zero_filled",
            config=u_tv_config, model=u_tv_net, model_name="u_tv",
            # config=u_tgv_config, model=u_tgv_net, model_name="u_tgv",
            general_action="test",
            acceleration_factor_R=acceleration_factor_R,
            sigma=sigma,
            force_overwrite=False
            # force_overwrite=True
        )

Testing with test data, R = 6, sigma = 0.05


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 3.362683173691039e-06
max abs val of corrupted: 1.25584876537323

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_6-sigma_0_05
Save directory: pretrained/u_tv
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
NOTE: New save directory: results/test_set/u_tv/test-abs-R_6-sigma_0_05
File 'results/test_set/u_tv/test-abs-R_6-sigma_0_05/test-abs-R_6-sigma_0_05_epoch_metrics.cs

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 6, sigma = 0.1


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 2.467540070938412e-06
max abs val of corrupted: 1.2514926195144653

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_6-sigma_0_10
Save directory: pretrained/u_tv
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
NOTE: New save directory: results/test_set/u_tv/test-abs-R_6-sigma_0_10
File 'results/test_set/u_tv/test-abs-R_6-sigma_0_10/test-abs-R_6-sigma_0_10_epoch_metrics.c

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 6, sigma = 0.15


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 2.625301931402646e-06
max abs val of corrupted: 1.2493090629577637

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_6-sigma_0_15
Save directory: pretrained/u_tv
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
NOTE: New save directory: results/test_set/u_tv/test-abs-R_6-sigma_0_15
File 'results/test_set/u_tv/test-abs-R_6-sigma_0_15/test-abs-R_6-sigma_0_15_epoch_metrics.

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 6, sigma = 0.2


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 4.8182196223933715e-06
max abs val of corrupted: 1.2492433786392212

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_6-sigma_0_20
Save directory: pretrained/u_tv
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
NOTE: New save directory: results/test_set/u_tv/test-abs-R_6-sigma_0_20
File 'results/test_set/u_tv/test-abs-R_6-sigma_0_20/test-abs-R_6-sigma_0_20_epoch_metrics.

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 8, sigma = 0.05


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 2.1736816506745527e-06
max abs val of corrupted: 1.1920750141143799

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_8-sigma_0_05
Save directory: pretrained/u_tv
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
NOTE: New save directory: results/test_set/u_tv/test-abs-R_8-sigma_0_05
File 'results/test_set/u_tv/test-abs-R_8-sigma_0_05/test-abs-R_8-sigma_0_05_epoch_metrics

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 8, sigma = 0.1


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 5.431713361758739e-06
max abs val of corrupted: 1.1884340047836304

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_8-sigma_0_10
Save directory: pretrained/u_tv
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
NOTE: New save directory: results/test_set/u_tv/test-abs-R_8-sigma_0_10
File 'results/test_set/u_tv/test-abs-R_8-sigma_0_10/test-abs-R_8-sigma_0_10_epoch_metrics.c

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 8, sigma = 0.15


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 4.317465482017724e-06
max abs val of corrupted: 1.189072847366333

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_8-sigma_0_15
Save directory: pretrained/u_tv
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
NOTE: New save directory: results/test_set/u_tv/test-abs-R_8-sigma_0_15
File 'results/test_set/u_tv/test-abs-R_8-sigma_0_15/test-abs-R_8-sigma_0_15_epoch_metrics.c

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 8, sigma = 0.2


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 2.0816048618144123e-06
max abs val of corrupted: 1.191650629043579

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_8-sigma_0_20
Save directory: pretrained/u_tv
Force overwrite: True
Current epoch: 0
Please initialize the logging options.
NOTE: New save directory: results/test_set/u_tv/test-abs-R_8-sigma_0_20
File 'results/test_set/u_tv/test-abs-R_8-sigma_0_20/test-abs-R_8-sigma_0_20_epoch_metrics.c

Testing:   0%|          | 0/302 [00:00<?, ?it/s]