Skip to content

Commit

Permalink
Gradient realtime_shim orientation and nan problems (#323)
Browse files Browse the repository at this point in the history
* Add function to determine the slice orientation

* Correct docstring

* Move imports

* Flip axis depending on anatomical orientation

* Avoid division by 0 when input pressure data is flat

* Add tests for coronal and sagital orientations

* Add test for ambiguous slice orientation

* Add a method to use the ImageOrientationText tag in the json as default sinc eit is more reliable

* Update test/cli/test_cli_realtime_shim.py

* Add comments to describe tests

* Use defined orientations instead of filenames and change import order

* Add comments about nans and change import order
  • Loading branch information
po09i committed Dec 10, 2021
1 parent b23a0a9 commit 24a422d
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 10 deletions.
43 changes: 38 additions & 5 deletions shimmingtoolbox/cli/realtime_shim.py
Expand Up @@ -2,14 +2,15 @@
# -*- coding: utf-8 -*-

import click
import json
import nibabel as nib
import numpy as np
import os
import nibabel as nib
import json

from shimmingtoolbox.shim.realtime_shim import realtime_shim
from shimmingtoolbox.pmu import PmuResp
from shimmingtoolbox.utils import create_output_dir
from shimmingtoolbox.coils.coordinates import get_main_orientation

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])

