Skip to content

Commit

Permalink
Changes to plots
Browse files Browse the repository at this point in the history
  • Loading branch information
po09i committed Aug 29, 2023
1 parent 4cdc061 commit 6f6da9f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 67 deletions.
151 changes: 85 additions & 66 deletions shimmingtoolbox/shim/sequencer.py
Expand Up @@ -57,6 +57,8 @@ class Sequencer(object):
regularization. A negative value will favour high currents (not preferred). Only relevant
for 'least_squares' opt_method.
path_output (str): Path to the directory to output figures. Set logging level to debug to output debug
index_shimmed: Indexes of ``slices`` that have been shimmed
index_not_shimmed: Indexes of ``slices`` that have not been shimmed
"""

def __init__(self, slices, mask_dilation_kernel, mask_dilation_kernel_size, reg_factor, path_output):
Expand Down Expand Up @@ -84,6 +86,8 @@ def __init__(self, slices, mask_dilation_kernel, mask_dilation_kernel_size, reg_
self.reg_factor = reg_factor
self.path_output = path_output
self.optimizer = None
self.index_shimmed = []
self.index_not_shimmed = []

def optimize(self, masks_fmap):
"""
Expand All @@ -102,10 +106,12 @@ def optimize(self, masks_fmap):
# If there is nothing to shim in this shim group
if np.all(masks_fmap[..., i] == 0):
coefs.append(np.zeros(self.optimizer.merged_coils.shape[-1]))
self.index_not_shimmed.append(i)

# Otherwise optimize
else:
coefs.append(self.optimizer.optimize(masks_fmap[..., i]))
self.index_shimmed.append(i)

return np.array(coefs)

Expand Down Expand Up @@ -332,7 +338,8 @@ def get_resampled_masks(self):
for i in range(n_shims))

# We need to transpose the mask to have the good dimensions
masks_fmap_dilated = np.array([results_mask[it][1].get_fdata() for it in range(n_shims)]).transpose(1, 2, 3, 0)
masks_fmap_dilated = np.array([results_mask[it][1].get_fdata() for it in range(n_shims)]).transpose(1, 2, 3,
0)
masks_fmap = np.array([results_mask[it][0].get_fdata() for it in range(n_shims)]).transpose(1, 2, 3, 0)

