@@ -94,7 +94,8 @@ def __init__(self,
94
94
thresholdSearchString = None ,
95
95
modeNames = None ,
96
96
pallidumAsWM = True ,
97
- savePosteriors = False
97
+ savePosteriors = False ,
98
+ tpToBaseTransforms = None ,
98
99
):
99
100
100
101
# Store input parameters as class variables
@@ -112,6 +113,16 @@ def __init__(self,
112
113
self .modeNames = modeNames
113
114
self .pallidumAsWM = pallidumAsWM
114
115
self .savePosteriors = savePosteriors
116
+ self .tpToBaseTransforms = tpToBaseTransforms
117
+
118
+ # Check if all time point to base transforms are identity matrices.
119
+ # If so, we can derive a combined 4D mask during preprocessing
120
+ self .allIdentityTransforms = True
121
+ if tpToBaseTransforms is not None :
122
+ for tp , transform in enumerate (self .tpToBaseTransforms ):
123
+ if not np .allclose (transform .matrix , np .eye (4 )):
124
+ self .allIdentityTransforms = False
125
+
115
126
116
127
# Initialize some objects
117
128
self .probabilisticAtlas = ProbabilisticAtlas ()
@@ -190,6 +201,7 @@ def constructAndRegisterSubjectSpecificTemplate(self):
190
201
templateFileName = templateFileName )
191
202
self .imageToImageTransformMatrix , _ = affine .registerAtlas (savePath = sstDir , visualizer = self .visualizer )
192
203
204
+
193
205
def preProcess (self ):
194
206
195
207
# construct sstModel
@@ -205,32 +217,92 @@ def preProcess(self):
205
217
self .sstModel .imageBuffers , self .sstModel .transform , self .sstModel .voxelSpacing , self .sstModel .cropping = readCroppedImages (self .sstFileNames , templateFileName , self .imageToImageTransformMatrix )
206
218
207
219
self .imageBuffersList = []
208
- for imageFileNames in self .imageFileNamesList :
209
- imageBuffers , _ , _ , _ = readCroppedImages (imageFileNames , templateFileName , self .imageToImageTransformMatrix )
210
- self .imageBuffersList .append (imageBuffers )
211
-
212
- # Put everything in a big 4-D matrix to derive one consistent mask across all time points
213
- imageSize = self .sstModel .imageBuffers .shape [:3 ]
214
- numberOfContrasts = self .sstModel .imageBuffers .shape [- 1 ]
215
- self .combinedImageBuffers = np .zeros (imageSize + (numberOfContrasts * (1 + self .numberOfTimepoints ),))
216
- self .combinedImageBuffers [..., 0 :numberOfContrasts ] = self .sstModel .imageBuffers
217
- for timepointNumber in range (self .numberOfTimepoints ):
218
- self .combinedImageBuffers [..., (timepointNumber + 1 ) * numberOfContrasts :
219
- (timepointNumber + 2 ) * numberOfContrasts ] = self .imageBuffersList [timepointNumber ]
220
-
221
- self .combinedImageBuffers , self .sstModel .mask = maskOutBackground (self .combinedImageBuffers , self .sstModel .modelSpecifications .atlasFileName ,
222
- self .sstModel .transform ,
223
- self .sstModel .modelSpecifications .maskingProbabilityThreshold ,
224
- self .sstModel .modelSpecifications .maskingDistance ,
225
- self .probabilisticAtlas ,
226
- self .sstModel .voxelSpacing )
227
- combinedImageBuffers = logTransform (self .combinedImageBuffers , self .sstModel .mask )
228
-
229
- # Retrieve the masked sst and time points
230
- self .sstModel .imageBuffers = combinedImageBuffers [..., 0 :numberOfContrasts ]
231
- for timepointNumber in range (self .numberOfTimepoints ):
232
- self .imageBuffersList [timepointNumber ] = combinedImageBuffers [..., (timepointNumber + 1 ) * numberOfContrasts :
233
- (timepointNumber + 2 ) * numberOfContrasts ]
220
+ self .voxelSpacings = []
221
+ self .transforms = []
222
+ self .masks = []
223
+ self .croppings = []
224
+ if self .allIdentityTransforms :
225
+
226
+ self .imageBuffersList = []
227
+ for imageFileNames in self .imageFileNamesList :
228
+ imageBuffers , _ , _ , _ = readCroppedImages (imageFileNames , templateFileName ,
229
+ self .imageToImageTransformMatrix )
230
+ self .imageBuffersList .append (imageBuffers )
231
+
232
+ # Put everything in a big 4-D matrix to derive one consistent mask across all time points
233
+ imageSize = self .sstModel .imageBuffers .shape [:3 ]
234
+ numberOfContrasts = self .sstModel .imageBuffers .shape [- 1 ]
235
+ self .combinedImageBuffers = np .zeros (imageSize + (numberOfContrasts * (1 + self .numberOfTimepoints ),))
236
+ self .combinedImageBuffers [..., 0 :numberOfContrasts ] = self .sstModel .imageBuffers
237
+ for timepointNumber in range (self .numberOfTimepoints ):
238
+ self .combinedImageBuffers [..., (timepointNumber + 1 ) * numberOfContrasts :
239
+ (timepointNumber + 2 ) * numberOfContrasts ] = self .imageBuffersList [
240
+ timepointNumber ]
241
+
242
+ self .combinedImageBuffers , self .sstModel .mask = maskOutBackground (self .combinedImageBuffers ,
243
+ self .sstModel .modelSpecifications .atlasFileName ,
244
+ self .sstModel .transform ,
245
+ self .sstModel .modelSpecifications .maskingProbabilityThreshold ,
246
+ self .sstModel .modelSpecifications .maskingDistance ,
247
+ self .probabilisticAtlas ,
248
+ self .sstModel .voxelSpacing )
249
+ combinedImageBuffers = logTransform (self .combinedImageBuffers , self .sstModel .mask )
250
+
251
+ # Retrieve the masked sst and time points
252
+ self .sstModel .imageBuffers = combinedImageBuffers [..., 0 :numberOfContrasts ]
253
+ for timepointNumber in range (self .numberOfTimepoints ):
254
+ self .imageBuffersList [timepointNumber ] = combinedImageBuffers [...,
255
+ (timepointNumber + 1 ) * numberOfContrasts :
256
+ (timepointNumber + 2 ) * numberOfContrasts ]
257
+ self .masks .append (self .sstModel .mask )
258
+ self .croppings .append (self .sstModel .cropping )
259
+ self .voxelSpacings .append (self .sstModel .voxelSpacing )
260
+ self .transforms .append (self .sstModel .transform )
261
+
262
+ else :
263
+
264
+ for timepointNumber , imageFileNames in enumerate (self .imageFileNamesList ):
265
+
266
+ # Compute transformation from population atlas (p) to time point (tp), passing through the template space (s)
267
+ # The transformation needs to be in vox to vox space as self.imageToImageTransformMatrix
268
+ # We need to concatenate the following transformations
269
+ # self.imageToImageTransformMatrix -> population to template space - vox to vox transform
270
+ # tmp_s.geom.vox2world -> template space - vox to world transform
271
+ # self.tpToBaseTransforms[timepointNumber].inv() -> template to time point space - world to world transform
272
+ # tmp_tp.geom.world2vox -> time point space - world to vox transform
273
+ tmpTp = sf .load_volume (imageFileNames [0 ])
274
+ tmpS = sf .load_volume (os .path .join (self .savePath , "base" , "template_coregistered.mgz" ))
275
+ pToTpTransform = tmpTp .geom .world2vox @ self .tpToBaseTransforms [timepointNumber ].inv () @ tmpS .geom .vox2world @ self .imageToImageTransformMatrix
276
+
277
+ imageBuffers , transform , voxelSpacing , cropping = readCroppedImages (imageFileNames , templateFileName , pToTpTransform .matrix )
278
+
279
+ #
280
+ self .imageBuffersList .append (imageBuffers )
281
+ self .voxelSpacings .append (voxelSpacing )
282
+ self .transforms .append (transform )
283
+ self .croppings .append (cropping )
284
+
285
+ # Derive mask for sst model
286
+ imageBuffer , self .sstModel .mask = maskOutBackground (self .sstModel .imageBuffers , self .sstModel .modelSpecifications .atlasFileName ,
287
+ self .sstModel .transform ,
288
+ self .sstModel .modelSpecifications .maskingProbabilityThreshold ,
289
+ self .sstModel .modelSpecifications .maskingDistance ,
290
+ self .probabilisticAtlas ,
291
+ self .sstModel .voxelSpacing )
292
+ self .sstModel .imageBuffers = logTransform (imageBuffer , self .sstModel .mask )
293
+
294
+ # Derive one mask for each time point model
295
+ for timepointNumber in range (self .numberOfTimepoints ):
296
+ imageBuffer , timepointMask = maskOutBackground (self .imageBuffersList [timepointNumber ],
297
+ self .sstModel .modelSpecifications .atlasFileName ,
298
+ self .transforms [timepointNumber ],
299
+ self .sstModel .modelSpecifications .maskingProbabilityThreshold ,
300
+ self .sstModel .modelSpecifications .maskingDistance ,
301
+ self .probabilisticAtlas ,
302
+ self .voxelSpacings [timepointNumber ])
303
+ imageBuffer = logTransform (imageBuffer , timepointMask )
304
+ self .imageBuffersList [timepointNumber ] = imageBuffer
305
+ self .masks .append (timepointMask )
234
306
235
307
# construct timepoint models
236
308
self .constructTimepointModels ()
@@ -272,7 +344,7 @@ def fitModel(self):
272
344
ax .set_title ('sst after bias field correction' )
273
345
for timepointNumber in range (self .numberOfTimepoints ):
274
346
ax = axs .ravel ()[2 + timepointNumber ]
275
- ax .hist (self .imageBuffersList [timepointNumber ][self .sstModel . mask , contrastNumber ], bins )
347
+ ax .hist (self .imageBuffersList [timepointNumber ][self .masks [ timepointNumber ] , contrastNumber ], bins )
276
348
ax .grid ()
277
349
ax .set_title ('time point ' + str (timepointNumber ))
278
350
axsList .append (axs )
@@ -331,6 +403,17 @@ def fitModel(self):
331
403
# For the GMM part, I'm using the *average* number of voxels assigned to the components in each mixture (class) of the
332
404
# SST segmentation, so that all the components in each mixture are well-regularized (and tiny components don't get to do
333
405
# whatever they want)
406
+ #
407
+ # Note that we need to take into account the possible resolution difference between SST and each time point.
408
+ # Here we assume that these time points have similar resolution, otherwise the mean might not be the best choice
409
+ # Scale latent number of measurements by the voxel spacing ratio between the subject-specific template and the time point mean resolution
410
+ meanTimePointResolution = 0
411
+ for t in range (self .numberOfTimepoints ):
412
+ meanTimePointResolution += np .prod (self .timepointModels [t ].voxelSpacing )
413
+ meanTimePointResolution /= self .numberOfTimepoints
414
+ voxelSpacingRatio = np .prod (self .sstModel .voxelSpacing ) / meanTimePointResolution
415
+ print ("Voxel spacing ratio: " + str (voxelSpacingRatio ))
416
+
334
417
K0 = self .sstModel .modelSpecifications .K # Stiffness population -> latent position
335
418
K1 = self .strengthOfLatentDeformationHyperprior * K0 # Stiffness latent position -> each time point
336
419
sstEstimatedNumberOfVoxelsPerGaussian = np .sum (self .sstModel .optimizationHistory [- 1 ]['posteriorsAtEnd' ], axis = 0 ) * \
@@ -348,11 +431,11 @@ def fitModel(self):
348
431
sstEstimatedNumberOfVoxelsInClass = np .sum (sstEstimatedNumberOfVoxelsPerGaussian [gaussianNumbers ])
349
432
350
433
self .latentMixtureWeightsNumberOfMeasurements [
351
- classNumber ] = self .strengthOfLatentGMMHyperprior * sstEstimatedNumberOfVoxelsInClass
434
+ classNumber ] = self .strengthOfLatentGMMHyperprior * sstEstimatedNumberOfVoxelsInClass * voxelSpacingRatio
352
435
353
436
averageSizeOfComponents = sstEstimatedNumberOfVoxelsInClass / numberOfComponents
354
- self .latentMeansNumberOfMeasurements [gaussianNumbers ] = self .strengthOfLatentGMMHyperprior * averageSizeOfComponents
355
- self .latentVariancesNumberOfMeasurements [gaussianNumbers ] = self .strengthOfLatentGMMHyperprior * averageSizeOfComponents
437
+ self .latentMeansNumberOfMeasurements [gaussianNumbers ] = self .strengthOfLatentGMMHyperprior * averageSizeOfComponents * voxelSpacingRatio
438
+ self .latentVariancesNumberOfMeasurements [gaussianNumbers ] = self .strengthOfLatentGMMHyperprior * averageSizeOfComponents * voxelSpacingRatio
356
439
357
440
# Estimating the mode of the latentVariance posterior distribution (which is Wishart) requires a stringent condition
358
441
# on latentVariancesNumberOfMeasurements so that the mode is actually defined
@@ -441,8 +524,8 @@ def fitModel(self):
441
524
import matplotlib .pyplot as plt # avoid importing matplotlib by default
442
525
plt .ion ()
443
526
self .timepointModels [timepointNumber ].biasField .downSampleBasisFunctions ([1 , 1 , 1 ])
444
- timepointBiasFields = self .timepointModels [timepointNumber ].biasField .getBiasFields (self .sstModel . mask )
445
- timepointData = self .imageBuffersList [timepointNumber ][self .sstModel . mask , :] - timepointBiasFields [self .sstModel . mask , :]
527
+ timepointBiasFields = self .timepointModels [timepointNumber ].biasField .getBiasFields (self .masks [ timepointNumber ] )
528
+ timepointData = self .imageBuffersList [timepointNumber ][self .masks [ timepointNumber ] , :] - timepointBiasFields [self .masks [ timepointNumber ] , :]
446
529
for contrastNumber in range (self .sstModel .gmm .numberOfContrasts ):
447
530
axs = axsList [contrastNumber ]
448
531
ax = axs .ravel ()[2 + timepointNumber ]
@@ -469,6 +552,7 @@ def fitModel(self):
469
552
#
470
553
# The parameter estimation happens in a (potentially) downsampled image grid, so it's import to work in the same space
471
554
# when measuring and updating the latentDeformation
555
+
472
556
transformUsedForEstimation = gems .KvlTransform (
473
557
requireNumpyArray (self .sstModel .optimizationHistory [- 1 ]['downSampledTransformMatrix' ]))
474
558
mesh_collection = gems .KvlMeshCollection ()
@@ -484,6 +568,7 @@ def fitModel(self):
484
568
self .probabilisticAtlas .mapPositionsFromTemplateToSubjectSpace (positionInTemplateSpace , transformUsedForEstimation ))
485
569
mesh_collection .set_positions (referencePosition , timepointPositions )
486
570
571
+
487
572
# Read mesh in sst warp
488
573
mesh = self .probabilisticAtlas .getMesh (latentAtlasFileName , transformUsedForEstimation )
489
574
@@ -733,16 +818,21 @@ def generateSubjectSpecificTemplate(self):
733
818
734
819
# Read in the various time point images, and compute the average
735
820
numberOfTimepoints = len (contrastImageFileNames )
736
- image0 = gems .KvlImage (contrastImageFileNames [0 ])
737
- imageBuffer = image0 .getImageBuffer ().copy ()
821
+ image0 = sf .load_volume (contrastImageFileNames [0 ])
822
+ imageBuffer = image0 .transform (affine = self .tpToBaseTransforms [0 ])
823
+ # Make sure that we are averaging only non zero voxels
824
+ count = np .zeros (imageBuffer .shape )
825
+ count [imageBuffer > 0 ] += 1
738
826
for timepointNumber in range (1 , numberOfTimepoints ):
739
- imageBuffer += gems .KvlImage (contrastImageFileNames [timepointNumber ]).getImageBuffer ()
740
- imageBuffer /= numberOfTimepoints
827
+ tmp = sf .load_volume (contrastImageFileNames [timepointNumber ]).transform (affine = self .tpToBaseTransforms [timepointNumber ]).data
828
+ imageBuffer += tmp
829
+ count [tmp > 0 ] += 1
830
+ # Make sure that we are not dividing by zero for, e.g., background voxels
831
+ imageBuffer [count > 0 ] /= count [count > 0 ]
741
832
742
- # Create an ITK image and write to disk
743
- sst = gems .KvlImage (requireNumpyArray (imageBuffer ))
833
+ # Write image to disk
744
834
sstFilename = os .path .join (sstDir , 'mode%02d_average.mgz' % (contrastNumber + 1 ))
745
- sst . write (sstFilename , image0 . transform_matrix )
835
+ imageBuffer . save (sstFilename )
746
836
747
837
#
748
838
sstFileNames .append (sstFilename )
@@ -790,8 +880,9 @@ def constructTimepointModels(self):
790
880
pallidumAsWM = self .pallidumAsWM ,
791
881
savePosteriors = self .savePosteriors
792
882
))
793
- self .timepointModels [timepointNumber ].mask = self .sstModel .mask
883
+
884
+ self .timepointModels [timepointNumber ].mask = self .masks [timepointNumber ]
794
885
self .timepointModels [timepointNumber ].imageBuffers = self .imageBuffersList [timepointNumber ]
795
- self .timepointModels [timepointNumber ].voxelSpacing = self .sstModel . voxelSpacing
796
- self .timepointModels [timepointNumber ].transform = self .sstModel . transform
797
- self .timepointModels [timepointNumber ].cropping = self .sstModel . cropping
886
+ self .timepointModels [timepointNumber ].voxelSpacing = self .voxelSpacings [ timepointNumber ]
887
+ self .timepointModels [timepointNumber ].transform = self .transforms [ timepointNumber ]
888
+ self .timepointModels [timepointNumber ].cropping = self .croppings [ timepointNumber ]
0 commit comments