diff --git a/setup.py b/setup.py index 2a501d73a..aba63801e 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ entry_points={ 'console_scripts': [ "st_download_data=shimmingtoolbox.cli.download_data:download_data", + "st_realtime_zshim=shimmingtoolbox.cli.realtime_zshim:realtime_zshim", "st_dicom_to_nifti=shimmingtoolbox.cli.dicom_to_nifti:dicom_to_nifti_cli", ] }, @@ -37,6 +38,8 @@ "matplotlib~=3.1.2", "pytest~=4.6.3", "pytest-cov~=2.5.1", + "sklearn~=0.0", + "nilearn~=0.6.2" ], extras_require={ 'docs': ["sphinx>=1.6", "sphinx_rtd_theme>=0.2.4"], diff --git a/shimmingtoolbox/cli/realtime_zshim.py b/shimmingtoolbox/cli/realtime_zshim.py new file mode 100644 index 000000000..4550a5707 --- /dev/null +++ b/shimmingtoolbox/cli/realtime_zshim.py @@ -0,0 +1,268 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + +import click +import numpy as np +import os +import nibabel as nib +import json +from sklearn.linear_model import LinearRegression +from nilearn.image import resample_img +# TODO: remove matplotlib and dirtesting import +from matplotlib.figure import Figure +from shimmingtoolbox import __dir_testing__ + +from shimmingtoolbox.optimizer.sequential import sequential_zslice +from shimmingtoolbox.load_nifti import get_acquisition_times +from shimmingtoolbox.pmu import PmuResp +from shimmingtoolbox import __dir_shimmingtoolbox__ + +CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) + + +@click.command( + context_settings=CONTEXT_SETTINGS, + help=f"Perform realtime z-shimming." +) +@click.option('-coil', 'fname_coil', required=True, type=click.Path(), + help="Coil basis to use for shimming. Enter multiple files if " + "you wish to use more than one set of shim coils (eg: " + "Siemens gradient/shim coils and external custom coils).") +@click.option('-fmap', 'fname_fmap', required=True, type=click.Path(), + help="B0 fieldmap. For realtime shimming, this should be a 4d file (4th dimension being time") +@click.option('-mask', 'fname_mask', type=click.Path(), + help="3D nifti file with voxels between 0 and 1 used to weight the spatial region to shim.") +@click.option('-resp', 'fname_resp', type=click.Path(), + help="Siemens respiratory file containing pressure data.") +@click.option('-anat', 'fname_anat', type=click.Path(), + help="Filename of the anatomical image to apply the correction.") +# TODO: Remove json file as input +@click.option('-json', 'fname_json', type=click.Path(), + help="Filename of json corresponding BIDS sidecar.") +@click.option("-verbose", is_flag=True, help="Be more verbose.") +def realtime_zshim(fname_coil, fname_fmap, fname_mask, fname_resp, fname_json, fname_anat, verbose=True): + """ + + Args: + fname_coil: Pointing to coil profile. 4-dimensional: x, y, z, coil. + fname_fmap: + fname_mask: + fname_resp: + verbose: + + Returns: + + """ + # Load coil + # When using only z channel (corresponding to index 0) TODO:Remove + # coil = np.expand_dims(nib.load(fname_coil).get_fdata()[:, :, :, 0], -1) + # When using all channels TODO: Keep + coil = nib.load(fname_coil).get_fdata() + + # Load fieldmap + nii_fmap = nib.load(fname_fmap) + fieldmap = nii_fmap.get_fdata() + + # TODO: Error handling might move to API + if fieldmap.ndim != 4: + raise RuntimeError("fmap must be 4d (x, y, z, t)") + nx, ny, nz, nt = fieldmap.shape + + # Load mask + # TODO: check good practice below + if fname_mask is not None: + mask = nib.load(fname_mask).get_fdata() + else: + mask = np.ones_like(fieldmap) + + # Load anat + nii_anat = nib.load(fname_anat) + anat = nii_anat.get_fdata() + if anat.ndim != 3: + raise RuntimeError("Anatomical image must be in 3d") + + # Shim using sequencer and optimizer + n_coils = coil.shape[-1] + currents = np.zeros([n_coils, nt]) + shimmed = np.zeros_like(fieldmap) + masked_fieldmaps = np.zeros_like(fieldmap) + for i_t in range(nt): + currents[:, i_t] = sequential_zslice(fieldmap[..., i_t], coil, mask, z_slices=np.array(range(nz)), + bounds=[(-np.inf, np.inf)] * n_coils) + shimmed[..., i_t] = fieldmap[..., i_t] + np.sum(currents[:, i_t] * coil, axis=3, keepdims=False) + masked_fieldmaps[..., i_t] = mask * fieldmap[..., i_t] + + # Fetch PMU timing + # TODO: Add json to fieldmap instead of asking for another json file + with open(fname_json) as json_file: + json_data = json.load(json_file) + acq_timestamps = get_acquisition_times(nii_fmap, json_data) + pmu = PmuResp(fname_resp) + # TODO: deal with saturation + acq_pressures = pmu.interp_resp_trace(acq_timestamps) + 2048 # [0, 4095] + + # TODO: + # fit PMU and fieldmap values + # do regression to separate static componant and RIRO component + # output coefficient with proper scaling + # field(i_vox) = a(i_vox) * acq_pressures + b(i_vox) + # could use: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html + # Note: strong spatial autocorrelation on the a and b coefficients. Ie: two adjacent voxels are submitted to similar + # static B0 field and RIRO component. --> we need to find a way to account for that + # solution 1: post-fitting regularization. + # pros: easy to implement + # cons: fit is less robust to noise + # solution 2: accounting for regularization during fitting + # pros: fitting more robust to noise + # cons: (from Ryan): regularized fitting took a lot of time on Matlab + + # Shim using PMU + riro = np.zeros_like(fieldmap[:, :, :, 0]) + static = np.zeros_like(fieldmap[:, :, :, 0]) + for i_x in range(fieldmap.shape[0]): + for i_y in range(fieldmap.shape[1]): + for i_z in range(fieldmap.shape[2]): + # TODO: Fit for -masked_field? + reg = LinearRegression().fit(acq_pressures.reshape(-1, 1), -masked_fieldmaps[i_x, i_y, i_z, :]) + riro[i_x, i_y, i_z] = reg.coef_ + static[i_x, i_y, i_z] = reg.intercept_ + + # Resample masked_fieldmaps, riro and static to target anatomical image + nii_masked_fmap = nib.Nifti1Image(masked_fieldmaps, nii_fmap.affine) + nii_riro = nib.Nifti1Image(riro, nii_fmap.affine) + nii_static = nib.Nifti1Image(static, nii_fmap.affine) + + target_affine = nii_anat.affine[:-1, :-1] + nii_resampled_fmap = resample_img(nii_masked_fmap, target_affine=target_affine, interpolation='nearest') + nii_resampled_riro = resample_img(nii_riro, target_affine=target_affine, interpolation='nearest') + nii_resampled_static = resample_img(nii_static, target_affine=target_affine, interpolation='nearest') + + nib.save(nii_resampled_fmap, os.path.join(__dir_shimmingtoolbox__, 'resampled_fmap.nii.gz')) + nib.save(nii_resampled_riro, os.path.join(__dir_shimmingtoolbox__, 'resampled_riro.nii.gz')) + nib.save(nii_resampled_static, os.path.join(__dir_shimmingtoolbox__, 'resampled_static.nii.gz')) + + # ================ PLOTS ================ + + # Calculate masked shim for spherical harmonics plot + masked_shimmed = np.zeros_like(shimmed) + for i_t in range(nt): + masked_shimmed[..., i_t] = mask * shimmed[..., i_t] + + # Plot unshimmed vs shimmed and their mask for spherical harmonics + i_t = 0 + fig = Figure(figsize=(10, 10)) + ax = fig.add_subplot(2, 2, 1) + im = ax.imshow(masked_fieldmaps[:-1, :-1, 0, i_t]) + fig.colorbar(im) + ax.set_title("Masked unshimmed") + ax = fig.add_subplot(2, 2, 2) + im = ax.imshow(masked_shimmed[:-1, :-1, 0, i_t]) + fig.colorbar(im) + ax.set_title("Masked shimmed") + ax = fig.add_subplot(2, 2, 3) + im = ax.imshow(fieldmap[:-1, :-1, 0, i_t]) + fig.colorbar(im) + ax.set_title("Unshimmed") + ax = fig.add_subplot(2, 2, 4) + im = ax.imshow(shimmed[:-1, :-1, 0, i_t]) + fig.colorbar(im) + ax.set_title("Shimmed") + fname_figure = os.path.join(__dir_shimmingtoolbox__, 'realtime_zshim_sphharm_shimmed.png') + fig.savefig(fname_figure) + + # Plot the coil coefs through time + fig = Figure(figsize=(10, 10)) + for i_coil in range(n_coils): + ax = fig.add_subplot(n_coils, 1, i_coil + 1) + ax.plot(np.arange(nt), currents[i_coil, :]) + ax.set_title(f"Channel {i_coil}") + fname_figure = os.path.join(__dir_shimmingtoolbox__, 'realtime_zshim_sphharm_currents.png') + fig.savefig(fname_figure) + + # Plot Static and RIRO + fig = Figure(figsize=(10, 10)) + ax = fig.add_subplot(2, 1, 1) + im = ax.imshow(riro[:-1, :-1, 0]) + fig.colorbar(im) + ax.set_title("RIRO") + ax = fig.add_subplot(2, 1, 2) + im = ax.imshow(static[:-1, :-1, 0]) + fig.colorbar(im) + ax.set_title("Static") + fname_figure = os.path.join(__dir_shimmingtoolbox__, 'realtime_zshim_riro_static.png') + fig.savefig(fname_figure) + + # Calculate fitted and shimmed for pressure fitted plot + fitted_fieldmap = riro * acq_pressures + static + shimmed_pressure_fitted = np.expand_dims(fitted_fieldmap, 2) + masked_fieldmaps + + # Plot pressure fitted fieldmap + fig = Figure(figsize=(10, 10)) + ax = fig.add_subplot(3, 1, 1) + im = ax.imshow(masked_fieldmaps[:-1, :-1, 0, i_t]) + fig.colorbar(im) + ax.set_title("fieldmap") + ax = fig.add_subplot(3, 1, 2) + im = ax.imshow(fitted_fieldmap[:-1, :-1, i_t]) + fig.colorbar(im) + ax.set_title("Fit") + ax = fig.add_subplot(3, 1, 3) + im = ax.imshow(shimmed_pressure_fitted[:-1, :-1, 0, i_t]) + fig.colorbar(im) + ax.set_title("Shimmed (fit + fieldmap") + fname_figure = os.path.join(__dir_shimmingtoolbox__, 'realtime_zshim_pressure_fitted.png') + fig.savefig(fname_figure) + + # Reshape pmu datapoints to fit those of the acquisition + pmu_times = np.linspace(pmu.start_time_mdh, pmu.stop_time_mdh, len(pmu.data)) + pmu_times_within_range = pmu_times[pmu_times > acq_timestamps[0]] + pmu_data_within_range = pmu.data[pmu_times > acq_timestamps[0]] + pmu_data_within_range = pmu_data_within_range[pmu_times_within_range < acq_timestamps[fieldmap.shape[3] - 1]] + pmu_times_within_range = pmu_times_within_range[pmu_times_within_range < acq_timestamps[fieldmap.shape[3] - 1]] + + # Calc fieldmap average within mask + fieldmap_avg = np.zeros([fieldmap.shape[3]]) + for i_time in range(nt): + masked_array = np.ma.array(fieldmap[:, :, :, i_time], mask=mask == False) + fieldmap_avg[i_time] = np.ma.average(masked_array) + + # Plot pmu vs B0 in masked region + fig = Figure(figsize=(10, 10)) + ax = fig.add_subplot(211) + ax.plot(acq_timestamps / 1000, acq_pressures, label='Interpolated pressures') + # ax.plot(pmu_times / 1000, pmu.data, label='Raw pressures') + ax.plot(pmu_times_within_range / 1000, pmu_data_within_range, label='Pmu pressures') + ax.legend() + ax.set_title("Pressure [-2048, 2047] vs time (s) ") + ax = fig.add_subplot(212) + ax.plot(acq_timestamps / 1000, fieldmap_avg, label='Mean B0') + ax.legend() + ax.set_title("Fieldmap average over unmasked region (Hz) vs time (s)") + fname_figure = os.path.join(__dir_shimmingtoolbox__, 'realtime_zshim_pmu_vs_B0.png') + fig.savefig(fname_figure) + + # Show anatomical image + fig = Figure(figsize=(10, 10)) + ax = fig.add_subplot(2, 1, 1) + im = ax.imshow(anat[:-1, :-1, 10]) + fig.colorbar(im) + ax.set_title("Anatomical image [:-1, :-1, 10]") + ax = fig.add_subplot(2, 1, 2) + im = ax.imshow(nii_resampled_fmap.get_fdata()[0, :-1, :-1, 0]) + fig.colorbar(im) + ax.set_title("Resampled fieldmap [0, :-1, :-1, 0]") + fname_figure = os.path.join(__dir_shimmingtoolbox__, 'reatime_zshime_anat.png') + fig.savefig(fname_figure) + + return fname_figure + +# Debug +# fname_coil = os.path.join(__dir_testing__, 'test_realtime_zshim', 'coil_profile.nii.gz') +# fname_fmap = os.path.join(__dir_testing__, 'test_realtime_zshim', 'sub-example_fieldmap.nii.gz') +# fname_mask = os.path.join(__dir_testing__, 'test_realtime_zshim', 'mask.nii.gz') +# fname_resp = os.path.join(__dir_testing__, 'realtime_zshimming_data', 'PMUresp_signal.resp') +# fname_json = os.path.join(__dir_testing__, 'test_realtime_zshim', 'sub-example_magnitude1.json') +# # fname_coil='/Users/julien/code/shimming-toolbox/shimming-toolbox/test_realtime_zshim/coil_profile.nii.gz' +# # fname_fmap='/Users/julien/code/shimming-toolbox/shimming-toolbox/test_realtime_zshim/sub-example_fieldmap.nii.gz' +# # fname_mask='/Users/julien/code/shimming-toolbox/shimming-toolbox/test_realtime_zshim/mask.nii.gz' +# realtime_zshim(fname_coil, fname_fmap, fname_mask, fname_resp, fname_json) diff --git a/shimmingtoolbox/shim/__init__.py b/shimmingtoolbox/shim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/cli/test_cli_realtime_zshim.py b/test/cli/test_cli_realtime_zshim.py new file mode 100644 index 000000000..c8f79b150 --- /dev/null +++ b/test/cli/test_cli_realtime_zshim.py @@ -0,0 +1,69 @@ +#!usr/bin/env python3 +# coding: utf-8 + +import os +import pathlib +import tempfile +import nibabel as nib +import numpy as np + +from click.testing import CliRunner +from shimmingtoolbox.cli.realtime_zshim import realtime_zshim +from shimmingtoolbox.masking.shapes import shapes +from shimmingtoolbox.masking.threshold import threshold +from shimmingtoolbox.coils.coordinates import generate_meshgrid +from shimmingtoolbox.coils.siemens_basis import siemens_basis +from shimmingtoolbox import __dir_testing__ + + +def test_cli_realtime_zshim(): + with tempfile.TemporaryDirectory(prefix='st_' + pathlib.Path(__file__).stem) as tmp: + runner = CliRunner() + + fname_fieldmap = os.path.join(__dir_testing__, 'realtime_zshimming_data', 'nifti', 'sub-example', 'fmap', + 'sub-example_fieldmap.nii.gz') + nii_fmap = nib.load(fname_fieldmap) + fmap = nii_fmap.get_fdata() + affine = nii_fmap.affine + + # Set up mask + # Cube + # nx, ny, nz, nt = fmap.shape + # mask = shapes(fmap[:, :, :, 0], 'cube', + # center_dim1=int(fmap.shape[0] / 2 - 8), + # center_dim2=int(fmap.shape[1] / 2 - 5), + # len_dim1=15, len_dim2=25, len_dim3=nz) + # Threshold + fname_mag = os.path.join(__dir_testing__, 'realtime_zshimming_data', 'nifti', 'sub-example', 'fmap', + 'sub-example_magnitude1.nii.gz') + mag = nib.load(fname_mag).get_fdata() + mask = threshold(mag, thr=50) + + nii_mask = nib.Nifti1Image(mask.astype(int)[:, :, :, 0], affine) + fname_mask = os.path.join(tmp, 'mask.nii.gz') + nib.save(nii_mask, fname_mask) + + # Set up coils + coord_phys = generate_meshgrid(fmap.shape[0:3], affine) + coil_profile = siemens_basis(coord_phys[0], coord_phys[1], coord_phys[2]) + + nii_coil = nib.Nifti1Image(coil_profile, affine) + fname_coil = os.path.join(tmp, 'coil_profile.nii.gz') + nib.save(nii_coil, fname_coil) + + # Path for resp data + fname_resp = os.path.join(__dir_testing__, 'realtime_zshimming_data', 'PMUresp_signal.resp') + + # Path for json file + fname_json = os.path.join(__dir_testing__, 'realtime_zshimming_data', 'nifti', 'sub-example', 'fmap', + 'sub-example_magnitude1.json') + + # Path for mag anat image + fname_anat = os.path.join(__dir_testing__, 'realtime_zshimming_data', 'nifti', 'sub-example', 'anat', + 'sub-example_unshimmed_e1.nii.gz') + + result = runner.invoke(realtime_zshim, ['-fmap', fname_fieldmap, '-coil', fname_coil, '-mask', fname_mask, + '-resp', fname_resp, '-json', fname_json, '-anat', fname_anat], + catch_exceptions=False) + + assert result.exit_code == 0