Skip to content

Commit

Permalink
Fix astra wrapper for recent astra versions (#457)
Browse files Browse the repository at this point in the history
* Add support for recent ASTRA versions in wrapper

* Make ASTRA center of rotation correction similar to Tomopy
  • Loading branch information
dmpelt authored and dgursoy committed Jan 16, 2020
1 parent 58b8b2d commit 59aa3f3
Showing 1 changed file with 36 additions and 62 deletions.
98 changes: 36 additions & 62 deletions source/tomopy/recon/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,26 +151,22 @@ def astra(tomo, center, recon, theta, **kwargs):
vol_geom = astra_mod.create_vol_geom((num_gridx, num_gridy))

# Number of GPUs to use
if proj_type == 'cuda':
if opts['gpu_list'] is not None:
import concurrent.futures as cf
gpu_list = opts['gpu_list']
ngpu = len(gpu_list)
_, slcs = mproc.get_ncore_slices(nslices, ngpu)
# execute recon on a thread per GPU
with cf.ThreadPoolExecutor(ngpu) as e:
for gpu, slc in zip(gpu_list, slcs):
e.submit(astra_rec_cuda, tomo[slc], center[slc], recon[slc],
theta, vol_geom, niter, proj_type, gpu, opts)
else:
astra_rec_cuda(tomo, center, recon, theta, vol_geom, niter,
proj_type, None, opts)
if proj_type == 'cuda' and opts['gpu_list'] is not None:
import concurrent.futures as cf
gpu_list = opts['gpu_list']
ngpu = len(gpu_list)
_, slcs = mproc.get_ncore_slices(nslices, ngpu)
# execute recon on a thread per GPU
with cf.ThreadPoolExecutor(ngpu) as e:
for gpu, slc in zip(gpu_list, slcs):
e.submit(astra_rec, tomo[slc], center[slc], recon[slc],
theta, vol_geom, niter, proj_type, gpu, opts)
else:
astra_rec_cpu(tomo, center, recon, theta, vol_geom, niter,
proj_type, opts)
astra_rec(tomo, center, recon, theta, vol_geom, niter,
proj_type, None, opts)


def astra_rec_cuda(tomo, center, recon, theta, vol_geom, niter, proj_type, gpu_index, opts):
def astra_rec(tomo, center, recon, theta, vol_geom, niter, proj_type, gpu_index, opts):
# Lazy import ASTRA
import astra as astra_mod
nslices, nang, ndet = tomo.shape
Expand All @@ -182,19 +178,35 @@ def astra_rec_cuda(tomo, center, recon, theta, vol_geom, niter, proj_type, gpu_i
cfg['option'] = {}
if gpu_index is not None:
cfg['option']['GPUindex'] = gpu_index
oc = None
const_theta = np.ones(nang)
proj_geom = astra_mod.create_proj_geom(
'parallel', 1.0, ndet, theta.astype(np.float64))
if hasattr(astra_mod, 'geom_postalignment'):
proj_geom_orig = proj_geom
for i in range(nslices):
if center[i] != oc:
oc = center[i]
proj_geom['option'] = {
'ExtraDetectorOffset':
(center[i] - ndet / 2.) * const_theta}
sino = tomo[i]
if proj_type=='cuda':
if hasattr(astra_mod, 'geom_postalignment'):
proj_geom = astra_mod.geom_postalignment(proj_geom_orig, -(center[i] - ndet / 2.))
else:
proj_geom['option'] = {
'ExtraDetectorOffset':
(center[i] - ndet / 2.) * const_theta}
else:
shft = int(np.round(ndet / 2. - center[i]))
if not shft == 0:
sino = np.roll(tomo[i], shft)
l = shft
r = ndet + shft
if l < 0:
l = 0
if r > ndet:
r = ndet
sino[:, :l] = 0
sino[:, r:] = 0
pid = astra_mod.create_projector(proj_type, proj_geom, vol_geom)
cfg['ProjectorId'] = pid
sid = astra_mod.data2d.link('-sino', proj_geom, tomo[i])
sid = astra_mod.data2d.link('-sino', proj_geom, sino)
cfg['ProjectionDataId'] = sid
vid = astra_mod.data2d.link('-vol', vol_geom, recon[i])
cfg['ReconstructionDataId'] = vid
Expand All @@ -206,44 +218,6 @@ def astra_rec_cuda(tomo, center, recon, theta, vol_geom, niter, proj_type, gpu_i
astra_mod.projector.delete(pid)


def astra_rec_cpu(tomo, center, recon, theta, vol_geom, niter, proj_type, opts):
# Lazy import ASTRA
import astra as astra_mod
nslices, nang, ndet = tomo.shape
cfg = astra_mod.astra_dict(opts['method'])
if 'extra_options' in opts:
cfg['option'] = opts['extra_options']
proj_geom = astra_mod.create_proj_geom(
'parallel', 1.0, ndet, theta.astype(np.float64))
pid = astra_mod.create_projector(proj_type, proj_geom, vol_geom)
sino = np.zeros((nang, ndet), dtype=np.float32)
sid = astra_mod.data2d.link('-sino', proj_geom, sino)
cfg['ProjectorId'] = pid
cfg['ProjectionDataId'] = sid
for i in range(nslices):
shft = int(np.round(ndet / 2. - center[i]))
if not shft == 0:
sino[:] = np.roll(tomo[i], shft)
l = shft
r = ndet + shft
if l < 0:
l = 0
if r > ndet:
r = ndet
sino[:, :l] = 0
sino[:, r:] = 0
else:
sino[:] = tomo[i]
vid = astra_mod.data2d.link('-vol', vol_geom, recon[i])
cfg['ReconstructionDataId'] = vid
alg_id = astra_mod.algorithm.create(cfg)
astra_mod.algorithm.run(alg_id, niter)
astra_mod.algorithm.delete(alg_id)
astra_mod.data2d.delete(vid)
astra_mod.data2d.delete(sid)
astra_mod.projector.delete(pid)


def _process_data(input_task, output_task, sinograms, slices):
import ufo.numpy as unp
num_sinograms, num_projections, width = sinograms.shape
Expand Down

0 comments on commit 59aa3f3

Please sign in to comment.