In [5]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import mbirjax
import os

# Define parameters for sinogram
num_views = 64
num_det_rows = 40
num_det_channels = 128
geometry_type = 'parallel'

# Define projection angles
start_angle = -np.pi / 2
end_angle = np.pi / 2
angles = torch.linspace(start_angle, end_angle, num_views)

# Initialize MBIR-JAX model
sinogram_shape = (num_views, num_det_rows, num_det_channels)
ct_model = mbirjax.ParallelBeamModel(sinogram_shape, angles.numpy())

# Generate Shepp-Logan phantom
print("Generating Phantom (x)...")
true_image = torch.tensor(np.array(ct_model.gen_modified_3d_sl_phantom()), dtype=torch.float32).clone()

# Compute forward projection (A*x = y)
print("Computing Sinogram (y = A*x)...")
y = torch.tensor(np.array(ct_model.forward_project(true_image.numpy())), dtype=torch.float32).clone()

# Save dataset (X-ray images)
dataset_dir = "xray_dataset"
os.makedirs(dataset_dir, exist_ok=True)
dataset_path = os.path.join(dataset_dir, "shepp_logan_xray.npz")
np.savez(dataset_path, sinogram=y.numpy(), angles=angles.numpy())

print(f"X-ray dataset saved to {dataset_path}")

# Load dataset
loaded_data = np.load(dataset_path)
y = torch.tensor(loaded_data["sinogram"], dtype=torch.float32).clone()
angles = torch.tensor(loaded_data["angles"], dtype=torch.float32)

# **MBIR-Based Iterative Reconstruction**
num_iterations = 2000  # Iteration count
learning_rate = .5  # Learning rate

# Initialize reconstructed image with zeros
x = torch.zeros_like(true_image)

print("Starting MBIR Reconstruction...")
for i in range(num_iterations):
    Ax = torch.tensor(np.array(ct_model.forward_project(x.numpy())), dtype=torch.float32).clone()  # Compute Ax (forward projection)
    residual = y - Ax  # Compute (y - Ax)
    
    # Compute loss ||y - Ax||^2
    loss = torch.norm(residual) ** 2
    
    # Compute gradient: Aᵀ(y - Ax) (back-projection step)
    gradient = torch.tensor(np.array(ct_model.back_project(residual.numpy())), dtype=torch.float32).clone()
    
    # Normalize gradient to avoid large updates
    gradient = gradient / (torch.max(torch.abs(gradient)) + 1e-6)
    
    # Adaptive step size to avoid divergence
    step_size = learning_rate / (torch.norm(gradient) + 1e-6)
    
    # Update x using gradient descent
    x = x + step_size * gradient
    
    # Print loss every 10 iterations
    if i % 10 == 0:
        print(f"Iteration {i}: Loss = {loss:.6f}")

# Convert to NumPy for visualization
mbir_reconstruction = x.numpy()

# Display Results
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(true_image.numpy()[:, :, true_image.shape[2] // 2], cmap='gray')
plt.title("Ground Truth (Phantom)")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(mbir_reconstruction[:, :, mbir_reconstruction.shape[2] // 2], cmap='gray')
plt.title("MBIR Reconstruction")
plt.axis("off")

plt.show()


Estimated memory required = 0.183GB full, 0.012GB update
Using TFRT_CPU_0 for main memory, TFRT_CPU_0 as worker.
views_per_batch = 256; pixels_per_batch = 2048
Generating Phantom (x)...
Computing Sinogram (y = A*x)...
X-ray dataset saved to xray_dataset/shepp_logan_xray.npz
Starting MBIR Reconstruction...
Iteration 0: Loss = 67116960.000000
Iteration 10: Loss = 60805316.000000
Iteration 20: Loss = 54846480.000000
Iteration 30: Loss = 49229312.000000


KeyboardInterrupt: 