diff --git a/src/deform/regularize.cpp b/src/deform/regularize.cpp index 7ffd32f..631a30e 100644 --- a/src/deform/regularize.cpp +++ b/src/deform/regularize.cpp @@ -3,8 +3,6 @@ #include #include -#include -#include #include "deform/command.h" @@ -43,97 +41,38 @@ int RegularisationCommand::_execute(void) using namespace std::chrono; auto t_start = high_resolution_clock::now(); - VolumePyramid deformation_pyramid; - deformation_pyramid.set_level_count(pyramid_levels); + LOG(Info) << "Input: '" << _args.positional("deformation") << "'"; + stk::VolumeFloat3 src = stk::read_volume(_args.positional("deformation").c_str()); - { - LOG(Info) << "Input: '" << _args.positional("deformation") << "'"; - - stk::Volume src = stk::read_volume(_args.positional("deformation").c_str()); - if (!src.valid()) return 1; - - if (src.voxel_type() != stk::Type_Float3) { - LOG(Error) << "Invalid voxel type for deformation field, expected float3"; - return 1; - } - - #ifdef DF_ENABLE_DISPLACEMENT_FIELD_RESIDUALS - deformation_pyramid.build_from_base_with_residual(src, filters::downsample_vectorfield_by_2); - #else - deformation_pyramid.build_from_base(src, filters::downsample_vectorfield_by_2); - #endif - } - - bool use_constraints = false; stk::Volume constraints_mask, constraints_values; - if (!constraint_mask_file.empty() && !constraint_values_file.empty()) { + if (!constraint_mask_file.empty()) { LOG(Info) << "Constraint mask: '" << constraint_mask_file << "'"; - LOG(Info) << "Constraint values: '" << constraint_values_file << "'"; - constraints_mask = stk::read_volume(constraint_mask_file.c_str()); if (!constraints_mask.valid()) return 1; - + } + if (!constraint_values_file.empty()) { + LOG(Info) << "Constraint values: '" << constraint_values_file << "'"; constraints_values = stk::read_volume(constraint_values_file.c_str()); if (!constraints_values.valid()) return 1; - - use_constraints = true; - } - else { - constraints_mask = stk::VolumeUChar(deformation_pyramid.volume(0).size(), uint8_t{0}); - constraints_values = stk::VolumeFloat3(deformation_pyramid.volume(0).size(), float3{0, 0, 0}); } - VolumePyramid constraints_mask_pyramid, constraints_pyramid; - voxel_constraints::build_pyramids( - constraints_mask, - constraints_values, + stk::VolumeFloat3 out = regularization( + src, + precision, pyramid_levels, - constraints_mask_pyramid, - constraints_pyramid + constraints_mask, + constraints_values ); - // Initialization is only needed if we have constraints - if (use_constraints) { - // Perform initialization at the coarsest resolution - stk::VolumeFloat3 def = deformation_pyramid.volume(pyramid_levels-1); - initialize_regularization( - def, - constraints_mask_pyramid.volume(pyramid_levels-1), - constraints_pyramid.volume(pyramid_levels-1) - ); + if (!out.valid()) { + return 1; } - for (int l = pyramid_levels-1; l >= 0; --l) { - stk::VolumeFloat3 def = deformation_pyramid.volume(l); - - LOG(Info) << "Performing regularization level " << l; - - do_regularization( - def, - constraints_mask_pyramid.volume(l), - constraints_pyramid.volume(l), - precision - ); - - if (l != 0) { - dim3 upsampled_dims = deformation_pyramid.volume(l - 1).size(); - deformation_pyramid.set_volume(l - 1, - #ifdef DF_ENABLE_DISPLACEMENT_FIELD_RESIDUALS - filters::upsample_vectorfield(def, upsampled_dims, deformation_pyramid.residual(l - 1)) - #else - filters::upsample_vectorfield(def, upsampled_dims) - #endif - ); - } - else { - deformation_pyramid.set_volume(0, def); - } - } auto t_end = high_resolution_clock::now(); int elapsed = int(round(duration_cast>(t_end - t_start).count())); LOG(Info) << "Regularization completed in " << elapsed / 60 << ":" << std::right << std::setw(2) << std::setfill('0') << elapsed % 60; LOG(Info) << "Writing to '" << output_file << "'"; - stk::write_volume(output_file.c_str(), deformation_pyramid.volume(0)); + stk::write_volume(output_file.c_str(), out); return EXIT_SUCCESS; } diff --git a/src/deform_lib/regularize.cpp b/src/deform_lib/regularize.cpp index ab10df8..1032901 100644 --- a/src/deform_lib/regularize.cpp +++ b/src/deform_lib/regularize.cpp @@ -177,3 +177,92 @@ void do_regularization( } } + +stk::VolumeFloat3 regularization( + const stk::VolumeFloat3& df, + float precision, + int pyramid_levels, + stk::VolumeUChar constraints_mask, + stk::VolumeFloat3 constraints_values +) +{ + VolumePyramid deformation_pyramid; + deformation_pyramid.set_level_count(pyramid_levels); + + if (!df.valid()) return stk::Volume(); + + if (df.voxel_type() != stk::Type_Float3) { + LOG(Error) << "Invalid voxel type for deformation field, expected float3"; + return stk::Volume(); + } + + // Clone to avoid directly modifying our input +#ifdef DF_ENABLE_DISPLACEMENT_FIELD_RESIDUALS + deformation_pyramid.build_from_base_with_residual(df.clone(), filters::downsample_vectorfield_by_2); +#else + deformation_pyramid.build_from_base(df.clone(), filters::downsample_vectorfield_by_2); +#endif + + bool use_constraints = false; + + if (!constraints_mask.valid()) { + constraints_mask = stk::VolumeUChar(deformation_pyramid.volume(0).size(), uint8_t{0}); + } + else { + use_constraints = true; + } + + if (!constraints_values.valid()) { + constraints_values = stk::VolumeFloat3(deformation_pyramid.volume(0).size(), float3{0, 0, 0}); + } + + VolumePyramid constraints_mask_pyramid, constraints_pyramid; + voxel_constraints::build_pyramids( + constraints_mask, + constraints_values, + pyramid_levels, + constraints_mask_pyramid, + constraints_pyramid + ); + + // Initialization is only needed if we have constraints + if (use_constraints) { + // Perform initialization at the coarsest resolution + stk::VolumeFloat3 def = deformation_pyramid.volume(pyramid_levels-1); + initialize_regularization( + def, + constraints_mask_pyramid.volume(pyramid_levels-1), + constraints_pyramid.volume(pyramid_levels-1) + ); + } + + for (int l = pyramid_levels-1; l >= 0; --l) { + stk::VolumeFloat3 def = deformation_pyramid.volume(l); + + LOG(Info) << "Performing regularization level " << l; + + do_regularization( + def, + constraints_mask_pyramid.volume(l), + constraints_pyramid.volume(l), + precision + ); + + if (l != 0) { + dim3 upsampled_dims = deformation_pyramid.volume(l - 1).size(); + deformation_pyramid.set_volume(l - 1, + #ifdef DF_ENABLE_DISPLACEMENT_FIELD_RESIDUALS + filters::upsample_vectorfield(def, upsampled_dims, deformation_pyramid.residual(l - 1)) + #else + filters::upsample_vectorfield(def, upsampled_dims) + #endif + ); + } + else { + deformation_pyramid.set_volume(0, def); + } + } + + return deformation_pyramid.volume(0); +} + diff --git a/src/deform_lib/regularize.h b/src/deform_lib/regularize.h index adeb866..9f78e48 100644 --- a/src/deform_lib/regularize.h +++ b/src/deform_lib/regularize.h @@ -1,13 +1,10 @@ #pragma once -void initialize_regularization( - stk::VolumeFloat3& def, - const stk::VolumeUChar& constraints_mask, - const stk::VolumeFloat3& constraints_values - ); -void do_regularization( - stk::VolumeFloat3& def, - const stk::VolumeUChar& constraints_mask, - const stk::VolumeFloat3& constraints_values, - float precision - ); +stk::VolumeFloat3 regularization( + const stk::VolumeFloat3& df, + float precision, + int pyramid_levels, + stk::VolumeUChar constraints_mask, + stk::VolumeFloat3 constraints_values +); +