Skip to content

Commit

Permalink
ENH: separated atlas into its own directory and added option
Browse files Browse the repository at this point in the history
  • Loading branch information
rkwitt committed Mar 31, 2017
1 parent c208896 commit bb244a4
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 51 deletions.
108 changes: 57 additions & 51 deletions code/qs_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
requiredNamed = parser.add_argument_group('required named arguments')

requiredNamed.add_argument('--moving-image', nargs='+', required=True, metavar=('m1', 'm2, m3...'),
help='List of moving images, seperated by space.')
help='List of moving images, seperated by space.')
requiredNamed.add_argument('--target-image', nargs='+', required=True, metavar=('t1', 't2, t3...'),
help='List of target images, seperated by space.')
help='List of target images, seperated by space.')
requiredNamed.add_argument('--output-prefix', nargs='+', required=True, metavar=('o1', 'o2, o3...'),
help='List of registration output prefixes for every moving/target image pair, seperated by space. Preferred to be a directory (e.g. /some_path/output_dir/)')
help='List of registration output prefixes for every moving/target image pair, seperated by space. Preferred to be a directory (e.g. /some_path/output_dir/)')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for prediction network (default: 64)')
parser.add_argument('--n-GPU', type=int, default=1, metavar='N',
Expand All @@ -50,45 +50,47 @@
help='time steps for geodesic shooting. Ignore this option to use the default step size used by the registration model.')
parser.add_argument('--affine-align', action='store_true', default=False,
help='Perform affine registration to align moving and target images to ICBM152 atlas space. Require niftireg.')
parser.add_argument('--atlas', default="../data/atlas/icbm152.nii",
help="Atlas to use for (affine) pre-registration")

args = parser.parse_args()


# check validity of input arguments from command line
def check_args(args):
# number of input images/output prefix consistency check
n_moving_images = len(args.moving_image)
n_target_images = len(args.target_image)
n_output_prefix = len(args.output_prefix)
if (n_moving_images != n_target_images):
print('The number of moving images is not consistent with the number of target images!')
sys.exit(1)
elif (n_moving_images != n_output_prefix ):
print('The number of output prefix is not consistent with the number of input images!')
sys.exit(1)

# number of GPU check (positive integers)
if (args.n_GPU <= 0):
print('Number of GPUs must be positive!')
sys.exit(1)

# geodesic shooting step check (positive integers)
if (args.shoot_steps < 0):
print('Shooting steps (--shoot-steps) is negative. Using model default step.')
# number of input images/output prefix consistency check
n_moving_images = len(args.moving_image)
n_target_images = len(args.target_image)
n_output_prefix = len(args.output_prefix)
if (n_moving_images != n_target_images):
print('The number of moving images is not consistent with the number of target images!')
sys.exit(1)
elif (n_moving_images != n_output_prefix ):
print('The number of output prefix is not consistent with the number of input images!')
sys.exit(1)

# number of GPU check (positive integers)
if (args.n_GPU <= 0):
print('Number of GPUs must be positive!')
sys.exit(1)

# geodesic shooting step check (positive integers)
if (args.shoot_steps < 0):
print('Shooting steps (--shoot-steps) is negative. Using model default step.')
#enddef


def create_net(args, network_config):
net_single = prediction_network.net(network_config['network_feature']).cuda();
net_single.load_state_dict(network_config['state_dict'])
net_single = prediction_network.net(network_config['network_feature']).cuda();
net_single.load_state_dict(network_config['state_dict'])

if (args.n_GPU > 1) :
device_ids=range(0, args.n_GPU)
net = torch.nn.DataParallel(net_single, device_ids=device_ids).cuda()
else:
net = net_single
if (args.n_GPU > 1) :
device_ids=range(0, args.n_GPU)
net = torch.nn.DataParallel(net_single, device_ids=device_ids).cuda()
else:
net = net_single

return net;
return net;
#enddef


Expand All @@ -101,47 +103,48 @@ def preprocess_image(image_pyca):


def write_result(result, output_prefix):
common.SaveITKImage(result['I1'], output_prefix+"I1.mhd")
common.SaveITKField(result['phiinv'], output_prefix+"phiinv.mhd")
common.SaveITKImage(result['I1'], output_prefix+"I1.mhd")
common.SaveITKField(result['phiinv'], output_prefix+"phiinv.mhd")
#enddef


