In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

In [None]:
# TODO merge with simulate_canonical_examples.ipynb

# choose the model to plot
# model_name = 't1inv'
model_name = "adc"

# choose the SNR values to plot
# SNR = (10,20,30,40,50)
SNR = (20,)

# choose the n_train of the simulations to plot
# n_train = 100000
n_train = 10**5

# choose the designs to plot
designs = ("super", "crlb", "tadred")

# choose the fits to plot
fits = ("super", "crlb", "tadred_lsq", "tadred_nn")

# neat names for each fit for the plots
fits_neat = ("Super design", "CRLB", "TADRED-LSQ", "TADRED-NN")

# This is the project directory where everything is saved - possible to get this automatically?)
# basedir = '/Users/paddyslator/python/ED_MRI/'
basedir = "/home/blumberg/Bureau/z_Automated_Measurement/Output/tst/"


proj_dir = Path(basedir, f"{model_name}_simulations_n_train_{n_train}_SNR_{SNR[0]}")

# define location to save figures, TODO new figures?
fig_dir = Path(proj_dir, "figures_plot_canonical_simulated_examples")
fig_dir.mkdir(parents=True, exist_ok=True)

# base filename for saving figures
fig_basename = model_name + "_simulations_n_train_" + str(n_train)

In [None]:
# load the data into dictionaries

signals = {}
fitted_parameters = {}
gt_parameters = {}
acq_params = {}

for SNR_val in SNR:
    signals[SNR_val] = {}
    fitted_parameters[SNR_val] = {}
    gt_parameters[SNR_val] = {}
    acq_params[SNR_val] = {}

    # Will depend on SNR
    # this_dir = Path(proj_dir, 'n_train_' + str(n_train) + '_SNR_' + str(i))
    this_dir = proj_dir

    for des in designs:
        # test signals
        signals_filename = Path(this_dir, "signals_" + des + ".npy")
        signals[SNR_val][des] = np.load(signals_filename)

        # optimised acquisition protocols
        acq_params_filename = Path(this_dir, "acq_params_" + des + ".npy")
        acq_params[SNR_val][des] = np.load(acq_params_filename)

    for fit in fits:
        fit_filename = Path(this_dir, "fit_" + fit + ".npy")
        fitted_parameters[SNR_val][fit] = np.load(fit_filename)

    # ground truth parameters
    gt_parameters_filename = Path(this_dir, "parameters_gt.npy")
    gt_parameters[SNR_val] = np.load(gt_parameters_filename)

In [None]:
# plot the ground truth parameters against fitted parameters

fig, ax = plt.subplots(
    len(fits_neat), len(SNR), figsize=[6 * len(SNR), 5 * len(fits_neat)], squeeze=False
)
markers = ("v", "o", "x", "^")
colors = ("tab:blue", "tab:orange", "tab:green", "tab:red")

for j, SNR_val in enumerate(SNR):
    for k, fit in enumerate(fits):
        ax[k, j].plot(
            gt_parameters[SNR_val],
            fitted_parameters[SNR_val][fit],
            ".",
            markersize=1,
            color=colors[k],
        )

        if k == 0:
            ax[k, j].set_title("SNR = " + str(SNR_val), fontsize=32, weight="bold")
        #     ax[j].plot(gt_parameters[SNR_val],fitted_parameters_crlb,'o',markersize=5)
        #     ax[j].plot(gt_parameters[SNR_val],fitted_parameters_tadred,'x',markersize=10)
        #     ax[j].plot(test_tar,TADRED_output[12]["test_output"][:,0],'^')

        if model_name == "adc":
            ax[k, j].plot((0, 3.5), (0, 3.5), "k", markersize=5)
            if k == len(fits_neat) - 1:
                ax[k, j].set_xlabel("Ground truth D\n($\mu$m$^2$s$^{-1}$)", fontsize=32)
            if j == 0:
                ax[k, j].set_ylabel(
                    fits_neat[k] + "\n" + "predicted D\n($\mu$m$^2$s$^{-1}$)",
                    fontsize=32,
                    color=colors[k],
                )

            ax[k, j].set_ylim([0, 3.5])
            ax[k, j].set_xlim([0, 3.5])

        elif model_name == "t1inv":
            ax[k, j].plot((0, 7.5), (0, 7.5), "k", markersize=5)
            if k == len(fits_neat) - 1:
                ax[k, j].set_xlabel("ground truth T1 ($s$)", fontsize=32)
            if j == 0:
                ax[k, j].set_ylabel(
                    fits_neat[k] + "\n" + "predicted T1 ($s$)", fontsize=32, color=colors[k]
                )

            ax[k, j].set_ylim([0, 7.5])
            ax[k, j].set_xlim([0, 7.5])

        ax[k, j].tick_params(axis="both", labelsize=24)

