Skip to content

Commit

Permalink
Optimizer improvements (#129)
Browse files Browse the repository at this point in the history
* Flattened optimnizer parameters to remove unmasked elements, and upgraded basic_lsq to properly use the scipy lsq algorithm -- may not have worked fully before, but does now

* Renamed basic_lsq to lsq_optimizer and fixed variable names

* Updated docstring and added assertion for misalignment on lsq_optimizer

* Catchin missed snake_case

* Refactored indexing for readability style

* Tweaked sequential.py style

* Fixed line lengths

* Update shimmingtoolbox/optimizer/lsq_optimizer.py

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>

* Update shimmingtoolbox/optimizer/lsq_optimizer.py

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>

Co-authored-by: Alex Dastous <47249340+po09i@users.noreply.github.com>
  • Loading branch information
Lincoln Craven-Brightman and po09i committed Oct 9, 2020
1 parent ed2f2fd commit fbb2eae
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 69 deletions.
59 changes: 0 additions & 59 deletions shimmingtoolbox/optimizer/basic_lsq.py

This file was deleted.

20 changes: 12 additions & 8 deletions shimmingtoolbox/optimizer/basic_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

class Optimizer(object):
"""
Optimizer object that stores coil profiles and optimizes an unshimmed volume given a mask. Use optimize(args) to optimize a given mask.
Optimizer object that stores coil profiles and optimizes an unshimmed volume given a mask.
Use optimize(args) to optimize a given mask.
For basic optimizer, uses unbounded pseudo-inverse.
Attributes:
X (int): Amount of pixels in the X direction
Expand Down Expand Up @@ -60,23 +62,25 @@ def optimize(self, unshimmed, mask, mask_origin=(0, 0, 0), bounds=None):
unshimmed (numpy.ndarray): (X, Y, Z) 3d array of unshimmed volume
mask (numpy.ndarray): (X, Y, Z) 3d array of integers marking volume for optimization -- 0 indicates unused
mask_origin (tuple): Origin of mask if mask volume does not cover unshimmed volume
bounds (list): List of ``(min, max)`` pairs for each coil channels. None
is used to specify no bound.
bounds (list): List of ``(min, max)`` pairs for each coil channels. UNUSED in pseudo-inverse
"""
# Check for sizing errors
self._check_sizing(unshimmed, mask, mask_origin=mask_origin, bounds=bounds)

# Set up output currents and optimize
output = np.zeros(self.N)

mx, my, mz = mask_origin
mX, mY, mZ = mask.shape
mask_range = tuple([slice(mask_origin[i], mask_origin[i] + mask.shape[i]) for i in range(3)])
mask_vec = mask.reshape((-1,))

# Simple pseudo-inverse optimization
profile_mat = np.reshape(np.transpose(self.coils[mx:mx+mX, my:my+mY, mz:mz+mZ], axes=(3, 0, 1, 2)), (self.N, -1)).T # mV x N
unshimmed_vec = np.reshape(unshimmed[mx:mx+mX, my:my+mY, mz:mz+mZ], (mX * mY * mZ,)) # mV
# Reshape coil profile: X, Y, Z, N --> [mask.shape], N
# --> N, [mask.shape] --> N, mask.size --> mask.size, N --> masked points, N
coil_mat = np.reshape(np.transpose(self.coils[mask_range], axes=(3, 0, 1, 2)),
(self.N, -1)).T[mask_vec != 0, :] # masked points x N
unshimmed_vec = np.reshape(unshimmed[mask_range], (-1,))[mask_vec != 0] # mV'

output = -1 * scipy.linalg.pinv(profile_mat) @ unshimmed_vec # N x mV @ mV
output = -1 * scipy.linalg.pinv(coil_mat) @ unshimmed_vec # N x mV' @ mV'

return output

Expand Down
66 changes: 66 additions & 0 deletions shimmingtoolbox/optimizer/lsq_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*-

import numpy as np
import scipy.optimize as opt

from shimmingtoolbox.optimizer.basic_optimizer import Optimizer


class LSQ_Optimizer(Optimizer):

def _residuals(self, coef, unshimmed_vec, coil_mat):
"""
Objective function to minimize
Args:
coef (numpy.ndarray): 1D array of channel coefficients
unshimmed_vec (numpy.ndarray): 1D flattened array (point) of the masked unshimmed map
coil_mat (numpy.ndarray): 2D flattened array (point, channel) of masked coils
(axis 0 must align with unshimmed_vec)
Returns:
numpy.ndarray: Residuals for least squares optimization -- equivalent to flattened shimmed vector
"""
self._error_if(unshimmed_vec.shape[0] != coil_mat.shape[0],
(f'Unshimmed ({unshimmed_vec.shape}) and coil ({coil_mat.shape})'
' arrays do not align on axis 0'))
return unshimmed_vec + np.sum(coil_mat * coef, axis=1, keepdims=False)

def optimize(self, unshimmed, mask, mask_origin=(0, 0, 0), bounds=None):
"""
Optimize unshimmed volume by varying current to each channel
Args:
unshimmed (numpy.ndarray): 3D B0 map
mask (numpy.ndarray): 3D integer mask used for the optimizer (only consider voxels with non-zero values).
mask_origin (tuple): Mask origin if mask volume does not cover unshimmed volume
bounds (list): List of ``(min, max)`` pairs for each coil channels. None
is used to specify no bound.
Returns:
numpy.ndarray: Coefficients corresponding to the coil profiles that minimize the objective function
(coils.size)
"""

# Check for sizing errors
self._check_sizing(unshimmed, mask, mask_origin=mask_origin, bounds=bounds)

mask_range = tuple([slice(mask_origin[i], mask_origin[i] + mask.shape[i]) for i in range(3)])
mask_vec = mask.reshape((-1,))

# Simple pseudo-inverse optimization
# Reshape coil profile: X, Y, Z, N --> [mask.shape], N
# --> N, [mask.shape] --> N, mask.size --> mask.size, N --> masked points, N
coil_mat = np.reshape(np.transpose(self.coils[mask_range], axes=(3, 0, 1, 2)),
(self.N, -1)).T[mask_vec != 0, :] # masked points x N
unshimmed_vec = np.reshape(unshimmed[mask_range], (-1,))[mask_vec != 0] # mV'

# Set up output currents and optimize
currents_0 = np.zeros(self.N)
currents_sp = opt.least_squares(self._residuals, currents_0,
args=(unshimmed_vec, coil_mat), bounds=np.array(bounds).T)

currents = currents_sp.x

return currents
4 changes: 2 additions & 2 deletions shimmingtoolbox/optimizer/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from shimmingtoolbox.optimizer.basic_lsq import BasicLSQ
from shimmingtoolbox.optimizer.lsq_optimizer import LSQ_Optimizer


def sequential_zslice(unshimmed, coils, full_mask, z_slices, bounds=None):
Expand All @@ -24,7 +24,7 @@ def sequential_zslice(unshimmed, coils, full_mask, z_slices, bounds=None):
"""
z_slices.reshape(z_slices.size)
currents = np.zeros((z_slices.size, coils.shape[3]))
optimizer = BasicLSQ(coils)
optimizer = LSQ_Optimizer(coils)
for i in range(z_slices.size):
z = z_slices[i]
mask = full_mask[:, :, z:z+1]
Expand Down

0 comments on commit fbb2eae

Please sign in to comment.