#perform deformation prediction
def predict_image(args):
#initialize the network
#create prediction network
if (args.use_CPU_for_shooting):
mType = ca.MEM_HOST
mType = ca.MEM_HOST
else:
mType = ca.MEM_DEVICE
mType = ca.MEM_DEVICE

# load the prediction network
predict_network_config = torch.load('../network_configs/OASIS_predict.pth.tar')

prediction_net = create_net(args, predict_network_config);

batch_size = args.batch_size
patch_size = predict_network_config['patch_size']
input_batch = torch.zeros(batch_size, 2, patch_size, patch_size, patch_size).cuda()

# use correction network if required
if args.use_correction:
correction_network_config = torch.load('../network_configs/OASIS_correct.pth.tar');
correction_net = create_net(args, correction_network_config);
correction_network_config = torch.load('../network_configs/OASIS_correct.pth.tar');
correction_net = create_net(args, correction_network_config);
else:
correction_net = None;
correction_net = None;

# start prediction
for i in range(0, len(args.moving_image)):

common.Mkdir_p(os.path.dirname(args.output_prefix[i]))
if (args.affine_align):
# Perform affine registration to both moving and target image to the ICBM152 atlas space.
# Registration is done using Niftireg.
call(["reg_aladin",
"-noSym", "-speeeeed", "-ref", "icbm152.nii" ,
"-noSym", "-speeeeed", "-ref", args.atlas ,
"-flo", args.moving_image[i],
"-res", args.output_prefix[i]+"moving_affine.nii",
"-aff", args.output_prefix[i]+'moving_affine_transform.txt'])

call(["reg_aladin",
"-noSym", "-speeeeed" ,"-ref", "icbm152.nii" ,
"-noSym", "-speeeeed" ,"-ref", args.atlas ,
"-flo", args.target_image[i],
"-res", args.output_prefix[i]+"target_affine.nii",
"-aff", args.output_prefix[i]+'target_affine_transform.txt'])
Expand All @@ -151,7 +154,8 @@ def predict_image(args):
else:
moving_image = common.LoadITKImage(args.moving_image[i], mType)
target_image = common.LoadITKImage(args.target_image[i], mType)
#preprocessing of the image

#preprocessing of the image
moving_image_np = preprocess_image(moving_image);
target_image_np = preprocess_image(target_image);

Expand All @@ -163,21 +167,23 @@ def predict_image(args):
moving_image.setGrid(grid)
target_image.setGrid(grid)

m0 = util.predict_momentum(moving_image_np, target_image_np, input_batch, batch_size, patch_size, prediction_net);
# run actual prediction
m0 = util.predict_momentum(moving_image_np, target_image_np, input_batch, batch_size, patch_size, prediction_net);

#convert to registration space and perform registration
m0_reg = common.FieldFromNPArr(m0, mType);
registration_result = registration_methods.geodesic_shooting(moving_image, target_image, m0_reg, args.shoot_steps, mType, predict_network_config)
if (args.use_correction):
#perform correction
#convert to registration space and perform registration
m0_reg = common.FieldFromNPArr(m0, mType);
registration_result = registration_methods.geodesic_shooting(moving_image, target_image, m0_reg, args.shoot_steps, mType, predict_network_config)

#perform correction
if (args.use_correction):
target_inv_np = common.AsNPCopy(registration_result['I1_inv'])
m0_correct = util.predict_momentum(moving_image_np, target_inv_np, input_batch, batch_size, patch_size, correction_net);
m0 += m0_correct;
m0_reg = common.FieldFromNPArr(m0, mType);
registration_result = registration_methods.geodesic_shooting(moving_image, target_image, m0_reg, args.shoot_steps, mType, predict_network_config)
#endif
#endif

write_result(registration_result, args.output_prefix[i]);
write_result(registration_result, args.output_prefix[i]);
#enddef


Expand Down
Binary file modified code/vectormomentum/Code/Python/Libraries/CAvmCommon.pyc
Binary file not shown.
Binary file modified code/vectormomentum/Code/Python/Libraries/__init__.pyc
Binary file not shown.
File renamed without changes.

0 comments on commit bb244a4

Please sign in to comment.