Skip to content

Commit a08ab49

Browse files
authored
Merge pull request freesurfer#1078 from ste93ste/nf-samseg-long-no-resampling
Nf: samseg-long segmentation in native space
2 parents 0366fe1 + 4435541 commit a08ab49

File tree

3 files changed

+162
-48
lines changed

3 files changed

+162
-48
lines changed

python/gems/SamsegLongitudinal.py

Lines changed: 135 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def __init__(self,
9494
thresholdSearchString=None,
9595
modeNames=None,
9696
pallidumAsWM=True,
97-
savePosteriors=False
97+
savePosteriors=False,
98+
tpToBaseTransforms=None,
9899
):
99100

100101
# Store input parameters as class variables
@@ -112,6 +113,16 @@ def __init__(self,
112113
self.modeNames = modeNames
113114
self.pallidumAsWM = pallidumAsWM
114115
self.savePosteriors = savePosteriors
116+
self.tpToBaseTransforms = tpToBaseTransforms
117+
118+
# Check if all time point to base transforms are identity matrices.
119+
# If so, we can derive a combined 4D mask during preprocessing
120+
self.allIdentityTransforms = True
121+
if tpToBaseTransforms is not None:
122+
for tp, transform in enumerate(self.tpToBaseTransforms):
123+
if not np.allclose(transform.matrix, np.eye(4)):
124+
self.allIdentityTransforms = False
125+
115126

116127
# Initialize some objects
117128
self.probabilisticAtlas = ProbabilisticAtlas()
@@ -190,6 +201,7 @@ def constructAndRegisterSubjectSpecificTemplate(self):
190201
templateFileName=templateFileName)
191202
self.imageToImageTransformMatrix, _ = affine.registerAtlas(savePath=sstDir, visualizer=self.visualizer)
192203

204+
193205
def preProcess(self):
194206

195207
# construct sstModel
@@ -205,32 +217,92 @@ def preProcess(self):
205217
self.sstModel.imageBuffers, self.sstModel.transform, self.sstModel.voxelSpacing, self.sstModel.cropping = readCroppedImages(self.sstFileNames, templateFileName, self.imageToImageTransformMatrix)
206218