return masks_fmap_dilated, masks_fmap
Expand Down Expand Up @@ -1209,6 +1216,7 @@ def eval(self, coef_static, coef_riro, mean_p, pressure_rms):
shim_trace_static = []
shim_trace_riro = []
unshimmed_trace = []
mae_unshimmed_trace = []
mask_full_binary = np.clip(np.ceil(resample_from_to(self.nii_static_mask,
nii_target,
order=0,
Expand Down Expand Up @@ -1240,26 +1248,34 @@ def eval(self, coef_static, coef_riro, mean_p, pressure_rms):
# Calculate the sum over the ROI
# TODO: Calculate the sum of mask_fmap_cs[..., i_shim] and divide by that (If the roi is bigger due to
# interpolation, it should not count more). Possibly use soft mask?
sum_shimmed_static = np.sum(np.abs(masked_shim_static[..., i_t, i_shim]))
sum_shimmed_static_riro = np.sum(np.abs(masked_shim_static_riro[..., i_t, i_shim]))
sum_shimmed_riro = np.sum(np.abs(masked_shim_riro[..., i_t, i_shim]))
sum_unshimmed = np.sum(np.abs(masked_unshimmed[..., i_t, i_shim]))

if sum_shimmed_static_riro > sum_unshimmed:
rmse_shimmed_static = calculate_metric_within_mask(masked_shim_static[..., i_t, i_shim],
mask_fmap_cs[..., i_shim].astype(bool),
metric='rmse')
rmse_shimmed_static_riro = calculate_metric_within_mask(masked_shim_static_riro[..., i_t, i_shim],
mask_fmap_cs[..., i_shim].astype(bool),
metric='rmse')
rmse_shimmed_riro = calculate_metric_within_mask(masked_shim_riro[..., i_t, i_shim],
mask_fmap_cs[..., i_shim].astype(bool),
metric='rmse')
rmse_unshimmed = calculate_metric_within_mask(masked_unshimmed[..., i_t, i_shim],
mask_fmap_cs[..., i_shim].astype(bool),
metric='rmse')

if rmse_shimmed_static_riro > rmse_unshimmed:
logger.warning("Verify the shim parameters. Some give worse results than no shim.\n"
f"i_shim: {i_shim}, i_t: {i_t}")

logger.debug(f"\ni_shim: {i_shim}, t: {i_t}"
f"\nunshimmed: {sum_unshimmed}, shimmed static: {sum_shimmed_static}, "
f"shimmed static+riro: {sum_shimmed_static_riro}\n"
logger.debug(f"\nRMSE: i_shim: {i_shim}, t: {i_t}"
f"\nunshimmed: {rmse_unshimmed}, shimmed static: {rmse_shimmed_static}, "
f"shimmed static+riro: {rmse_shimmed_static_riro}\n"
f"Static currents:\n{coef_static[i_shim]}\n"
f"Riro currents:\n{coef_riro[i_shim] * (self.acq_pressures[i_t] - mean_p)}\n")

# Create a 1D list of the sum of the shimmed and unshimmed maps
shim_trace_static.append(sum_shimmed_static)
shim_trace_static_riro.append(sum_shimmed_static_riro)
shim_trace_riro.append(sum_shimmed_riro)
unshimmed_trace.append(sum_unshimmed)
shim_trace_static.append(rmse_shimmed_static)
shim_trace_static_riro.append(rmse_shimmed_static_riro)
shim_trace_riro.append(rmse_shimmed_riro)
unshimmed_trace.append(rmse_unshimmed)

# reshape to slice x timepoint
nt = unshimmed.shape[3]
Expand All @@ -1272,7 +1288,8 @@ def eval(self, coef_static, coef_riro, mean_p, pressure_rms):
if self.path_output is not None:
# Plot before vs after shimming averaged on time
shimmed_mask_avg = np.zeros(mask_full_binary.shape)
np.divide(np.sum(np.mean(masked_shim_static_riro, axis=3), axis=3), np.sum(mask_fmap_cs, axis=3), where=mask_full_binary.astype(bool), out=shimmed_mask_avg)
np.divide(np.sum(np.mean(masked_shim_static_riro, axis=3), axis=3), np.sum(mask_fmap_cs, axis=3),
where=mask_full_binary.astype(bool), out=shimmed_mask_avg)
self.plot_full_mask(np.mean(unshimmed, axis=3), shimmed_mask_avg, mask_full_binary)

# Plot STD over time before and after shimming
Expand All @@ -1281,18 +1298,15 @@ def eval(self, coef_static, coef_riro, mean_p, pressure_rms):
if logger.level <= getattr(logging, 'DEBUG') and self.path_output is not None:
# plot results
i_slice = 0
i_shim = 0
i_shim = self.index_shimmed[0] if self.index_shimmed else n_shim - 1
i_t = 0
while np.all(masked_unshimmed[..., i_slice, i_t, i_shim] == np.zeros(masked_unshimmed.shape[:2])):
i_shim += 1
if i_shim >= n_shim - 1:
break

self.plot_static_riro(masked_unshimmed, masked_shim_static, masked_shim_static_riro, unshimmed,
shimmed_static,
shimmed_static_riro, i_slice=i_slice, i_shim=i_shim, i_t=i_t)
self.plot_currents(coef_static, riro=coef_riro * pressure_rms)
self.plot_shimmed_trace(unshimmed_trace, shim_trace_static, shim_trace_riro, shim_trace_static_riro)
self.plot_pressure_and_field(unshimmed_trace)
self.plot_pressure_and_unshimmed_field(unshimmed_trace)
self.print_rt_metrics(unshimmed, shimmed_static, shimmed_static_riro, shimmed_riro, mask_fmap_cs)
# Save shimmed result
nii_shimmed_static_riro = nib.Nifti1Image(shimmed_static_riro, self.nii_fieldmap.affine,
Expand Down Expand Up @@ -1373,25 +1387,32 @@ def plot_static_riro(self, masked_unshimmed, masked_shim_static, masked_shim_sta
fig.colorbar(im)
ax.set_title("masked_unshimmed")

min_value = min(shimmed_static_riro[..., i_slice, i_t, i_shim].min(),
shimmed_static[..., i_slice, i_t, i_shim].min(),
unshimmed[..., i_slice, i_t, i_shim].min())
max_value = max(shimmed_static_riro[..., i_slice, i_t, i_shim].max(),
shimmed_static[..., i_slice, i_t, i_shim].max(),
unshimmed[..., i_slice, i_t, i_shim].max())

ax = fig.add_subplot(2, 3, 4)
im = ax.imshow(np.rot90(shimmed_static_riro[..., i_slice, i_t, i_shim]))
im = ax.imshow(np.rot90(shimmed_static_riro[..., i_slice, i_t, i_shim]), vmin=min_value, vmax=max_value)
fig.colorbar(im)
ax.set_title("shim static + riro")
ax = fig.add_subplot(2, 3, 5)
im = ax.imshow(np.rot90(shimmed_static[..., i_slice, i_t, i_shim]))
im = ax.imshow(np.rot90(shimmed_static[..., i_slice, i_t, i_shim]), vmin=min_value, vmax=max_value)
fig.colorbar(im)
ax.set_title(f"shim static")
ax = fig.add_subplot(2, 3, 6)
im = ax.imshow(np.rot90(unshimmed[..., i_slice, i_t]))
im = ax.imshow(np.rot90(unshimmed[..., i_slice, i_t]), vmin=min_value, vmax=max_value)
fig.colorbar(im)
ax.set_title(f"unshimmed")
fname_figure = os.path.join(self.path_output, 'fig_realtime_masked_shimmed_vs_unshimmed.png')
fig.savefig(fname_figure)
logger.debug(f"Saved figure: {fname_figure}")

def plot_pressure_and_field(self, unshimmed_trace):
def plot_pressure_and_unshimmed_field(self, unshimmed_trace):
"""
Plot respiratory trace, acquisition time pressure points and the scaled B0 field
Plot respiratory trace, acquisition time pressure points and the B0 field RMSE
Args:
unshimmed_trace (np.ndarray): field in the ROI for each shim volume
Expand All @@ -1404,24 +1425,18 @@ def plot_pressure_and_field(self, unshimmed_trace):
pmu_timestamps_curated = pmu_timestamps[indexes]
pmu_pressures_curated = pmu_pressures[indexes]

# Remove slices not being shimmed
shim_to_display = []
shim_to_remove = []
curated_unshimmed_trace = copy.deepcopy(unshimmed_trace)
for i_shim in range(len(self.slices)):
if np.all(unshimmed_trace[i_shim] != 0):
shim_to_display.append(i_shim)
else:
shim_to_remove.append(i_shim)
curated_unshimmed_trace = np.delete(curated_unshimmed_trace, shim_to_remove, axis=0)
# Select slices shimmed
curated_unshimmed_trace = unshimmed_trace[self.index_shimmed]

# Get the b0 field in the same units as the pressure reading
n_plots = len(shim_to_display)
n_plots = len(self.index_shimmed)
max_diff_field = 0
for i_plot in range(n_plots):
diff_field = curated_unshimmed_trace[i_plot].max() - curated_unshimmed_trace[i_plot].min()
if abs(max_diff_field) < abs(diff_field):
max_diff_field = diff_field
min_field = curated_unshimmed_trace[i_plot].min()
max_field = curated_unshimmed_trace[i_plot].max()

diff_pressure = pmu_pressures_curated.max() - pmu_pressures_curated.min()
scaling = max_diff_field / diff_pressure
Expand All @@ -1430,7 +1445,8 @@ def plot_pressure_and_field(self, unshimmed_trace):
curated_unshimmed_trace_scaled = np.zeros_like(curated_unshimmed_trace)
for i_plot in range(n_plots):
avg_b0field = np.mean(curated_unshimmed_trace[i_plot])
curated_unshimmed_trace_scaled[i_plot] = (curated_unshimmed_trace[i_plot] - avg_b0field) / scaling + avg_pressure
curated_unshimmed_trace_scaled[i_plot] = (curated_unshimmed_trace[
i_plot] - avg_b0field) / scaling + avg_pressure

# Find y limits
perc = (self.pmu.max - self.pmu.min) / 20
Expand All @@ -1444,29 +1460,29 @@ def plot_pressure_and_field(self, unshimmed_trace):
ax.plot((pmu_timestamps_curated - pmu_timestamps_curated[0]) / 1000, pmu_pressures_curated,
label='Pressure Trace')
ax.plot((self.acq_timestamps - pmu_timestamps_curated[0]) / 1000, curated_unshimmed_trace_scaled[i_plot],
label='B0_field')
label='RMSE over the not shimmed ROI')
ax.scatter((self.acq_timestamps - pmu_timestamps_curated[0]) / 1000, self.acq_pressures, color='red',
label='Fieldmap timepoints')
ax.legend()
ax.set_ylim(ylim)
ax.set_yticklabels([])
ax.set_yticks([])
ax.set_yticks([pmu_pressures_curated.min(), pmu_pressures_curated.max()],
[min_field.astype(int), max_field.astype(int)])
ax.set_xlabel('Time (s)')
ax.set_title(f"Slices: {self.slices[shim_to_display[i_plot]]}")
ax.set_title(f"Slices: {self.slices[self.index_shimmed[i_plot]]}")

# Place suptitle
# 10 = 2%, 5 = 4%, 1 = 10%
top = 1 - (0.1 / (n_plots / 1.5))
fig.tight_layout(rect=[0, 0.03, 1, top])

# Save figure
fname_figure = os.path.join(self.path_output, 'fig_trace_pressures.png')
fname_figure = os.path.join(self.path_output, 'fig_not_shimmed_trace_vs_pressure.png')
fig.savefig(fname_figure, bbox_inches='tight')
logger.debug(f"Saved figure: {fname_figure}")

def plot_shimmed_trace(self, unshimmed_trace, shim_trace_static, shim_trace_riro, shim_trace_static_riro):
"""
Plot shimmed and unshimmed sum over the roi for each shim
Plot shimmed and unshimmed rmse over the roi for each shim
Args:
unshimmed_trace (np.ndarray): array with the trace of the nii_fieldmap data
Expand All @@ -1476,33 +1492,33 @@ def plot_shimmed_trace(self, unshimmed_trace, shim_trace_static, shim_trace_riro
"""

min_value = min(
shim_trace_static_riro[:, :].min(),
shim_trace_static[:, :].min(),
shim_trace_riro[:, :].min(),
unshimmed_trace[:, :].min()
shim_trace_static_riro[self.index_shimmed, :].min(),
shim_trace_static[self.index_shimmed, :].min(),
shim_trace_riro[self.index_shimmed, :].min(),
unshimmed_trace[self.index_shimmed, :].min()
)
max_value = max(
shim_trace_static_riro[:, :].max(),
shim_trace_static[:, :].max(),
shim_trace_riro[:, :].max(),
unshimmed_trace[:, :].max()
shim_trace_static_riro[self.index_shimmed, :].max(),
shim_trace_static[self.index_shimmed, :].max(),
shim_trace_riro[self.index_shimmed, :].max(),
unshimmed_trace[self.index_shimmed, :].max()
)

# Calc ysize
n_shims = len(unshimmed_trace)
n_shims = len(self.index_shimmed)
ysize = n_shims * 4.7
fig = Figure(figsize=(10, ysize), tight_layout=True)
for i_shim in range(n_shims):
ax = fig.add_subplot(n_shims, 1, i_shim + 1)
for i, i_shim in enumerate(self.index_shimmed):
ax = fig.add_subplot(n_shims, 1, i + 1)
ax.plot(shim_trace_static_riro[i_shim, :], label='shimmed static + riro')
ax.plot(shim_trace_static[i_shim, :], label='shimmed static')
ax.plot(shim_trace_riro[i_shim, :], label='shimmed_riro')
ax.plot(unshimmed_trace[i_shim, :], label='unshimmed')
ax.set_xlabel('Timepoints')
ax.set_ylabel('Sum over the ROI')
ax.set_ylabel('RMSE over the ROI')
ax.legend()
ax.set_ylim(min_value, max_value)
ax.set_title(f"Unshimmed vs shimmed values: shim {i_shim}")
ax.set_ylim([min_value, max_value])
ax.set_title(f"Unshimmed vs shimmed values: shim {self.slices[i_shim]}")
fname_figure = os.path.join(self.path_output, 'fig_trace_shimmed_vs_unshimmed.png')
fig.savefig(fname_figure)
logger.debug(f"Saved figure: {fname_figure}")
Expand Down Expand Up @@ -1622,20 +1638,21 @@ def plot_full_time_std(self, unshimmed, masked_shim_static_riro, mask_fmap_cs, m
mask (np.ndarray): Binary mask in the fieldmap space shaped (x, y, z)
"""
# Transform shimmed field map to shape (x, y, z, time)
sum_mask_fmap_cs = np.sum(mask_fmap_cs, axis=3)
sum_mask_fmap_cs = np.sum(mask_fmap_cs, axis=3)
mask_extended = np.repeat(mask[..., np.newaxis], masked_shim_static_riro.shape[-2], axis=-1)

# Transpose is used to cater to numpy division order
# (3, 2, 4) / (3, 2) Does not work
# (4, 2, 3) / (2, 3) Does work
#* Using out parameter in np.divide() prevents inconsistent results
# * Using out parameter in np.divide() prevents inconsistent results
shimmed_masked = np.zeros(mask_extended.shape)
np.divide(np.sum(masked_shim_static_riro, axis=-1).T, sum_mask_fmap_cs.T, where=mask.T.astype(bool), out=shimmed_masked.T)
np.divide(np.sum(masked_shim_static_riro, axis=-1).T, sum_mask_fmap_cs.T, where=mask.T.astype(bool),
out=shimmed_masked.T)

std_shimmed_masked = np.std(shimmed_masked, axis=-1, dtype=np.float64)
std_unshimmed = np.std(unshimmed, axis=-1, dtype=np.float64)

## Plot
# Plot
mt_unshimmed = montage(np.mean(unshimmed, axis=-1))
mt_unshimmed_masked = montage(std_unshimmed * mask)
mt_shimmed_masked = montage(std_shimmed_masked)
Expand All @@ -1644,12 +1661,14 @@ def plot_full_time_std(self, unshimmed, masked_shim_static_riro, mask_fmap_cs, m
metric_shimmed_mean = calculate_metric_within_mask(std_shimmed_masked, mask, metric='mean')

# Remove the outliners to calculate the colorbar limits
# Necessary because some STD are much higher and are not visible on the heatmap, they are still considered in the metric
shim_limit = np.percentile(mt_shimmed_masked[mt_shimmed_masked !=0], 90)
unshim_limit = np.percentile(mt_unshimmed_masked[mt_unshimmed_masked !=0], 90)
# Necessary because some STD are much higher and are not visible on the heatmap, they are still considered in
# the metric
shim_limit = np.percentile(mt_shimmed_masked[mt_shimmed_masked != 0], 90)
unshim_limit = np.percentile(mt_unshimmed_masked[mt_unshimmed_masked != 0], 90)

min_value = min(mt_unshimmed_masked.min(), mt_shimmed_masked.min())
max_value = max(mt_unshimmed_masked[mt_unshimmed_masked < unshim_limit].max(), mt_shimmed_masked[mt_shimmed_masked < shim_limit].max())
max_value = max(mt_unshimmed_masked[mt_unshimmed_masked < unshim_limit].max(),
mt_shimmed_masked[mt_shimmed_masked < shim_limit].max())

fig = Figure(figsize=(9, 6))
fig.suptitle(f"Fieldmaps\nFieldmap Coordinate System\n\u0394B\u2080 STD over time ")
Expand Down

0 comments on commit 6f6da9f

Please sign in to comment.