Expand Down Expand Up @@ -92,35 +93,67 @@ def realtime_shim_cli(fname_fmap, fname_mask_anat_static, fname_mask_anat_riro,
# To output to the gradient coord system, axes need some inversions. The gradient coordinate system is defined by
# the frequency, phase and slice encode directions.
# TODO: More thorough tests

# Load json
fname_json = fname_anat.rsplit('.nii', 1)[0] + '.json'
with open(fname_json) as json_file:
json_data = json.load(json_file)

if 'ImageOrientationText' in json_data:
# Tag in private dicom header (0051,100E) indicates the slice orientation, if it exists, it will appear in the
# json under 'ImageOrientationText' tag
orientation_text = json_data['ImageOrientationText']
orientation = orientation_text[:3].upper()
else:
# Find orientation with the ImageOrientationPatientDICOM tag, this is less reliable since it can fail if there
# are 2 highest cosines. It will raise an exception if there is a problem
orientation = get_main_orientation(json_data['ImageOrientationPatientDICOM'])

if orientation == 'SAG':
slice_static_corr = -slice_static_corr
slice_riro_corr = -slice_riro_corr
elif orientation == 'COR':
freq_static_corr = -freq_static_corr
freq_riro_corr = -freq_riro_corr
else:
# TRA
pass

phase_encode_is_positive = _get_phase_encode_direction_sign(fname_anat)
if not phase_encode_is_positive:
freq_static_corr = -freq_static_corr
phase_static_corr = -phase_static_corr
freq_riro_corr = -freq_riro_corr
phase_riro_corr = -phase_riro_corr

# Avoid division by 0 so there are no nans in the output text file. Nans can brick the sequence.
if not np.isclose(pressure_rms, 0):
slice_riro_corr /= pressure_rms
phase_riro_corr /= pressure_rms
freq_riro_corr /= pressure_rms

# Write to a text file
fname_zcorrections = os.path.join(fname_output, 'zshim_gradients.txt')
file_gradients = open(fname_zcorrections, 'w')
for i_slice in range(slice_static_corr.shape[-1]):
file_gradients.write(f'corr_vec[0][{i_slice}]= {slice_static_corr[i_slice]:.6f}\n')
file_gradients.write(f'corr_vec[1][{i_slice}]= {slice_riro_corr[i_slice] / pressure_rms:.12f}\n')
file_gradients.write(f'corr_vec[1][{i_slice}]= {slice_riro_corr[i_slice]:.12f}\n')
file_gradients.write(f'corr_vec[2][{i_slice}]= {mean_p:.3f}\n')
file_gradients.close()

fname_ycorrections = os.path.join(fname_output, 'yshim_gradients.txt')
file_gradients = open(fname_ycorrections, 'w')
for i_slice in range(phase_static_corr.shape[-1]):
file_gradients.write(f'corr_vec[0][{i_slice}]= {phase_static_corr[i_slice]:.6f}\n')
file_gradients.write(f'corr_vec[1][{i_slice}]= {phase_riro_corr[i_slice] / pressure_rms:.12f}\n')
file_gradients.write(f'corr_vec[1][{i_slice}]= {phase_riro_corr[i_slice]:.12f}\n')
file_gradients.write(f'corr_vec[2][{i_slice}]= {mean_p:.3f}\n')
file_gradients.close()

fname_xcorrections = os.path.join(fname_output, 'xshim_gradients.txt')
file_gradients = open(fname_xcorrections, 'w')
for i_slice in range(freq_static_corr.shape[-1]):
file_gradients.write(f'corr_vec[0][{i_slice}]= {freq_static_corr[i_slice]:.6f}\n')
file_gradients.write(f'corr_vec[1][{i_slice}]= {freq_riro_corr[i_slice] / pressure_rms:.12f}\n')
file_gradients.write(f'corr_vec[1][{i_slice}]= {freq_riro_corr[i_slice]:.12f}\n')
file_gradients.write(f'corr_vec[2][{i_slice}]= {mean_p:.3f}\n')
file_gradients.close()

Expand Down
30 changes: 30 additions & 0 deletions shimmingtoolbox/coils/coordinates.py
Expand Up @@ -207,3 +207,33 @@ def _resample_3d(nii_3d, nii_to_vox_map, order, mode, cval, out_class):
nii_resampled_3d = nib_resample_from_to(nii_3d, nii_to_vox_map, order=order, mode=mode, cval=cval,
out_class=out_class)
return nii_resampled_3d.get_fdata()


def get_main_orientation(cosines: list):
""" Returns the orientation of the slice axis by looking at the ImageOrientationPatientDICOM JSON tag
Args:
cosines (list): list of 6 elements. The first 3 represent the x, y, z cosines of the first row. The last 3
represent the x, y, z cosines of the first column. This can be found in ImageOrientationPatientDICOM so it
should be LPS coordinates.
Returns:
str: 'SAG', 'COR' or 'TRA'
"""
cosines_row = cosines[:3]
cosines_col = cosines[3:]
cosines_slice = np.cross(cosines_row, cosines_col)

slice_abs = np.abs(cosines_slice)
# list containing where the max cosine is
index_max = np.where(slice_abs == slice_abs.max())[0]

if len(index_max) != 1:
raise NotImplementedError("Ambiguous slice orientation")

orientations = {0: 'SAG',
1: 'COR',
2: 'TRA'}

return orientations[index_max[0]]
87 changes: 87 additions & 0 deletions test/cli/test_cli_realtime_shim.py
Expand Up @@ -132,3 +132,90 @@ def test_phase_encode_wrong_tag_value():
with pytest.raises(ValueError,
match="Unexpected value for PhaseEncodingDirection:"):
_get_phase_encode_direction_sign(fname_nii)


def test_cli_realtime_shim_sag_anat():
"""We do not have a sagittal orientation in testing_data so we change the json manually to test for the SAG case"""
runner = CliRunner()
with tempfile.TemporaryDirectory(prefix='st_' + pathlib.Path(__file__).stem) as tmp:
# Specify output for text file and figures
path_output = os.path.join(tmp, 'test_realtime_shim')

# Change json to have an orientation that is SAGITTAL
nii = nib.load(fname_anat)
with open(fname_json) as json_file:
json_data = json.load(json_file)
json_data['ImageOrientationPatientDICOM'] = [0, 1, 0, 0, 0, -1]
fname_json_sag = os.path.join(tmp, 'anat_sag.json')
with open(fname_json_sag, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=4)
fname_anat_sag = os.path.join(tmp, 'anat_sag.nii.gz')
nib.save(nii, fname_anat_sag)

# Run the CLI
result = runner.invoke(realtime_shim_cli, ['--fmap', fname_fieldmap,
'--output', path_output,
'--resp', fname_resp,
'--anat', fname_anat_sag],
catch_exceptions=False)

assert len(os.listdir(path_output)) != 0
assert result.exit_code == 0


def test_cli_realtime_shim_cor_anat():
"""We do not have a coronal orientation in testing_data so we change the json manually to test for the COR case"""
runner = CliRunner()
with tempfile.TemporaryDirectory(prefix='st_' + pathlib.Path(__file__).stem) as tmp:
# Specify output for text file and figures
path_output = os.path.join(tmp, 'test_realtime_shim')

# Change json to have an orientation that is CORONAL
nii = nib.load(fname_anat)
with open(fname_json) as json_file:
json_data = json.load(json_file)
json_data['ImageOrientationPatientDICOM'] = [1, 0, 0, 0, 0, -1]
fname_json_cor = os.path.join(tmp, 'anat_cor.json')
with open(fname_json_cor, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=4)
fname_anat_cor = os.path.join(tmp, 'anat_cor.nii.gz')
nib.save(nii, fname_anat_cor)

# Run the CLI
result = runner.invoke(realtime_shim_cli, ['--fmap', fname_fieldmap,
'--output', path_output,
'--resp', fname_resp,
'--anat', fname_anat_cor],
catch_exceptions=False)

assert len(os.listdir(path_output)) != 0
assert result.exit_code == 0


def test_cli_realtime_shim_tra_orient_text():
"""Add a json tag ImageOrientationText with 'Tra'"""
runner = CliRunner()
with tempfile.TemporaryDirectory(prefix='st_' + pathlib.Path(__file__).stem) as tmp:
# Specify output for text file and figures
path_output = os.path.join(tmp, 'test_realtime_shim')

# Add tag to json that says it is axial
nii = nib.load(fname_anat)
with open(fname_json) as json_file:
json_data = json.load(json_file)
json_data['ImageOrientationText'] = 'Tra'
fname_json_text = os.path.join(tmp, 'anat_text.json')
with open(fname_json_text, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=4)
fname_anat_text = os.path.join(tmp, 'anat_text.nii.gz')
nib.save(nii, fname_anat_text)

# Run the CLI
result = runner.invoke(realtime_shim_cli, ['--fmap', fname_fieldmap,
'--output', path_output,
'--resp', fname_resp,
'--anat', fname_anat_text],
catch_exceptions=False)

assert len(os.listdir(path_output)) != 0
assert result.exit_code == 0
37 changes: 32 additions & 5 deletions test/test_coordinates.py
@@ -1,16 +1,15 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*

import nibabel as nib
import numpy as np
import math
import os
import nibabel as nib
import pytest

from shimmingtoolbox.coils.coordinates import generate_meshgrid
from shimmingtoolbox.coils.coordinates import phys_gradient
from shimmingtoolbox.coils.coordinates import phys_to_vox_gradient
from shimmingtoolbox.coils.coordinates import resample_from_to
from shimmingtoolbox import __dir_testing__
from shimmingtoolbox.coils.coordinates import generate_meshgrid, phys_gradient, phys_to_vox_gradient, resample_from_to
from shimmingtoolbox.coils.coordinates import get_main_orientation


fname_fieldmap = os.path.join(__dir_testing__, 'ds_b0', 'sub-realtime', 'fmap', 'sub-realtime_fieldmap.nii.gz')
Expand Down Expand Up @@ -227,3 +226,31 @@ def test_resample_from_to_5d():
# If there isn't an error, then there is a problem
print('\nWrong dimensions but does not throw an error.')
assert False


def test_get_main_orientation_tra():
tra_orientation = [1, 0, 0, 0, 1, 0]
orientation = get_main_orientation(tra_orientation)

assert orientation == 'TRA'


def test_get_main_orientation_sag():
sag_orientation = [0, 0, 1, 0, 1, 0]
orientation = get_main_orientation(sag_orientation)

assert orientation == 'SAG'


def test_get_main_orientation_cor():
cor_orientation = [1, 0, 0, 0, 0, -1]
orientation = get_main_orientation(cor_orientation)

assert orientation == 'COR'


def test_get_main_orientation_not_imp():
cor_orientation = [1 / np.sqrt(2), 1 / np.sqrt(2), 0, 0, 0, -1]

with pytest.raises(NotImplementedError, match="Ambiguous slice orientation"):
get_main_orientation(cor_orientation)

0 comments on commit 24a422d

Please sign in to comment.