# ax[0].legend(fits_neat)

fig.savefig(Path(fig_dir, fig_basename + "_gt_v_est.pdf"), dpi=300)

In [None]:
# Define metrics and calculate approach errors

MSE = {}
MAE = {}
bias = {}
# bias = {}
variance = {}


def mean_squared_error(x, y):
    return ((x - y) ** 2).mean(axis=0)


def mean_absolute_error(x, y):
    return np.mean(np.abs(x - y))


def calculate_bias(gt, pred):
    return np.mean(gt - pred)


def calculate_variance(pred):
    return np.mean((np.mean(pred) - pred)) ** 2


for fit in fits:
    MSE[fit] = {}
    MAE[fit] = {}
    bias[fit] = {}
    variance[fit] = {}
    for SNR_val in SNR:
        MSE[fit][SNR_val] = mean_squared_error(
            gt_parameters[SNR_val], fitted_parameters[SNR_val][fit]
        )
        MAE[fit][SNR_val] = mean_absolute_error(
            gt_parameters[SNR_val], fitted_parameters[SNR_val][fit]
        )
        bias[fit][SNR_val] = calculate_bias(gt_parameters[SNR_val], fitted_parameters[SNR_val][fit])
        variance[fit][SNR_val] = calculate_variance(fitted_parameters[SNR_val][fit])

In [None]:
# barplot of MSE and MAE

n_plots = 2

fig, axs = plt.subplots(1, n_plots, figsize=[6 * n_plots, 5])

bar_width = 0.15

nbars = len(SNR)

i = SNR[0]

bars = np.arange(nbars)

MSE_bars = {}
MAE_bars = {}

# bias_bars={}
# variance_bars={}


for fit, k in zip(fits, range(0, len(fits))):
    MSE_bars[fit] = [x + k * bar_width for x in bars]
    MAE_bars[fit] = [x + k * bar_width for x in bars]

#     bias_bars[fit] = [x + k * bar_width for x in bars]
#     variance_bars[fit] = [x + k * bar_width for x in bars]


for fit in fits:
    axs[0].bar(MSE_bars[fit], MSE[fit].values(), width=bar_width)
    axs[1].bar(MAE_bars[fit], MAE[fit].values(), width=bar_width)

#      axs[0].bar(bias_bars[fit], bias[fit].values(),width=bar_width)
#      axs[1].bar(variance_bars[fit], variance[fit].values(),width=bar_width)


axs[0].set_ylabel("Mean Squared Error")
axs[1].set_ylabel("Mean Absolute Error")

# axs[0].set_ylabel('Bias')
# axs[1].set_ylabel('Variance')

# mean(gt - pred)
# axs[2].set_ylabel('Mean Variance') # std(gt - pred)^2


SNRticks = []
for SNR_val in SNR:
    SNRticks.append("SNR = " + str(SNR_val))

for n in range(0, n_plots):
    axs[n].set_xticklabels("")
    axs[n].set_xticks([r + 1.5 * bar_width for r in range(nbars)])
    axs[n].set_xticklabels(SNRticks)
    axs[n].legend(fits_neat)

axs[0].set_yscale("log")
axs[1].set_yscale("log")

