### **Introduction**
The data consists of 1601 cone-beam projections with a detector of size 1006x1006. The scanned specimen was [foraminifera](https://en.wikipedia.org/wiki/Foraminifera), a microscopic marine organism.

In [None]:
INPUT_PATH = '/dtu/3d-imaging-center/courses/CIL-QIM25_workshop/data/foraminifera/Amphi_13363_10X-40kV-LE1-20s-1p45micro.txrm'
CENTRE_OF_ROTATION_OFFSET = 0.046424131475367544

As a baseline, we will use the fast FDK algorithm for cone-beam filtered backprojection. This is usually sufficient when projections are plentiful. When that is not the case, however, the reconstruction can be sensitive to noise and exhibit streak artifacts.

Optimisation-based iterative algorithms offer an alternative by balancing fidelity with regularization. In this notebook, we compare two such approaches:
- FISTA with a nonnegativity constraint
- FISTA with total variation (TV) regularization (and nonnegativity)

### **Module imports**

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import qim3d

# CIL imports
import cil
from cil.io import ZEISSDataReader
from cil.framework import ImageGeometry, AcquisitionGeometry, BlockDataContainer
from cil.processors import TransmissionAbsorptionConverter
from cil.recon import FDK
from cil.optimisation.algorithms import GD, FISTA, SIRT, CGLS
from cil.optimisation.functions import LeastSquares, IndicatorBox, TotalVariation
from cil.optimisation.operators import BlockOperator, GradientOperator, IdentityOperator
from cil.plugins.tigre import ProjectionOperator
from cil.plugins.ccpi_regularisation.functions import FGP_TV
from cil.utilities.display import show_geometry, show2D
from cil.optimisation.utilities import callbacks

### **User parameters**

Here we introduce the configurable parameters that control the data used for the reconstruction experiments. The primary aim is to study how projection subsampling affects different algorithms and the resulting image analysis. Additional options for detector subsampling and vertical cropping allow faster processing during testing by reducing data size.

The parameters are:
- `PROJECTION_SUBSAMPLING`: controls the spacing between the projections. A value of $1$ corresponds to using all projections and a value of $5$ would reduce the projections by a factor $5$.
- `DETECTOR_SUBSAMPLING`: controls the spacing between pixels in both horizontal and vertical directions on of the detector array. A value of $1$ uses all pixels, and a value of $2$ reduces the total number of pixels by a factor of $2^2$.
- `VERTICAL_CROP_RATIO`: fraction of slices to keep as a centered vertical crop of the volume. A value of $1$ corresponds to no cropping, and a value of $0.5$ would discard $25\%$ from top and bottom. To speed up computations during testing, a value such as $0.05$ can be used, though very small values (that would result in a single slice) will cause errors in the CIL reader.

In [None]:
PROJECTION_SUBSAMPLING = 10
DETECTOR_SUBSAMPLING = 1
VERTICAL_CROP_RATIO = 0.05

### **Data reading and preprocessing**

In [None]:
DataReader = ZEISSDataReader

def get_pixel_nums():
    reader = DataReader(file_name=INPUT_PATH)
    num_pixels_horizontal = reader.get_geometry().pixel_num_h
    num_pixels_vertical = reader.get_geometry().pixel_num_v
    return num_pixels_horizontal, num_pixels_vertical

num_pixels_horizontal, num_pixels_vertical = get_pixel_nums()

slice_dict = {
    'angle': (None, None, PROJECTION_SUBSAMPLING),
    'vertical': (
        int((1 - VERTICAL_CROP_RATIO) * num_pixels_vertical // 2),
        int((1 + VERTICAL_CROP_RATIO) * num_pixels_vertical // 2),
        DETECTOR_SUBSAMPLING,
    ),
    'horizontal': (None, None, DETECTOR_SUBSAMPLING),
}
reader = DataReader(file_name=INPUT_PATH, roi=slice_dict)

In [None]:
slice_dict

In [None]:
data = reader.read() # may result in error if there is only one slice in the vertical direction

In [None]:
print(data)

In [None]:
ag = data.geometry
horizontal_index = data.get_data_axes_order().index('horizontal')
scaled_offset = CENTRE_OF_ROTATION_OFFSET * (data.shape[horizontal_index] / num_pixels_horizontal)
ag.set_centre_of_rotation(offset=scaled_offset)

ig = ag.get_ImageGeometry()

We configured the acquisition geometry with a centre of rotation offset that was precalculated by optimizing with the `image_sharpness` method from the CIL `CentreOfRotationCorrector`. Because this step is compute-intensive, the result was stored and hardcoded.

In [None]:
# qim3d.viz.slicer(data.array, color_map='grey', color_bar='slices', image_height=6, image_width=6)

In [None]:
data = TransmissionAbsorptionConverter(accelerated=False)(data)

### **FDK reconstruction**

In [None]:
data.reorder(order='tigre')
recon = FDK(data).run(verbose=0)

In [None]:
qim3d.viz.slicer(recon.array, color_map='grey', color_bar='slices', image_height=6, image_width=6)

### **Optimization-based iterative algorithms**

In [None]:
A = ProjectionOperator(image_geometry=ig, acquisition_geometry=ag)
x0 = ig.allocate(0.0)

### **FISTA with nonnegativity constraint**

In [None]:
F = LeastSquares(A, data)
G = IndicatorBox(lower=0.0, accelerated=False) # nonnegativity

`update_objective_interval` can be set to a higher value for faster iterations but less frequent updates about the objective values.

In [None]:
fista_nn = FISTA(f=F, g=G, initial=x0, update_objective_interval=1)

You can change the `iterations` parameter and rerun the cell below to continue performing iterations if the algorithm has not yet converged.

In [None]:
fista_nn.run(iterations=20, verbose=1, callbacks=[callbacks.TextProgressCallback()])

plt.plot(fista_nn.objective)
plt.gca().set_yscale('log')
plt.xlabel('Number of iterations')
plt.ylabel('Objective value')
plt.grid()

In [None]:
qim3d.viz.slicer(fista_nn.solution.array, color_map='grey', color_bar='slices', image_height=6, image_width=6)

### **FISTA with TV regularization**

The parameter `alpha` can be modified to control the regularization strength.

In [None]:
alpha = 0.01
F = LeastSquares(A, data)
G = FGP_TV(alpha, device='gpu', nonnegativity=True)

In [None]:
fista_tv = FISTA(f=F, g=G, initial=x0, update_objective_interval=1)

In [None]:
fista_tv.run(iterations=20, verbose=1, callbacks=[callbacks.TextProgressCallback()])

plt.plot(fista_tv.objective)
plt.gca().set_yscale('log')
plt.xlabel('Number of iterations')
plt.ylabel('Objective value')
plt.grid()

In [None]:
qim3d.viz.slicer(fista_tv.solution.array, color_map='grey', color_bar='slices', image_height=6, image_width=6)

### **Export the reconstruction**

Choose which volume to be exported and used for the subsequent image analysis. Uncomment one of the options by removing the # in front.

In [None]:
#recon_to_save = recon.array # FDK
#recon_to_save = fista_nn.solution.array # Fista with nonnegativity
#recon_to_save = fista_tv.solution.array # Fista with total variation

Optionally set a preferred filename:

In [None]:
filename = './recon.tif'

In [None]:
qim3d.io.save(filename, recon_to_save, replace=True)