207219
self.imageBuffersList = []
208-
for imageFileNames in self.imageFileNamesList:
209-
imageBuffers, _, _, _ = readCroppedImages(imageFileNames, templateFileName, self.imageToImageTransformMatrix)
210-
self.imageBuffersList.append(imageBuffers)
211-
212-
# Put everything in a big 4-D matrix to derive one consistent mask across all time points
213-
imageSize = self.sstModel.imageBuffers.shape[:3]
214-
numberOfContrasts = self.sstModel.imageBuffers.shape[-1]
215-
self.combinedImageBuffers = np.zeros(imageSize + (numberOfContrasts * (1 + self.numberOfTimepoints),))
216-
self.combinedImageBuffers[..., 0:numberOfContrasts] = self.sstModel.imageBuffers
217-
for timepointNumber in range(self.numberOfTimepoints):
218-
self.combinedImageBuffers[..., (timepointNumber + 1) * numberOfContrasts:
219-
(timepointNumber + 2) * numberOfContrasts] = self.imageBuffersList[timepointNumber]
220-
221-
self.combinedImageBuffers, self.sstModel.mask = maskOutBackground(self.combinedImageBuffers, self.sstModel.modelSpecifications.atlasFileName,
222-
self.sstModel.transform,
223-
self.sstModel.modelSpecifications.maskingProbabilityThreshold,
224-
self.sstModel.modelSpecifications.maskingDistance,
225-
self.probabilisticAtlas,
226-
self.sstModel.voxelSpacing)
227-
combinedImageBuffers = logTransform(self.combinedImageBuffers, self.sstModel.mask)
228-
229-
# Retrieve the masked sst and time points
230-
self.sstModel.imageBuffers = combinedImageBuffers[..., 0:numberOfContrasts]
231-
for timepointNumber in range(self.numberOfTimepoints):
232-
self.imageBuffersList[timepointNumber] = combinedImageBuffers[..., (timepointNumber + 1) * numberOfContrasts:
233-
(timepointNumber + 2) * numberOfContrasts]
220+
self.voxelSpacings = []
221+
self.transforms = []
222+
self.masks = []
223+
self.croppings = []
224+
if self.allIdentityTransforms:
225+
226+
self.imageBuffersList = []
227+
for imageFileNames in self.imageFileNamesList:
228+
imageBuffers, _, _, _ = readCroppedImages(imageFileNames, templateFileName,
229+
self.imageToImageTransformMatrix)
230+
self.imageBuffersList.append(imageBuffers)
231+
232+
# Put everything in a big 4-D matrix to derive one consistent mask across all time points
233+
imageSize = self.sstModel.imageBuffers.shape[:3]
234+
numberOfContrasts = self.sstModel.imageBuffers.shape[-1]
235+
self.combinedImageBuffers = np.zeros(imageSize + (numberOfContrasts * (1 + self.numberOfTimepoints),))
236+
self.combinedImageBuffers[..., 0:numberOfContrasts] = self.sstModel.imageBuffers
237+
for timepointNumber in range(self.numberOfTimepoints):
238+
self.combinedImageBuffers[..., (timepointNumber + 1) * numberOfContrasts:
239+
(timepointNumber + 2) * numberOfContrasts] = self.imageBuffersList[
240+
timepointNumber]
241+
242+
self.combinedImageBuffers, self.sstModel.mask = maskOutBackground(self.combinedImageBuffers,
243+
self.sstModel.modelSpecifications.atlasFileName,
244+
self.sstModel.transform,
245+
self.sstModel.modelSpecifications.maskingProbabilityThreshold,
246+
self.sstModel.modelSpecifications.maskingDistance,
247+
self.probabilisticAtlas,
248+
self.sstModel.voxelSpacing)
249+
combinedImageBuffers = logTransform(self.combinedImageBuffers, self.sstModel.mask)
250+
251+
# Retrieve the masked sst and time points
252+
self.sstModel.imageBuffers = combinedImageBuffers[..., 0:numberOfContrasts]
253+
for timepointNumber in range(self.numberOfTimepoints):
254+
self.imageBuffersList[timepointNumber] = combinedImageBuffers[...,
255+
(timepointNumber + 1) * numberOfContrasts:
256+
(timepointNumber + 2) * numberOfContrasts]
257+
self.masks.append(self.sstModel.mask)
258+
self.croppings.append(self.sstModel.cropping)
259+
self.voxelSpacings.append(self.sstModel.voxelSpacing)
260+
self.transforms.append(self.sstModel.transform)
261+
262+
else:
263+
264+
for timepointNumber, imageFileNames in enumerate(self.imageFileNamesList):
265+
266+
# Compute transformation from population atlas (p) to time point (tp), passing through the template space (s)
267+
# The transformation needs to be in vox to vox space as self.imageToImageTransformMatrix
268+
# We need to concatenate the following transformations
269+
# self.imageToImageTransformMatrix -> population to template space - vox to vox transform
270+
# tmp_s.geom.vox2world -> template space - vox to world transform
271+
# self.tpToBaseTransforms[timepointNumber].inv() -> template to time point space - world to world transform
272+
# tmp_tp.geom.world2vox -> time point space - world to vox transform
273+
tmpTp = sf.load_volume(imageFileNames[0])
274+
tmpS = sf.load_volume(os.path.join(self.savePath, "base", "template_coregistered.mgz"))
275+
pToTpTransform = tmpTp.geom.world2vox @ self.tpToBaseTransforms[timepointNumber].inv() @ tmpS.geom.vox2world @ self.imageToImageTransformMatrix
276+
277+
imageBuffers, transform, voxelSpacing, cropping = readCroppedImages(imageFileNames, templateFileName, pToTpTransform.matrix)
278+
279+
#
280+
self.imageBuffersList.append(imageBuffers)
281+
self.voxelSpacings.append(voxelSpacing)
282+
self.transforms.append(transform)
283+
self.croppings.append(cropping)
284+
285+
# Derive mask for sst model
286+
imageBuffer, self.sstModel.mask = maskOutBackground(self.sstModel.imageBuffers, self.sstModel.modelSpecifications.atlasFileName,
287+
self.sstModel.transform,
288+
self.sstModel.modelSpecifications.maskingProbabilityThreshold,
289+
self.sstModel.modelSpecifications.maskingDistance,
290+
self.probabilisticAtlas,
291+
self.sstModel.voxelSpacing)
292+
self.sstModel.imageBuffers = logTransform(imageBuffer, self.sstModel.mask)
293+
294+
# Derive one mask for each time point model
295+
for timepointNumber in range(self.numberOfTimepoints):
296+
imageBuffer, timepointMask = maskOutBackground(self.imageBuffersList[timepointNumber],
297+
self.sstModel.modelSpecifications.atlasFileName,
298+
self.transforms[timepointNumber],
299+
self.sstModel.modelSpecifications.maskingProbabilityThreshold,
300+
self.sstModel.modelSpecifications.maskingDistance,
301+
self.probabilisticAtlas,
302+
self.voxelSpacings[timepointNumber])
303+
imageBuffer = logTransform(imageBuffer, timepointMask)
304+
self.imageBuffersList[timepointNumber] = imageBuffer
305+
self.masks.append(timepointMask)
234306

