In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import os
from tqdm.notebook import tqdm
from typing import Callable, Union, Dict, Any

from networks.mri_pdhg_net import MriPdhgNet
from utils.makepath import makepath as mkp
from config.config_loader import load_config
from data.mri.data_loader import get_data_loader, get_dataset
from scripts.epoch import perform_epoch
from scripts.mri.test.test_utils import get_config_and_model
from scripts.logger import Logger
from utils.metrics import ImageMetricsEvaluator
from scripts.mri.mri_iteration import MriIteration

In [4]:
# device = "cpu"
# device = "cuda"
device = "mps"

In [5]:
root_dir = mkp("..", "..")

In [6]:
u_tv_config, u_tv_net = get_config_and_model(
    # model_id="u_tv",
    # state_dict_file="new_u_tv_model_state_dict_30.pth",
    model_id="u_tv_fixed",
    state_dict_file="model_state_dict_43.pth",
    device=device, root_dir=root_dir)
u_tv_data_config = u_tv_config["data"]
u_tv_net.eval()
print(f"U-TV sigma: {u_tv_net.sigma}")
print(f"U-TV tau: {u_tv_net.tau}")

Config loaded from file ../../pretrained/u_tv_fixed/config.yaml
Config loaded from dict
Loading model on device: mps
Loading model on device: mps
Norm of operator A: 1
Norm of gradient operator nabla: 2.8284270763397217
L: 3.0
PDHG net device: mps
Using my U-Net implementation!
Using my own UNet2d class
Loaded model state dict from ../../pretrained/u_tv_fixed/model_state_dict_43.pth
u_tv_fixed model loaded
U-TV sigma: 0.12612056732177734
U-TV tau: 0.8809912204742432


In [7]:
# u_tgv_config, u_tgv_net = get_config_and_model(
#     # model_id="u_tgv",
#     # state_dict_file="new_u_tgv_model_state_dict_30.pth",
#     model_id="u_tgv_fixed",
#     state_dict_file="model_state_dict_72.pth",
#     device=device, root_dir=root_dir)
# u_tgv_data_config = u_tgv_config["data"]
# u_tgv_net.eval()
# print(f"U-TGV sigma: {u_tgv_net.sigma}")
# print(f"U-TGV tau: {u_tgv_net.tau}")

In [8]:
# u_tgv_net

In [9]:
# assert u_tv_data_config == u_tgv_data_config, "Data config must be the same for both models"
data_config = u_tv_data_config
# data_config = u_tgv_data_config

In [10]:
# test_dataset = get_dataset(
#     action="test",
#     dataset_type="preprocessed",
#     data_config=data_config,
#     device=device,
#     acceleration_factor_R=4,
#     gaussian_noise_standard_deviation_sigma=0.05
# )

In [11]:
def get_zero_filled(
    batch_kdata,
    batch_kmask,
    batch_x,
    batch_csmap
):
    return batch_x, None

In [12]:
complex_to_real_conversion = "abs"
# complex_to_real_conversion = "view_as_real"

metrics_evaluator = ImageMetricsEvaluator(
    complex_to_real_conversion=complex_to_real_conversion, device=device)

NOTE: Complex-to-real conversion method: abs


In [13]:
def test_and_save_results(
        config: Dict[str, Any],
        model: Union[MriPdhgNet, Callable],
        model_name: str,
        general_action: str,
        acceleration_factor_R: int,
        sigma: float,
        force_overwrite: bool = False
):
    print(f"Testing with {general_action} data, R = {acceleration_factor_R}, sigma = {sigma}")

    test_data_loader = get_data_loader(
        data_config=config["data"],
        action=general_action,
        dataset_type="preprocessed",
        device=device,
        acceleration_factor_R=acceleration_factor_R,
        gaussian_noise_standard_deviation_sigma=sigma
    )

    specific_action = f"{general_action}-{complex_to_real_conversion}-R_{acceleration_factor_R}-sigma_{sigma:.2f}".replace(".", "_")

    test_logger = Logger(
        action=specific_action,
        config=config,
        force_overwrite=force_overwrite
    )

    # # Adjust the save directory as required
    # test_logger.save_dir = mkp(".", "results", "test_set", model_name, specific_action)
    # os.makedirs(test_logger.save_dir, exist_ok=True)
    # print(f"NOTE: New save directory: {test_logger.save_dir}")
    print(f"Final save directory: {test_logger.save_dir}")

    test_logger.init_metrics_logging_options()
    # Set the intermediate log frequency to 1
    test_logger.log_freq_by_iter["intermediate"] = 1

    test_data_iterator = tqdm(test_data_loader, desc="Testing")

    mri_iteration = MriIteration(model=model, metrics_evaluator=metrics_evaluator)

    with torch.no_grad():
        avg_metrics = perform_epoch(
            data_loader=test_data_loader,
            perform_iteration=mri_iteration.perform_iteration,
            is_training=False,
            logger=test_logger,
            tqdm=tqdm,
            sets_tqdm_postfix=False
        )
    test_data_iterator.close()
    del test_data_loader
    return avg_metrics

