Skip to content

Commit

Permalink
Restructuring of the internals for regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
simeks committed Oct 3, 2019
1 parent 689c334 commit a901bba
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 86 deletions.
89 changes: 14 additions & 75 deletions src/deform/regularize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

#include <deform_lib/regularize.h>
#include <deform_lib/filters/resample.h>
#include <deform_lib/registration/volume_pyramid.h>
#include <deform_lib/registration/voxel_constraints.h>

#include "deform/command.h"

Expand Down Expand Up @@ -43,97 +41,38 @@ int RegularisationCommand::_execute(void)
using namespace std::chrono;
auto t_start = high_resolution_clock::now();

VolumePyramid deformation_pyramid;
deformation_pyramid.set_level_count(pyramid_levels);
LOG(Info) << "Input: '" << _args.positional("deformation") << "'";
stk::VolumeFloat3 src = stk::read_volume(_args.positional("deformation").c_str());

{
LOG(Info) << "Input: '" << _args.positional("deformation") << "'";

stk::Volume src = stk::read_volume(_args.positional("deformation").c_str());
if (!src.valid()) return 1;

if (src.voxel_type() != stk::Type_Float3) {
LOG(Error) << "Invalid voxel type for deformation field, expected float3";
return 1;
}

#ifdef DF_ENABLE_DISPLACEMENT_FIELD_RESIDUALS
deformation_pyramid.build_from_base_with_residual(src, filters::downsample_vectorfield_by_2);
#else
deformation_pyramid.build_from_base(src, filters::downsample_vectorfield_by_2);
#endif
}