235307
# construct timepoint models
236308
self.constructTimepointModels()
@@ -272,7 +344,7 @@ def fitModel(self):
272344
ax.set_title('sst after bias field correction')
273345
for timepointNumber in range(self.numberOfTimepoints):
274346
ax = axs.ravel()[2 + timepointNumber]
275-
ax.hist(self.imageBuffersList[timepointNumber][self.sstModel.mask, contrastNumber], bins)
347+
ax.hist(self.imageBuffersList[timepointNumber][self.masks[timepointNumber], contrastNumber], bins)
276348
ax.grid()
277349
ax.set_title('time point ' + str(timepointNumber))
278350
axsList.append(axs)
@@ -331,6 +403,17 @@ def fitModel(self):
331403
# For the GMM part, I'm using the *average* number of voxels assigned to the components in each mixture (class) of the
332404
# SST segmentation, so that all the components in each mixture are well-regularized (and tiny components don't get to do
333405
# whatever they want)
406+
#
407+
# Note that we need to take into account the possible resolution difference between SST and each time point.
408+
# Here we assume that these time points have similar resolution, otherwise the mean might not be the best choice
409+
# Scale latent number of measurements by the voxel spacing ratio between the subject-specific template and the time point mean resolution
410+
meanTimePointResolution = 0
411+
for t in range(self.numberOfTimepoints):
412+
meanTimePointResolution += np.prod(self.timepointModels[t].voxelSpacing)
413+
meanTimePointResolution /= self.numberOfTimepoints
414+
voxelSpacingRatio = np.prod(self.sstModel.voxelSpacing) / meanTimePointResolution
415+
print("Voxel spacing ratio: " + str(voxelSpacingRatio))
416+
334417
K0 = self.sstModel.modelSpecifications.K # Stiffness population -> latent position
335418
K1 = self.strengthOfLatentDeformationHyperprior * K0 # Stiffness latent position -> each time point
336419
sstEstimatedNumberOfVoxelsPerGaussian = np.sum(self.sstModel.optimizationHistory[-1]['posteriorsAtEnd'], axis=0) * \
@@ -348,11 +431,11 @@ def fitModel(self):
348431
sstEstimatedNumberOfVoxelsInClass = np.sum(sstEstimatedNumberOfVoxelsPerGaussian[gaussianNumbers])
349432

350433
self.latentMixtureWeightsNumberOfMeasurements[
351-
classNumber] = self.strengthOfLatentGMMHyperprior * sstEstimatedNumberOfVoxelsInClass
434+
classNumber] = self.strengthOfLatentGMMHyperprior * sstEstimatedNumberOfVoxelsInClass * voxelSpacingRatio
352435

353436
averageSizeOfComponents = sstEstimatedNumberOfVoxelsInClass / numberOfComponents
354-
self.latentMeansNumberOfMeasurements[gaussianNumbers] = self.strengthOfLatentGMMHyperprior * averageSizeOfComponents
355-
self.latentVariancesNumberOfMeasurements[gaussianNumbers] = self.strengthOfLatentGMMHyperprior * averageSizeOfComponents
437+
self.latentMeansNumberOfMeasurements[gaussianNumbers] = self.strengthOfLatentGMMHyperprior * averageSizeOfComponents * voxelSpacingRatio
438+
self.latentVariancesNumberOfMeasurements[gaussianNumbers] = self.strengthOfLatentGMMHyperprior * averageSizeOfComponents * voxelSpacingRatio
356439

357440
# Estimating the mode of the latentVariance posterior distribution (which is Wishart) requires a stringent condition
358441
# on latentVariancesNumberOfMeasurements so that the mode is actually defined
@@ -441,8 +524,8 @@ def fitModel(self):
441524
import matplotlib.pyplot as plt # avoid importing matplotlib by default
442525
plt.ion()
443526
self.timepointModels[timepointNumber].biasField.downSampleBasisFunctions([1, 1, 1])
444-
timepointBiasFields = self.timepointModels[timepointNumber].biasField.getBiasFields(self.sstModel.mask)
445-
timepointData = self.imageBuffersList[timepointNumber][self.sstModel.mask, :] - timepointBiasFields[self.sstModel.mask, :]
527+
timepointBiasFields = self.timepointModels[timepointNumber].biasField.getBiasFields(self.masks[timepointNumber])
528+
timepointData = self.imageBuffersList[timepointNumber][self.masks[timepointNumber], :] - timepointBiasFields[self.masks[timepointNumber], :]
446529
for contrastNumber in range(self.sstModel.gmm.numberOfContrasts):
447530
axs = axsList[contrastNumber]
448531
ax = axs.ravel()[2 + timepointNumber]
@@ -469,6 +552,7 @@ def fitModel(self):
469552
#
470553
# The parameter estimation happens in a (potentially) downsampled image grid, so it's import to work in the same space
471554
# when measuring and updating the latentDeformation
555+
472556
transformUsedForEstimation = gems.KvlTransform(
473557
requireNumpyArray(self.sstModel.optimizationHistory[-1]['downSampledTransformMatrix']))
474558
mesh_collection = gems.KvlMeshCollection()
@@ -484,6 +568,7 @@ def fitModel(self):
484568
self.probabilisticAtlas.mapPositionsFromTemplateToSubjectSpace(positionInTemplateSpace, transformUsedForEstimation))
485569
mesh_collection.set_positions(referencePosition, timepointPositions)
486570

