In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd

In [3]:
default_device = "cpu"
# default_device = "cuda"

In [None]:
from utils.makepath import makepath
root_dir = makepath("..", "..")
data_dir = makepath(root_dir, "tmp", "mri_data")

import os
os.listdir(data_dir)

In [5]:
data_config = {
    "data_path": data_dir,
    "data_scale_factor": 1000,
    "test_file_name": "x_true_test_302.pt",
    "test_num_samples": 302,
}

In [6]:
import matplotlib.pyplot as plt
def plot_image_with_color_bar(x: torch.Tensor, title=None):
    plt.imshow(x[0].abs().detach().cpu().numpy(), cmap="gray")
    plt.colorbar()
    plt.title(title)
    plt.show()

In [7]:
# plot_image_with_color_bar(x_corrupted, "x_corrupted")
# plot_image_with_color_bar(x_true, "x_true")

In [None]:
for acc_factor_R in [4, 8, 6]:
    for gaussian_noise_sigma in [0.05, 0.10, 0.20]:

        best_metric = "SSIM"
        range_lambda = (0, 0.2)
        num_search_iters = 20
        from datetime import datetime
        cur_datetime = datetime.now().strftime("%d_%H_%M")
        csv_file_name = f"test-abs-scalar_tv-best_{best_metric}-" + \
            f"R_{acc_factor_R}-" + \
            f"sigma_{gaussian_noise_sigma}-" + \
            f"lambda_{range_lambda[0]}_to_{range_lambda[1]}-" + \
            f"ternary_iters_{num_search_iters}-dt_{cur_datetime}"
        csv_file_name = csv_file_name.replace(".", "_")

        with open(csv_file_name + ".csv", "w") as f:
            f.write("id,PSNR,SSIM,lambda\n")

        log_file = csv_file_name + ".log"
        with open(log_file, "w") as f:
            f.write("")
        def log_func(*args):
            with open(log_file, "a") as f:
                f.write(" ".join(map(str, args)) + "\n")
        from utils.search import ternary_search_1d
        def ternary_search_1d_with_log(
                range_x, compute, best, num_iters,
                func=None, round_func=None):
            return ternary_search_1d(
                range_x, compute, best, num_iters,
                func=func, round_func=round_func, log_func=log_func)

        from data.mri.data_loader import get_dataset
        dataset = get_dataset(
            data_config=data_config,
            action="test",
            dataset_type="preprocessed",
            device=default_device,
            acceleration_factor_R=acc_factor_R,
            gaussian_noise_standard_deviation_sigma=gaussian_noise_sigma
        )

        batch_size = 1
        # generator = torch.Generator(device=default_device)
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            # generator=generator
            # num_workers=0,
            # pin_memory=True
        )

        from tqdm.notebook import tqdm

        # all_results_df = pd.DataFrame()
        id = 0
        for data in tqdm(data_loader):
            id += 1
            x_corrupted, x_true, kdata_corrupted, undersampling_kmask = data

            from pdhg.mri_pdhg import MriPdhgTorchTest
            mri_tv_test = MriPdhgTorchTest(
                device=default_device,
                pdhg_algorithm="tv",
                num_iters=256,
                kdata=kdata_corrupted,
                kmask=undersampling_kmask,
                state=x_corrupted
            )

            from utils.metrics import ImageMetricsEvaluator
            metrics_evaluator = ImageMetricsEvaluator(device=default_device)

            def compute_metrics_mri(denoised: torch.Tensor) -> pd.DataFrame:
                # plot_image_with_color_bar(denoised)
                psnr, ssim = metrics_evaluator.compute_torch_complex(x=denoised, x_true=x_true)
                return psnr, ssim

            from pdhg.scalar_searcher import ScalarSearcher
            scalar_searcher = ScalarSearcher(
                get_denoised=mri_tv_test.get_denoised,
                compute_metrics=compute_metrics_mri,
                best_metric=best_metric
            )

            with open(log_file, "a") as f:
                f.write(f"BEGIN ID: {id}\n")
            best_denoised, best_lambda, best_metrics = \
                scalar_searcher.brute_force_and_denoise_tv(
                    # search_1d=ternary_search_1d,
                    search_1d=ternary_search_1d_with_log,
                    range_lambda=range_lambda,
                    num_search_iters=num_search_iters,
                )
            with open(log_file, "a") as f:
                f.write(f"END ID: {id}\n\n")

            psnr, ssim = best_metrics
            with open(csv_file_name + ".csv", "a") as f:
                f.write(f"{id},{psnr},{ssim},{best_lambda}\n")

In [10]:
# all_results_df.to_csv(csv_file_name + ".csv")

In [11]:
# from utils.search import ternary_search_2d
# scalar_searcher_best_ssim.brute_force_and_denoise_tgv(
#     search_2d=ternary_search_2d,
#     range_lambda0=[0, 0.5],
#     range_lambda1=[0, 0.5],
#     num_iters=40,
# )

In [12]:
# from utils.search import grid_search_2d
# import numpy as np
# scalar_searcher.brute_force_and_denoise_tgv(
#     search_2d=grid_search_2d,
#     range_lambda0=np.linspace(0, 0.02, 11),
#     range_lambda1=np.linspace(0, 0.02, 11),
# )

In [13]:
# scalar_searcher_best_psnr = ScalarSearcher(
#     get_denoised=mri_tgv_test.get_denoised,
#     compute_metrics=compute_metrics_mri,
#     best_metric="PSNR"
# )

In [14]:
# scalar_searcher_best_psnr.brute_force_and_denoise_tgv(
#     search_2d=ternary_search_2d,
#     range_lambda0=[0, 0.5],
#     range_lambda1=[0, 0.5],
#     num_iters=40,
# )

In [15]:
# scalar_searcher_best_psnr.brute_force_and_denoise_tgv(
#     search_2d=grid_search_2d,
#     range_lambda0=np.linspace(0, 0.02, 11),
#     range_lambda1=np.linspace(0, 0.02, 11),
# )