In [14]:
zero_filled_config = load_config(
    config_choice=mkp(root_dir, "config", "zero_filled_config.yaml"),
    is_training=False,
    root_dir=root_dir
)

Config loaded from file ../../config/zero_filled_config.yaml


In [15]:
# acceleration_factor_R = 4
# sigma = 0.15
# general_action = "test"
# # action = "val"
# force_overwrite = False
# # force_overwrite = True

for acceleration_factor_R in [4, 8]:
    for sigma in [0.05, 0.10, 0.20]:
        test_and_save_results(
            # config=zero_filled_config, model=get_zero_filled, model_name="zero_filled",
            config=u_tv_config, model=u_tv_net, model_name="u_tv",
            # config=u_tgv_config, model=u_tgv_net, model_name="u_tgv",
            general_action="test",
            acceleration_factor_R=acceleration_factor_R,
            sigma=sigma,
            force_overwrite=False
            # force_overwrite=True
        )

Testing with test data, R = 4, sigma = 0.05


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 5.752422566729365e-06
max abs val of corrupted: 1.2349826097488403

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_4-sigma_0_05
Save directory: pretrained/u_tv_fixed
Force overwrite: False
Current epoch: 0
Please initialize the logging options.
Final save directory: pretrained/u_tv_fixed
File 'pretrained/u_tv_fixed/test-abs-R_4-sigma_0_05_epoch_metrics.csv' initialized.
File 'pretrained/u_tv_fixed

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 4, sigma = 0.1


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 2.8516878955997527e-06
max abs val of corrupted: 1.2406718730926514

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_4-sigma_0_10
Save directory: pretrained/u_tv_fixed
Force overwrite: False
Current epoch: 0
Please initialize the logging options.
Final save directory: pretrained/u_tv_fixed
File 'pretrained/u_tv_fixed/test-abs-R_4-sigma_0_10_epoch_metrics.csv' initialized.
File 'pretrained/u_tv_fixed

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 4, sigma = 0.2


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 1.3283834050525911e-06
max abs val of corrupted: 1.2621699571609497

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_4-sigma_0_20
Save directory: pretrained/u_tv_fixed
Force overwrite: False
Current epoch: 0
Please initialize the logging options.
Final save directory: pretrained/u_tv_fixed
File 'pretrained/u_tv_fixed/test-abs-R_4-sigma_0_20_epoch_metrics.csv' initialized.
File 'pretrained/u_tv_fixed

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 8, sigma = 0.05


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 2.1736816506745527e-06
max abs val of corrupted: 1.1920750141143799

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_8-sigma_0_05
Save directory: pretrained/u_tv_fixed
Force overwrite: False
Current epoch: 0
Please initialize the logging options.
Final save directory: pretrained/u_tv_fixed
File 'pretrained/u_tv_fixed/test-abs-R_8-sigma_0_05_epoch_metrics.csv' initialized.
File 'pretrained/u_tv_fixe

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 8, sigma = 0.1


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 5.431713361758739e-06
max abs val of corrupted: 1.1884340047836304

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_8-sigma_0_10
Save directory: pretrained/u_tv_fixed
Force overwrite: False
Current epoch: 0
Please initialize the logging options.
Final save directory: pretrained/u_tv_fixed
File 'pretrained/u_tv_fixed/test-abs-R_8-sigma_0_10_epoch_metrics.csv' initialized.
File 'pretrained/u_tv_fixed/

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/302 [00:00<?, ?it/s]

Testing with test data, R = 8, sigma = 0.2


Ground truth data shape: torch.Size([302, 320, 320])
min abs val of ground truth: 1.8633207901075366e-06
max abs val of ground truth: 1.385725498199463

Ground truth data shape: torch.Size([302, 320, 320])
Memory size of ground truth: 235.9375 MB
Corrupted data shape: torch.Size([302, 320, 320])
Memory size of corrupted data: 235.9375 MB
min abs val of corrupted: 2.0816048618144123e-06
max abs val of corrupted: 1.191650629043579

Corrupted kdata shape: torch.Size([302, 1, 320, 320])
Memory size of corrupted kdata: 235.9375 MB

kmasks shape: torch.Size([302, 320, 320])
Memory size of kmasks: 235.9375 MB

test_data_loader contains 302 batches.
Action: test-abs-R_8-sigma_0_20
Save directory: pretrained/u_tv_fixed
Force overwrite: False
Current epoch: 0
Please initialize the logging options.
Final save directory: pretrained/u_tv_fixed
File 'pretrained/u_tv_fixed/test-abs-R_8-sigma_0_20_epoch_metrics.csv' initialized.
File 'pretrained/u_tv_fixed/

Testing:   0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/302 [00:00<?, ?it/s]