Skip to content

Commit

Permalink
Add test in cli for 3rd order and allow to use the 3rd order form the…
Browse files Browse the repository at this point in the history
… CLI
  • Loading branch information
po09i committed Nov 7, 2023
1 parent b1201a2 commit 826b974
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 54 deletions.
5 changes: 4 additions & 1 deletion shimmingtoolbox/cli/b0shim.py
Expand Up @@ -882,7 +882,7 @@ def _load_coils(coils, order, fname_constraints, nii_fmap, scanner_shim_settings
list_coils.append(Coil(nii_coil_profiles.get_fdata(), nii_coil_profiles.affine, constraints))

# Create the spherical harmonic coil profiles of the scanner
if 0 <= order <= 2:
if 0 <= order <= 3:

if os.path.isfile(fname_constraints):
with open(fname_constraints) as json_file:
Expand Down Expand Up @@ -943,6 +943,9 @@ def _initial_in_bounds(coefs, bounds):
elif order == 2:
# Order 2 has 5 more coefficients than order 1 (f0, X, Y, Z, Z2, ZX, ZY, X2-Y2, XY)
initial_coefs = [0] * 9
elif order == 3:
# It seems that ShimSettings only support upto the 2nd order
initial_coefs = [0] * 9
else:
initial_coefs = None

Expand Down
55 changes: 27 additions & 28 deletions shimmingtoolbox/coils/coil.py
Expand Up @@ -98,13 +98,13 @@ def load_constraints(self, constraints):

for i_channel in range(self.dim[3]):
if constraints["coef_channel_minmax"][i_channel] is None:
constraints["coef_channel_minmax"][i_channel] = (-np.inf, np.inf)
constraints["coef_channel_minmax"][i_channel] = [-np.inf, np.inf]
if constraints["coef_channel_minmax"][i_channel][0] is None:
constraints["coef_channel_minmax"][i_channel] = \
(-np.inf, constraints["coef_channel_minmax"][i_channel][1])
[-np.inf, constraints["coef_channel_minmax"][i_channel][1]]
if constraints["coef_channel_minmax"][i_channel][1] is None:
constraints["coef_channel_minmax"][i_channel] = \
(constraints["coef_channel_minmax"][i_channel][0], np.inf)
[constraints["coef_channel_minmax"][i_channel][0], np.inf]

if key_name == "coef_sum_max":
if constraints["coef_sum_max"] is None:
Expand All @@ -117,6 +117,7 @@ def load_constraints(self, constraints):

class ScannerCoil(Coil):
"""Coil class for scanner coils as they require extra arguments"""

def __init__(self, dim_volume, affine, constraints, order, manufacturer=""):

self.order = order
Expand All @@ -134,36 +135,29 @@ def __init__(self, dim_volume, affine, constraints, order, manufacturer=""):
sph_coil_profile = self._create_coil_profile(dim_volume, manufacturer)
# Restricts the constraints to the specified order
constraints['coef_channel_minmax'] = restrict_sph_constraints(constraints['coef_channel_minmax'], self.order)
if order == 3:
constraints['coef_channel_minmax'].extend([[None, None]] * (sph_coil_profile.shape[-1] - 9))

super().__init__(sph_coil_profile, affine, constraints)

def _create_coil_profile(self, dim, manufacturer=None):
# Define profile for Tx (constant volume)
profile_order_0 = -np.ones(dim)

# Create spherical harmonics coil profiles
if self.order == 0:
# f0 --> [1]
sph_coil_profile = profile_order_0[..., np.newaxis]
# f0, orders
mesh1, mesh2, mesh3 = generate_meshgrid(dim, self.affine)
if manufacturer == 'SIEMENS':
sph_coil_profile = siemens_basis(mesh1, mesh2, mesh3, orders=tuple(range(self.order + 1)),
shim_cs=self.coord_system)
elif manufacturer == 'GE':
sph_coil_profile = ge_basis(mesh1, mesh2, mesh3, orders=tuple(range(self.order + 1)),
shim_cs=self.coord_system)
elif manufacturer == 'PHILIPS':
sph_coil_profile = philips_basis(mesh1, mesh2, mesh3, orders=tuple(range(self.order + 1)),
shim_cs=self.coord_system)
else:
# f0, orders
mesh1, mesh2, mesh3 = generate_meshgrid(dim, self.affine)
if manufacturer == 'SIEMENS':
profile_orders = siemens_basis(mesh1, mesh2, mesh3, orders=tuple(range(1, self.order + 1)),
shim_cs=self.coord_system)
elif manufacturer == 'GE':
profile_orders = ge_basis(mesh1, mesh2, mesh3, orders=tuple(range(1, self.order + 1)),
shim_cs=self.coord_system)
elif manufacturer == 'PHILIPS':
profile_orders = philips_basis(mesh1, mesh2, mesh3, orders=tuple(range(1, self.order + 1)),
shim_cs=self.coord_system)
else:
logger.warning(f"{manufacturer} manufacturer not implemented. Outputting in Hz, uT/m, uT/m^2 for order "
f"0, 1 and 2 respectively")
profile_orders = siemens_basis(mesh1, mesh2, mesh3, orders=tuple(range(1, self.order + 1)),
shim_cs=self.coord_system)

sph_coil_profile = np.concatenate((profile_order_0[..., np.newaxis], profile_orders), axis=3)
logger.warning(f"{manufacturer} manufacturer not implemented. Outputting in Hz, uT/m, uT/m^2 for order "
f"0, 1 and 2 respectively")
sph_coil_profile = siemens_basis(mesh1, mesh2, mesh3, orders=tuple(range(self.order + 1)),
shim_cs=self.coord_system)

return sph_coil_profile

Expand Down Expand Up @@ -256,7 +250,12 @@ def restrict_sph_constraints(bounds, order):
elif order == 2:
# f0, ch1, ch2, ch3, ch4, ch5, ch6, ch7, ch8 -- > [9]
minmax_out = bounds[:9]
elif order == 3:
if len(bounds) >= 16:
minmax_out = bounds[:16]
else:
minmax_out = bounds
else:
raise NotImplementedError("Order must be between 0 and 2")
raise NotImplementedError("Order must be between 0 and 3")

return minmax_out
18 changes: 9 additions & 9 deletions shimmingtoolbox/coils/spher_harm_basis.py
Expand Up @@ -394,11 +394,7 @@ def convert_spher_harm_to_array(spher_harm_dict):

spher_harm = []
for order in sorted(spher_harm_dict.keys()):
sph = spher_harm_dict[order]
if order == 0:
spher_harm.append(sph[..., np.newaxis])
else:
spher_harm.append(sph)
spher_harm.append(spher_harm_dict[order])

spher_harm = np.concatenate(spher_harm, axis=-1)

Expand Down Expand Up @@ -472,8 +468,12 @@ def _get_scaling_factors(orders):

for i in range(_channels_per_order(order)):
field = sh[:, :, :, i_ch]
scaling_factors[i_ch] = (GYROMAGNETIC_RATIO * ((r[order][i] * 0.001) ** order) /
field[iref[order][i]][0])
if order != 0:
scaling_factors[i_ch] = (GYROMAGNETIC_RATIO * ((r[order][i] * 0.001) ** order) /
field[iref[order][i]][0])
else:
scaling_factors[i_ch] = -1 / field[iref[order][i]][0]

i_ch += 1

return scaling_factors
Expand All @@ -487,8 +487,8 @@ def _check_basis_inputs(x, y, z, orders):
if not (x.shape == y.shape == z.shape):
raise RuntimeError("Input arrays X, Y, and Z must be identically sized")

if max(orders) >= 3:
raise NotImplementedError("Spherical harmonics not implemented for order 3 and up")
if max(orders) >= 4:
raise NotImplementedError("Spherical harmonics not implemented for order 4 and up")


def _channels_per_order(order):
Expand Down
31 changes: 29 additions & 2 deletions test/cli/test_cli_b0shim.py
Expand Up @@ -54,7 +54,7 @@ def _define_inputs(fmap_dim):
mask = shapes(anat, 'cube',
center_dim1=int(nx / 2),
center_dim2=int(ny / 2),
len_dim1=10, len_dim2=10, len_dim3=nz - 10)
len_dim1=30, len_dim2=30, len_dim3=nz - 5)