fig.savefig(Path(fig_dir, fig_basename + "_metrics_barplot.pdf"))

In [None]:
# plot the acquisition scheme for each


# first need to define the model function
if model_name == "adc":
    # model equation for simulation
    def model(D, bvals):
        signals = np.exp(-bvals * D)
        return signals

elif model_name == "t1inv":
    tr = 7  # repetition time - hard coded
    print("using a tr = " + str(tr))

    def model(T1, ti, tr):
        signals = abs(1 - (2 * np.exp(-ti / T1)) + np.exp(-tr / T1))
        return signals


fig, ax = plt.subplots(2, int(np.ceil(len(SNR) / 2)), figsize=[20, 12])
ax = ax.flatten()

markers = ("v", "o", "x", "^")
colors = ("tab:blue", "tab:orange", "tab:green", "tab:red")

# choose a set of parameters to plot the signal for
paramtest = gt_parameters[SNR[0]][5]

for j, SNR_val in enumerate(SNR):
    for des in designs:
        title = ax[j].set_title("SNR = " + str(i), fontsize=24, weight="bold", y=1.0, pad=-25)

        #     ax[j].plot(gt_parameters[SNR_val],fitted_parameters_crlb,'o',markersize=5)
        #     ax[j].plot(gt_parameters[SNR_val],fitted_parameters_tadred,'x',markersize=10)
        #     ax[j].plot(test_tar,TADRED_output[12]["test_output"][:,0],'^')

        if model_name == "adc":
            # super design
            (l1,) = ax[j].plot(
                acq_params[SNR_val]["super"],
                model(paramtest, acq_params[SNR_val]["super"]),
                "k.",
                markersize=2,
            )
            # tadred
            (l2,) = ax[j].plot(
                acq_params[SNR_val]["tadred"],
                model(paramtest, acq_params[SNR_val]["tadred"]),
                "o",
                color="tab:green",
                markeredgecolor="k",
                markersize=15,
            )
            # CRLB chosen b-values
            (l3,) = ax[j].plot(
                acq_params[SNR_val]["crlb"],
                model(paramtest, acq_params[SNR_val]["crlb"]),
                "+",
                color="tab:orange",
                fillstyle="none",
                markersize=15,
                markeredgewidth=3,
            )

            ax[j].set_xlabel("b-value ($\mu$m$^2$ s$^{-1}$)", fontsize=18)

        elif model_name == "t1inv":
            # super design
            (l1,) = ax[j].plot(
                acq_params[SNR_val]["super"],
                model(paramtest, acq_params[SNR_val]["super"], tr),
                "k.",
                markersize=2,
            )
            # tadred
            (l2,) = ax[j].plot(
                acq_params[SNR_val]["tadred"],
                model(paramtest, acq_params[SNR_val]["tadred"], tr),
                "o",
                color="tab:green",
                markeredgecolor="k",
                markersize=15,
            )
            # CRLB chosen b-values
            (l3,) = ax[j].plot(
                acq_params[SNR_val]["crlb"],
                model(paramtest, acq_params[SNR_val]["crlb"], tr),
                "+",
                color="tab:orange",
                fillstyle="none",
                markersize=15,
                markeredgewidth=3,
            )

            ax[j].set_xlabel("Inversion time (s)", fontsize=20)

            #','

        if j == 0 or j == 3:
            ax[j].set_ylabel("Signal", fontsize=20)
        ax[j].tick_params(axis="both", labelsize=18)


ax[-1].legend((l1, l2, l3), ["super-design", "TADRED", "CRLB"], fontsize=24, loc=10)
ax[-1].axis("off")
# ax[0].legend(fits_neat)


fig.savefig(Path(fig_dir, fig_basename + "_acq_params.pdf"), dpi=300)

In [None]:
# TO DO (maybe): PLOT BOXPLOTS OF THE BIAS
plt.boxplot(bias["crlb"].values(), 1)
plt.plot([0.5, 5.5], [0, 0])