Skip to content

Commit

Permalink
Merge pull request #110 from simeks/regularization-map
Browse files Browse the repository at this point in the history
Enabled setting the regularization weights per voxel
  • Loading branch information
simeks committed Oct 25, 2019
2 parents 4ad2d8b + c27cc81 commit bb948d4
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 12 deletions.
25 changes: 22 additions & 3 deletions src/deform/registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ bool RegistrationCommand::_parse_arguments(void)
_args.add_option("constraint_mask", "--constraint_mask", "Path to the constraint mask");
_args.add_option("constraint_values", "--constraint_values", "Path to the constraint values");
_args.add_group();
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
_args.add_option("regularization_map", "-rm, --regularization_map", "Path to a map of voxel-wise regularization terms");
_args.add_group();
#endif
_args.add_option("jacobian", "-j, --jacobian", "Path to the resulting jacobian");
_args.add_option("transform", "-t, --transform", "Path to the transformed version of the first moving volume");
_args.add_group();
Expand Down Expand Up @@ -173,7 +177,7 @@ int RegistrationCommand::_execute(void)
}
else if (!constraint_mask_file.empty() || !constraint_values_file.empty()) {
// Just a check to make sure the user didn't forget something
LOG(Error) << "No constraints used, to use constraints, specify both a mask and a vectorfield";
LOG(Error) << "No constraints used, to use constraints, specify both a mask and a vector field";
return EXIT_FAILURE;
}

Expand All @@ -199,6 +203,18 @@ int RegistrationCommand::_execute(void)
return EXIT_FAILURE;
}

#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
std::string regularization_map_file = _args.get<std::string>("regularization_map", "");
LOG(Info) << "Regularization map: '" << regularization_map_file << "'";

stk::Volume regularization_map;
if (!regularization_map_file.empty()) {
regularization_map = stk::read_volume(regularization_map_file.c_str());
if (!regularization_map.valid())
return EXIT_FAILURE;
}
#endif

#ifdef DF_USE_CUDA
bool use_gpu = _args.is_set("use_gpu");
#endif
Expand All @@ -215,10 +231,13 @@ int RegistrationCommand::_execute(void)
initial_displacement,
constraint_mask,
constraint_values,
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
regularization_map,
#endif
_args.get<int>("num_threads", 0)
#ifdef DF_USE_CUDA
#ifdef DF_USE_CUDA
, use_gpu
#endif
#endif
);
}
catch (std::exception& e) {
Expand Down
2 changes: 1 addition & 1 deletion src/deform_lib/config.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

//#define DF_ENABLE_DISPLACEMENT_FIELD_RESIDUALS
//#define DF_ENABLE_REGULARIZATION_WEIGHT_MAP
#define DF_ENABLE_REGULARIZATION_WEIGHT_MAP

// Use the SSE version of linear_at<float>
// Does not make any real difference in performance on msvc2017
Expand Down
2 changes: 1 addition & 1 deletion src/deform_lib/cost_functions/regularizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct Regularizer
}

#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
void set_weight_map(stk::VolumeFloat& map) { _weight_map = map; }
void set_weight_map(const stk::VolumeFloat& map) { _weight_map = map; }
#endif // DF_ENABLE_REGULARIZATION_WEIGHT_MAP