nii_mask = nib.Nifti1Image(mask.astype(np.uint8), nii_anat.affine)

Expand Down Expand Up @@ -158,7 +158,7 @@ def test_cli_dynamic_coils(self, nii_fmap, nii_anat, nii_mask, fm_data, anat_dat
assert os.path.isfile(os.path.join(tmp, "coefs_coil0_Dummy_coil.txt"))

def test_cli_dynamic_sph_order_0(self, nii_fmap, nii_anat, nii_mask, fm_data, anat_data):
"""Test cli with scanner coil profiles of order 1 with default constraints"""
"""Test cli with scanner coil profiles of order 0 with default constraints"""
with tempfile.TemporaryDirectory(prefix='st_' + pathlib.Path(__file__).stem) as tmp:
# Save the inputs to the new directory
fname_fmap = os.path.join(tmp, 'fmap.nii.gz')
Expand All @@ -184,6 +184,33 @@ def test_cli_dynamic_sph_order_0(self, nii_fmap, nii_anat, nii_mask, fm_data, an
assert res.exit_code == 0
assert os.path.isfile(os.path.join(tmp, "coefs_coil0_Prisma_fit.txt"))

def test_cli_dynamic_sph_order_3(self, nii_fmap, nii_anat, nii_mask, fm_data, anat_data):
"""Test cli with scanner coil profiles of order 3 with default constraints"""
with tempfile.TemporaryDirectory(prefix='st_' + pathlib.Path(__file__).stem) as tmp:
# Save the inputs to the new directory
fname_fmap = os.path.join(tmp, 'fmap.nii.gz')
fname_fm_json = os.path.join(tmp, 'fmap.json')
fname_mask = os.path.join(tmp, 'mask.nii.gz')
fname_anat = os.path.join(tmp, 'anat.nii.gz')
fname_anat_json = os.path.join(tmp, 'anat.json')
_save_inputs(nii_fmap=nii_fmap, fname_fmap=fname_fmap,
nii_anat=nii_anat, fname_anat=fname_anat,
nii_mask=nii_mask, fname_mask=fname_mask,
fm_data=fm_data, fname_fm_json=fname_fm_json,
anat_data=anat_data, fname_anat_json=fname_anat_json)

runner = CliRunner()
res = runner.invoke(b0shim_cli, ['dynamic',
'--fmap', fname_fmap,
'--anat', fname_anat,
'--mask', fname_mask,
'--scanner-coil-order', '2',
'--output', tmp],
catch_exceptions=False)

assert res.exit_code == 0
assert os.path.isfile(os.path.join(tmp, "coefs_coil0_Prisma_fit.txt"))

def test_cli_dynamic_coils_and_sph(self, nii_fmap, nii_anat, nii_mask, fm_data, anat_data):
"""Test cli with input coil and scanner coil"""
with tempfile.TemporaryDirectory(prefix='st_' + pathlib.Path(__file__).stem) as tmp:
Expand Down
35 changes: 21 additions & 14 deletions test/test_spher_harm_basis.py
Expand Up @@ -17,40 +17,47 @@

@pytest.mark.parametrize('x,y,z', dummy_data)
def test_normal_siemens_basis(x, y, z):
basis = siemens_basis(x, y, z)
basis = siemens_basis(x, y, z, orders=(0, 1, 2, 3))

# Test for shape
assert (np.all(basis.shape == (x.shape[0], x.shape[1], x.shape[2], 8)))
assert (np.all(basis.shape == (x.shape[0], x.shape[1], x.shape[2], 13)))
# X, Y, Z, Z2, ZX, ZY, X2 - Y2, XY
assert np.allclose(basis[:, 1, 1, 0], [4.25774785e-02, 0, -4.25774785e-02])
assert np.allclose(basis[1, :, 1, 1], [-4.25774785e-02, 0, 4.25774785e-02])
assert np.allclose(basis[1, 1, :, 2], [4.25774785e-02, 0, -4.25774785e-02])
assert np.allclose(basis[1, 1, :, 3], [4.25774785e-05, 0.00000000e+00, 4.25774785e-05])
assert np.allclose(basis[:, 1, :, 4], np.array([[8.5154957e-05, 0.0000000e+00, -8.5154957e-05],
assert np.allclose(basis[:, 1, 1, 0], [-1, -1, -1])
assert np.allclose(basis[:, 1, 1, 1], [4.25774785e-02, 0, -4.25774785e-02])
assert np.allclose(basis[1, :, 1, 2], [-4.25774785e-02, 0, 4.25774785e-02])
assert np.allclose(basis[1, 1, :, 3], [4.25774785e-02, 0, -4.25774785e-02])
assert np.allclose(basis[1, 1, :, 4], [4.25774785e-05, 0.00000000e+00, 4.25774785e-05])
assert np.allclose(basis[:, 1, :, 5], np.array([[8.5154957e-05, 0.0000000e+00, -8.5154957e-05],
[-0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
[-8.5154957e-05, -0.0000000e+00, 8.5154957e-05]]))
assert np.allclose(basis[1, :, :, 5], np.array([[-8.5154957e-05, -0.0000000e+00, 8.5154957e-05],
assert np.allclose(basis[1, :, :, 6], np.array([[-8.5154957e-05, -0.0000000e+00, 8.5154957e-05],
[0.0000000e+00, -0.0000000e+00, -0.0000000e+00],
[8.5154957e-05, 0.0000000e+00, -8.5154957e-05]]))
assert np.allclose(basis[:, :, 1, 6], np.array([[0, 4.25774785e-05, 0],
assert np.allclose(basis[:, :, 1, 7], np.array([[0, 4.25774785e-05, 0],
[-4.25774785e-05, 0.00000000e+00, -4.25774785e-05],
[0, 4.25774785e-05, 0]]))
assert np.allclose(basis[:, :, 1, 7], np.array([[-8.51549570e-05, 0, 8.51549570e-05],
assert np.allclose(basis[:, :, 1, 8], np.array([[-8.51549570e-05, 0, 8.51549570e-05],
[0, 0, 0],
[8.51549570e-05, -0.00000000e+00, -8.51549570e-05]]))
# TODO: add tests for order 3


@pytest.mark.parametrize('x,y,z', dummy_data)
def test_siemens_basis(x, y, z):
basis = siemens_basis(x, y, z, orders=(1,))
print(basis.shape)
assert np.all(basis.shape == (3, 3, 3, 3))


@pytest.mark.parametrize('x,y,z', dummy_data)
def test_create_scanner_coil_order3(x, y, z):
with pytest.raises(NotImplementedError, match="Spherical harmonics not implemented for order 3 and up"):
siemens_basis(x, y, z, orders=(3,))
def test_create_siemens_basis_order3(x, y, z):
basis = siemens_basis(x, y, z, orders=(3,))
assert np.all(basis.shape == (3, 3, 3, 4))


@pytest.mark.parametrize('x,y,z', dummy_data)
def test_create_siemens_basis_order4(x, y, z):
with pytest.raises(NotImplementedError, match="Spherical harmonics not implemented for order 4 and up"):
siemens_basis(x, y, z, orders=(4,))


def test_siemens_basis_resample():
Expand Down

0 comments on commit 826b974

Please sign in to comment.