In [None]:
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
import cupy as cp

In [None]:
import ct_projector.projector.cupy as ct_projector
import ct_projector.projector.cupy.tomo as ct_tomo

In [None]:
# load a sample CT image
img = np.load('./CTA0296_7_2.npy')
spacing = [1, 1, 1]
img = img[np.newaxis, ...]

In [None]:
# show the ct images
plt.figure(figsize = (12,4))
plt.subplot(131); plt.imshow(img[0, img.shape[1]//2, ...], 'gray', aspect=spacing[1] / spacing[2])
plt.subplot(132); plt.imshow(img[0, :, img.shape[2]//2, :], 'gray', aspect=spacing[0] / spacing[2])
plt.subplot(133); plt.imshow(img[0, ..., img.shape[3]//2], 'gray', aspect=spacing[0] / spacing[1])

In [None]:
# setup the projector
projector = ct_projector.ct_projector()
projector.from_file('./tomo.cfg')
projector.nx = img.shape[3]
projector.ny = img.shape[2]
projector.nz = img.shape[1]
projector.dx = spacing[2]
projector.dy = spacing[1]
projector.dz = spacing[0]
projector.cx = 0
projector.cy = img.shape[2] * spacing[2] / 2
projector.cz = 11

for k in vars(projector):
    print (k, '=', getattr(projector, k))

In [None]:
# setup the positions of projections, let's do 0, 45, 90, and 135
angles = np.arange(-12, 13, 3) * np.pi / 180

srcs = np.array([projector.dso * np.sin(angles),
                 [0] * len(angles), 
                 projector.dso * np.cos(angles)]).T

det_centers = np.array([[0] * len(angles),
                        [projector.nv * projector.dv / 2] * len(angles), 
                        [projector.dso - projector.dsd] * len(angles)]).T

det_us = np.array([1, 0, 0] * len(angles)).T
det_vs = np.array([0, 1, 0] * len(angles)).T

In [None]:
# very important: use np.copy('C') so that the arrays are saved in C order
cp.cuda.Device(0).use()
ct_projector.set_device(0)

cuimg = cp.array(img, cp.float32, order = 'C')
cusrcs = cp.array(srcs, cp.float32, order = 'C')
cudet_centers = cp.array(det_centers, cp.float32, order = 'C')

In [None]:
projector.set_projector(ct_tomo.distance_driven_fp, det_center=cudet_centers, src=cusrcs)
projector.set_backprojector(ct_tomo.distance_driven_bp, det_center=cudet_centers, src=cusrcs)

In [None]:
# set which device to use
# forward projection
cufp = projector.fp(cuimg)

In [None]:
fp = cufp.get()
plt.imshow(fp[0,4,...], 'gray')

In [None]:
cubp = projector.bp(cufp)

In [None]:
bp = cubp.get()
plt.figure(figsize = (12,4))
plt.subplot(131); plt.imshow(bp[0, bp.shape[1]//2, ...], 'gray', aspect=spacing[1] / spacing[2])
plt.subplot(132); plt.imshow(bp[0, :, bp.shape[2]//2, :], 'gray', aspect=spacing[0] / spacing[2])
plt.subplot(133); plt.imshow(bp[0, ..., bp.shape[3]//2], 'gray', aspect=spacing[0] / spacing[1])