forked from InsightSoftwareConsortium/ITK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
itkFEMFiniteDifferenceFunctionLoad.h
282 lines (218 loc) · 10.8 KB
/
itkFEMFiniteDifferenceFunctionLoad.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
/*=========================================================================
Program: Insight Segmentation & Registration Toolkit
Module: itkFEMFiniteDifferenceFunctionLoad.h
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Insight Software Consortium. All rights reserved.
See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef _itkFEMFiniteDifferenceFunctionLoad_h_
#define _itkFEMFiniteDifferenceFunctionLoad_h_
#include "itkFEMLoadElementBase.h"
#include "itkImage.h"
#include "itkTranslationTransform.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkNeighborhoodIterator.h"
#include "itkNeighborhoodIterator.h"
#include "itkNeighborhoodInnerProduct.h"
#include "itkDerivativeOperator.h"
#include "itkForwardDifferenceOperator.h"
#include "itkLinearInterpolateImageFunction.h"
#include "vnl/vnl_math.h"
#include "itkDemonsRegistrationFunction.h"
#include "itkMeanSquareRegistrationFunction.h"
#include "itkNCCRegistrationFunction.h"
#include "itkMIRegistrationFunction.h"
namespace itk
{
namespace fem
{
/**
* \class FiniteDifferenceFunctionLoad
* \brief General image pair load that uses the itkFiniteDifferenceFunctions.
*
* This load computes FEM gravity loads by using derivatives provided
* by itkFiniteDifferenceFunctions (e.g. mean squares intensity difference.)
* The function responsible for this is called Fg, as required by the FEMLoad
* standards. It takes a vnl_vector as input.
* We assume the vector input is of size 2*ImageDimension.
* The 0 to ImageDimension-1 elements contain the position, p,
* in the reference (moving) image. The next ImageDimension to 2*ImageDimension-1
* elements contain the value of the vector field at that point, v(p).
* The metrics return both a scalar similarity value and vector-valued derivative.
* The derivative is what gives us the force to drive the FEM registration.
* These values are computed with respect to some region in the Fixed image.
* This region size may be set by the user by calling SetMetricRadius.
* As the metric derivative computation evolves, performance should improve
* and more functionality will be available (such as scale selection).
*/
template<class TMoving,class TFixed>
class FiniteDifferenceFunctionLoad : public LoadElement
{
FEM_CLASS(FiniteDifferenceFunctionLoad,LoadElement)
public:
// Necessary typedefs for dealing with images BEGIN
typedef typename LoadElement::Float Float;
typedef TMoving MovingImageType;
typedef typename MovingImageType::ConstPointer MovingConstPointer;
typedef MovingImageType* MovingPointer;
typedef TFixed FixedImageType;
typedef FixedImageType* FixedPointer;
typedef typename FixedImageType::ConstPointer FixedConstPointer;
/** Dimensionality of input and output data is assumed to be the same. */
itkStaticConstMacro(ImageDimension, unsigned int,
MovingImageType::ImageDimension);
typedef ImageRegionIteratorWithIndex<MovingImageType> MovingRegionIteratorType;
typedef ImageRegionIteratorWithIndex<FixedImageType> FixedRegionIteratorType;
typedef NeighborhoodIterator<MovingImageType>
MovingNeighborhoodIteratorType;
typedef typename MovingNeighborhoodIteratorType::IndexType
MovingNeighborhoodIndexType;
typedef typename MovingNeighborhoodIteratorType::RadiusType
MovingRadiusType;
typedef typename MovingNeighborhoodIteratorType::RadiusType
RadiusType;
typedef NeighborhoodIterator<FixedImageType>
FixedNeighborhoodIteratorType;
typedef typename FixedNeighborhoodIteratorType::IndexType
FixedNeighborhoodIndexType;
typedef typename FixedNeighborhoodIteratorType::RadiusType
FixedRadiusType;
// IMAGE DATA
typedef typename MovingImageType::PixelType MovingPixelType;
typedef typename FixedImageType::PixelType FixedPixelType;
typedef Float PixelType;
typedef Float ComputationType;
typedef Image< PixelType, itkGetStaticConstMacro(ImageDimension) > ImageType;
typedef itk::Vector<float,itkGetStaticConstMacro(ImageDimension)> VectorType;
typedef vnl_vector<Float> FEMVectorType;
typedef Image< VectorType, itkGetStaticConstMacro(ImageDimension) > DeformationFieldType;
typedef typename DeformationFieldType::Pointer DeformationFieldTypePointer;
typedef NeighborhoodIterator<DeformationFieldType>
FieldIteratorType;
// Necessary typedefs for dealing with images END
/** PDEDeformableRegistrationFilterFunction type. */
typedef PDEDeformableRegistrationFunction<FixedImageType,MovingImageType,
DeformationFieldType> FiniteDifferenceFunctionType;
typedef typename FiniteDifferenceFunctionType::Pointer FiniteDifferenceFunctionTypePointer;
typedef typename FiniteDifferenceFunctionType::TimeStepType TimeStepType;
typedef MeanSquareRegistrationFunction<FixedImageType,MovingImageType,
DeformationFieldType> MeanSquareRegistrationFunctionType;
typedef DemonsRegistrationFunction<FixedImageType,MovingImageType,
DeformationFieldType> DemonsRegistrationFunctionType;
typedef NCCRegistrationFunction<FixedImageType,MovingImageType,
DeformationFieldType> NCCRegistrationFunctionType;
typedef MIRegistrationFunction<FixedImageType,MovingImageType,
DeformationFieldType> MIRegistrationFunctionType;
// FUNCTIONS
/* This method sets the pointer to a FiniteDifferenceFunction object that
* will be used by the filter to calculate updates at image pixels.
* \returns A FiniteDifferenceObject pointer. */
void SetDifferenceFunction( FiniteDifferenceFunctionTypePointer drfp)
{
drfp->SetFixedImage(m_FixedImage);
drfp->SetMovingImage(m_MovingImage);
drfp->SetRadius(m_MetricRadius);
drfp->SetDeformationField(m_DeformationField);
drfp->InitializeIteration();
this->m_DifferenceFunction=drfp;
}
void SetMetric( FiniteDifferenceFunctionTypePointer drfp )
{
this->SetDifferenceFunction( static_cast<FiniteDifferenceFunctionType *>(
drfp.GetPointer() ) );
m_FixedSize=m_DeformationField->GetLargestPossibleRegion().GetSize();
}
/** Define the reference (moving) image. */
void SetMovingImage(MovingImageType* R)
{
m_MovingImage = R;
m_MovingSize=m_MovingImage->GetLargestPossibleRegion().GetSize();
if (this->m_DifferenceFunction) this->m_DifferenceFunction->SetMovingImage(m_MovingImage);
// this->InitializeIteration();
};
/** Define the target (fixed) image. */
void SetFixedImage(FixedImageType* T)
{
m_FixedImage=T;
m_FixedSize=T->GetLargestPossibleRegion().GetSize();
if (this->m_DifferenceFunction) this->m_DifferenceFunction->SetFixedImage(m_MovingImage);
// this->InitializeIteration();
};
MovingPointer GetMovingImage() { return m_MovingImage; };
FixedPointer GetFixedImage() { return m_FixedImage; };
/** Define the metric region size. */
void SetMetricRadius(MovingRadiusType T) {m_MetricRadius = T; };
/** Get the metric region size. */
MovingRadiusType GetMetricRadius() { return m_MetricRadius; };
/** Set/Get methods for the number of integration points to use
* in each 1-dimensional line integral when evaluating the load.
* This value is passed to the load implementation.
*/
void SetNumberOfIntegrationPoints(unsigned int i){ m_NumberOfIntegrationPoints=i;}
unsigned int GetNumberOfIntegrationPoints(){ return m_NumberOfIntegrationPoints;}
/** Set the direction of the gradient (uphill or downhill).
* E.g. the mean squares metric should be minimized while NCC and PR should be maximized.
*/
void SetSign(Float s) {m_Sign=s;}
/** Set the sigma in a gaussian measure. */
void SetTemp(Float s) {m_Temp=s;}
/** Scaling of the similarity energy term */
void SetGamma(Float s) {m_Gamma=s;}
void SetSolution(Solution::ConstPointer ptr) { m_Solution=ptr; }
Solution::ConstPointer GetSolution() { return m_Solution; }
// FIXME - WE ASSUME THE 2ND VECTOR (INDEX 1) HAS THE INFORMATION WE WANT
Float GetSolution(unsigned int i,unsigned int which=0)
{
return m_Solution->GetSolutionValue(i,which);
}
FiniteDifferenceFunctionLoad(); // cannot be private until we always use smart pointers
Float EvaluateMetricGivenSolution ( Element::ArrayType* el, Float step=1.0);
/**
* Compute the image based load - implemented with ITK metric derivatives.
*/
VectorType Fe1(VectorType);
FEMVectorType Fe(FEMVectorType,FEMVectorType);
static Baseclass* NewFiniteDifferenceFunctionLoad(void)
{ return new FiniteDifferenceFunctionLoad; }
/** Set the */
void SetDeformationField( DeformationFieldTypePointer df)
{ m_DeformationField=df;}
/** Get the */
DeformationFieldTypePointer GetDeformationField() { return m_DeformationField;}
void InitializeIteration();
void InitializeMetric();
void PrintCurrentEnergy();
double GetCurrentEnergy();
void SetCurrentEnergy( double e = 0.0);
protected:
private:
MovingPointer m_MovingImage;
FixedPointer m_FixedImage;
MovingRadiusType m_MetricRadius; /** used by the metric to set region size for fixed image*/
typename MovingImageType::SizeType m_MovingSize;
typename FixedImageType::SizeType m_FixedSize;
unsigned int m_NumberOfIntegrationPoints;
unsigned int m_SolutionIndex;
unsigned int m_SolutionIndex2;
Float m_Temp;
Float m_Gamma;
typename Solution::ConstPointer m_Solution;
float m_GradSigma;
float m_Sign;
float m_WhichMetric;
FiniteDifferenceFunctionTypePointer m_DifferenceFunction;
typename DeformationFieldType::Pointer m_DeformationField;
/** Dummy static int that enables automatic registration
with FEMObjectFactory. */
static const int DummyCLID;
};
}} // end namespace fem/itk
#ifndef ITK_MANUAL_INSTANTIATION
#include "itkFEMFiniteDifferenceFunctionLoad.txx"
#endif
#endif