571+
487572
# Read mesh in sst warp
488573
mesh = self.probabilisticAtlas.getMesh(latentAtlasFileName, transformUsedForEstimation)
489574

@@ -733,16 +818,21 @@ def generateSubjectSpecificTemplate(self):
733818

734819
# Read in the various time point images, and compute the average
735820
numberOfTimepoints = len(contrastImageFileNames)
736-
image0 = gems.KvlImage(contrastImageFileNames[0])
737-
imageBuffer = image0.getImageBuffer().copy()
821+
image0 = sf.load_volume(contrastImageFileNames[0])
822+
imageBuffer = image0.transform(affine=self.tpToBaseTransforms[0])
823+
# Make sure that we are averaging only non zero voxels
824+
count = np.zeros(imageBuffer.shape)
825+
count[imageBuffer > 0] += 1
738826
for timepointNumber in range(1, numberOfTimepoints):
739-
imageBuffer += gems.KvlImage(contrastImageFileNames[timepointNumber]).getImageBuffer()
740-
imageBuffer /= numberOfTimepoints
827+
tmp = sf.load_volume(contrastImageFileNames[timepointNumber]).transform(affine=self.tpToBaseTransforms[timepointNumber]).data
828+
imageBuffer += tmp
829+
count[tmp > 0] += 1
830+
# Make sure that we are not dividing by zero for, e.g., background voxels
831+
imageBuffer[count > 0] /= count[count > 0]
741832

742-
# Create an ITK image and write to disk
743-
sst = gems.KvlImage(requireNumpyArray(imageBuffer))
833+
# Write image to disk
744834
sstFilename = os.path.join(sstDir, 'mode%02d_average.mgz' % (contrastNumber + 1))
745-
sst.write(sstFilename, image0.transform_matrix)
835+
imageBuffer.save(sstFilename)
746836

747837
#
748838
sstFileNames.append(sstFilename)
@@ -790,8 +880,9 @@ def constructTimepointModels(self):
790880
pallidumAsWM=self.pallidumAsWM,
791881
savePosteriors=self.savePosteriors
792882
))
793-
self.timepointModels[timepointNumber].mask = self.sstModel.mask
883+
884+
self.timepointModels[timepointNumber].mask = self.masks[timepointNumber]
794885
self.timepointModels[timepointNumber].imageBuffers = self.imageBuffersList[timepointNumber]
795-
self.timepointModels[timepointNumber].voxelSpacing = self.sstModel.voxelSpacing
796-
self.timepointModels[timepointNumber].transform = self.sstModel.transform
797-
self.timepointModels[timepointNumber].cropping = self.sstModel.cropping
886+
self.timepointModels[timepointNumber].voxelSpacing = self.voxelSpacings[timepointNumber]
887+
self.timepointModels[timepointNumber].transform = self.transforms[timepointNumber]
888+
self.timepointModels[timepointNumber].cropping = self.croppings[timepointNumber]

python/gems/SamsegLongitudinalLesion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def __init__(self,
3636
numberOfPseudoSamplesVariance=500,
3737
rho=50,
3838
intensityMaskingPattern=None,
39-
intensityMaskingSearchString='Cortex'
39+
intensityMaskingSearchString='Cortex',
40+
tpToBaseTransforms=None,
4041
):
4142
SamsegLongitudinal.__init__(self,
4243
imageFileNamesList=imageFileNamesList,
@@ -62,7 +63,8 @@ def __init__(self,
6263
thresholdSearchString=thresholdSearchString,
6364
modeNames=modeNames,
6465
pallidumAsWM=pallidumAsWM,
65-
savePosteriors=savePosteriors
66+
savePosteriors=savePosteriors,
67+
tpToBaseTransforms=tpToBaseTransforms
6668
)
6769

6870
self.numberOfSamplingSteps = numberOfSamplingSteps

0 commit comments

Comments
 (0)