bool use_constraints = false;
stk::Volume constraints_mask, constraints_values;
if (!constraint_mask_file.empty() && !constraint_values_file.empty()) {
if (!constraint_mask_file.empty()) {
LOG(Info) << "Constraint mask: '" << constraint_mask_file << "'";
LOG(Info) << "Constraint values: '" << constraint_values_file << "'";

constraints_mask = stk::read_volume(constraint_mask_file.c_str());
if (!constraints_mask.valid()) return 1;

}
if (!constraint_values_file.empty()) {
LOG(Info) << "Constraint values: '" << constraint_values_file << "'";
constraints_values = stk::read_volume(constraint_values_file.c_str());
if (!constraints_values.valid()) return 1;

use_constraints = true;
}
else {
constraints_mask = stk::VolumeUChar(deformation_pyramid.volume(0).size(), uint8_t{0});
constraints_values = stk::VolumeFloat3(deformation_pyramid.volume(0).size(), float3{0, 0, 0});
}

VolumePyramid constraints_mask_pyramid, constraints_pyramid;
voxel_constraints::build_pyramids(
constraints_mask,
constraints_values,
stk::VolumeFloat3 out = regularization(
src,
precision,
pyramid_levels,
constraints_mask_pyramid,
constraints_pyramid
constraints_mask,
constraints_values
);

// Initialization is only needed if we have constraints
if (use_constraints) {
// Perform initialization at the coarsest resolution
stk::VolumeFloat3 def = deformation_pyramid.volume(pyramid_levels-1);
initialize_regularization(
def,
constraints_mask_pyramid.volume(pyramid_levels-1),
constraints_pyramid.volume(pyramid_levels-1)
);
if (!out.valid()) {
return 1;
}

for (int l = pyramid_levels-1; l >= 0; --l) {
stk::VolumeFloat3 def = deformation_pyramid.volume(l);

LOG(Info) << "Performing regularization level " << l;

do_regularization(
def,
constraints_mask_pyramid.volume(l),
constraints_pyramid.volume(l),
precision
);

if (l != 0) {
dim3 upsampled_dims = deformation_pyramid.volume(l - 1).size();
deformation_pyramid.set_volume(l - 1,
#ifdef DF_ENABLE_DISPLACEMENT_FIELD_RESIDUALS
filters::upsample_vectorfield(def, upsampled_dims, deformation_pyramid.residual(l - 1))
#else
filters::upsample_vectorfield(def, upsampled_dims)
#endif
);
}
else {
deformation_pyramid.set_volume(0, def);
}
}
auto t_end = high_resolution_clock::now();
int elapsed = int(round(duration_cast<duration<double>>(t_end - t_start).count()));
LOG(Info) << "Regularization completed in " << elapsed / 60 << ":" << std::right << std::setw(2) << std::setfill('0') << elapsed % 60;
LOG(Info) << "Writing to '" << output_file << "'";
stk::write_volume(output_file.c_str(), deformation_pyramid.volume(0));
stk::write_volume(output_file.c_str(), out);

return EXIT_SUCCESS;
}
89 changes: 89 additions & 0 deletions src/deform_lib/regularize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,92 @@ void do_regularization(
}
}


stk::VolumeFloat3 regularization(
const stk::VolumeFloat3& df,
float precision,
int pyramid_levels,
stk::VolumeUChar constraints_mask,
stk::VolumeFloat3 constraints_values
)
{
VolumePyramid deformation_pyramid;
deformation_pyramid.set_level_count(pyramid_levels);

if (!df.valid()) return stk::Volume();

if (df.voxel_type() != stk::Type_Float3) {
LOG(Error) << "Invalid voxel type for deformation field, expected float3";
return stk::Volume();
}

// Clone to avoid directly modifying our input
#ifdef DF_ENABLE_DISPLACEMENT_FIELD_RESIDUALS
deformation_pyramid.build_from_base_with_residual(df.clone(), filters::downsample_vectorfield_by_2);
#else
deformation_pyramid.build_from_base(df.clone(), filters::downsample_vectorfield_by_2);
#endif

bool use_constraints = false;

if (!constraints_mask.valid()) {
constraints_mask = stk::VolumeUChar(deformation_pyramid.volume(0).size(), uint8_t{0});
}
else {
use_constraints = true;
}

if (!constraints_values.valid()) {
constraints_values = stk::VolumeFloat3(deformation_pyramid.volume(0).size(), float3{0, 0, 0});
}

VolumePyramid constraints_mask_pyramid, constraints_pyramid;
voxel_constraints::build_pyramids(
constraints_mask,
constraints_values,
pyramid_levels,
constraints_mask_pyramid,
constraints_pyramid
);

// Initialization is only needed if we have constraints
if (use_constraints) {
// Perform initialization at the coarsest resolution
stk::VolumeFloat3 def = deformation_pyramid.volume(pyramid_levels-1);
initialize_regularization(
def,
constraints_mask_pyramid.volume(pyramid_levels-1),
constraints_pyramid.volume(pyramid_levels-1)
);
}

for (int l = pyramid_levels-1; l >= 0; --l) {
stk::VolumeFloat3 def = deformation_pyramid.volume(l);

LOG(Info) << "Performing regularization level " << l;

do_regularization(
def,
constraints_mask_pyramid.volume(l),
constraints_pyramid.volume(l),
precision
);

if (l != 0) {
dim3 upsampled_dims = deformation_pyramid.volume(l - 1).size();
deformation_pyramid.set_volume(l - 1,
#ifdef DF_ENABLE_DISPLACEMENT_FIELD_RESIDUALS
filters::upsample_vectorfield(def, upsampled_dims, deformation_pyramid.residual(l - 1))
#else
filters::upsample_vectorfield(def, upsampled_dims)
#endif
);
}
else {
deformation_pyramid.set_volume(0, def);
}
}

return deformation_pyramid.volume(0);
}

19 changes: 8 additions & 11 deletions src/deform_lib/regularize.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
#pragma once

void initialize_regularization(
stk::VolumeFloat3& def,
const stk::VolumeUChar& constraints_mask,
const stk::VolumeFloat3& constraints_values
);
void do_regularization(
stk::VolumeFloat3& def,
const stk::VolumeUChar& constraints_mask,
const stk::VolumeFloat3& constraints_values,
float precision
);
stk::VolumeFloat3 regularization(
const stk::VolumeFloat3& df,
float precision,
int pyramid_levels,
stk::VolumeUChar constraints_mask,
stk::VolumeFloat3 constraints_values
);

0 comments on commit a901bba

Please sign in to comment.