/// p : Position in fixed image
Expand Down
19 changes: 19 additions & 0 deletions src/deform_lib/registration/registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ stk::Volume registration(
const stk::Volume& initial_deformation,
const stk::Volume& constraint_mask,
const stk::Volume& constraint_values,
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
const stk::Volume& regularization_map,
#endif
const int num_threads
)
{
Expand Down Expand Up @@ -212,6 +215,13 @@ stk::Volume registration(
engine.set_landmarks(fixed_landmarks, moving_landmarks);
}

#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
if (regularization_map.valid()) {
validate_volume_properties(regularization_map, fixed_ref, "regularization map");
engine.set_regularization_weight_map(regularization_map);
}
#endif

using namespace std::chrono;
auto t_start = high_resolution_clock::now();
stk::Volume def = engine.execute();
Expand All @@ -236,6 +246,9 @@ stk::Volume registration(
const stk::Volume& initial_deformation,
const stk::Volume& constraint_mask,
const stk::Volume& constraint_values,
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
const stk::Volume& regularization_map,
#endif
const int num_threads
#ifdef DF_USE_CUDA
, bool use_gpu
Expand All @@ -256,6 +269,9 @@ stk::Volume registration(
initial_deformation,
constraint_mask,
constraint_values,
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
regularization_map,
#endif
num_threads
);
}else{
Expand All @@ -271,6 +287,9 @@ stk::Volume registration(
initial_deformation,
constraint_mask,
constraint_values,
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
regularization_map,
#endif
num_threads
);
#ifdef DF_USE_CUDA
Expand Down
3 changes: 3 additions & 0 deletions src/deform_lib/registration/registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ stk::Volume registration(
const stk::Volume& initial_deformation,
const stk::Volume& constraint_mask,
const stk::Volume& constraint_values,
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
const stk::Volume& regularization_map,
#endif
const int num_threads
#ifdef DF_USE_CUDA
, bool use_gpu
Expand Down
10 changes: 5 additions & 5 deletions src/deform_lib/registration/registration_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ void RegistrationEngine::build_regularizer(int level, Regularizer& binary_fn)
binary_fn.set_regularization_scale(_settings.levels[level].regularization_scale);
binary_fn.set_regularization_exponent(_settings.levels[level].regularization_exponent);

#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
if (_regularization_weight_map.volume(level).valid())
binary_fn.set_weight_map(_regularization_weight_map.volume(level));
#endif

stk::Volume df = _deformation_pyramid.volume(level);
if (!_settings.regularize_initial_displacement) {
// Clone the def, because the current copy will be changed when executing the optimizer
Expand Down Expand Up @@ -412,11 +417,6 @@ void RegistrationEngine::build_unary_function(int level, UnaryFunction& unary_fn
if (_fixed_mask_pyramid.levels() > 0) {
unary_fn.set_fixed_mask(_fixed_mask_pyramid.volume(level));
}

#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
if (_regularization_weight_map.volume(level).valid())
binary_fn.set_weight_map(_regularization_weight_map.volume(l));
#endif

auto const& moving_mask = _moving_mask_pyramid.levels() > 0 ? _moving_mask_pyramid.volume(level)
: stk::VolumeFloat();
Expand Down
17 changes: 15 additions & 2 deletions src/python_wrapper/_pydeform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ constraint_values: stk.Volume
Value for the constraints on the displacement.
Requires to provide `constraint_mask`.
regularization_map: stk.Volume
Map of voxel-wise regularization weights.
Should be the same shape as the fixed image.
settings: dict
Python dictionary containing the settings for the
registration.
Expand Down Expand Up @@ -303,6 +307,9 @@ stk::Volume registration_wrapper(
const stk::Volume& initial_displacement,
const stk::Volume& constraint_mask,
const stk::Volume& constraint_values,
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
const stk::Volume& regularization_map,
#endif
const py::object& settings,
const py::object& log,
const stk::LogLevel log_level,
Expand Down Expand Up @@ -407,10 +414,13 @@ stk::Volume registration_wrapper(
initial_displacement,
constraint_mask,
constraint_values,
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
regularization_map,
#endif
num_threads
#ifdef DF_USE_CUDA
#ifdef DF_USE_CUDA
, use_gpu
#endif
#endif
);

// This must be done before the `out_stream` goes out of scope
Expand Down Expand Up @@ -560,6 +570,9 @@ PYBIND11_MODULE(_pydeform, m)
py::arg("initial_displacement") = stk::Volume(),
py::arg("constraint_mask") = stk::Volume(),
py::arg("constraint_values") = stk::Volume(),
#ifdef DF_ENABLE_REGULARIZATION_WEIGHT_MAP
py::arg("regularization_map") = stk::Volume(),
#endif
py::arg("settings") = py::none(),
py::arg("log") = py::none(),
py::arg("log_level") = stk::LogLevel::Info,
Expand Down

0 comments on commit bb948d4

Please sign in to comment.