# MBIRJAX CUDA Installation Verification Notebook

This notebook verifies that mbirjax is correctly installed with CUDA support by:
1. Checking package imports
2. Verifying CUDA/GPU availability
3. Running an actual CT reconstruction on synthetic data
4. Visualizing the results

**Expected outcome**: All cells should execute without errors, and the final reconstruction should complete using GPU acceleration.

## 1. Package Imports and Version Check

In [None]:
import sys
import time
import numpy as np

print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")

In [None]:
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")

In [None]:
import mbirjax as mj

mbirjax_version = mj.__version__ if hasattr(mj, '__version__') else 'unknown'
print(f"MBIRJAX version: {mbirjax_version}")

## 2. CUDA/GPU Device Detection

In [None]:
# Check available devices
all_devices = jax.devices()
gpu_devices = [d for d in all_devices if d.platform == 'gpu']
cpu_devices = [d for d in all_devices if d.platform == 'cpu']

print(f"Total devices: {len(all_devices)}")
print(f"GPU devices: {len(gpu_devices)}")
print(f"CPU devices: {len(cpu_devices)}")

if gpu_devices:
    print("\n✓ GPU(s) detected:")
    for i, gpu in enumerate(gpu_devices):
        print(f"  GPU {i}: {gpu}")
else:
    print("\n✗ WARNING: No GPU detected! MBIRJAX will use CPU (slower).")

## 3. JAX GPU Backend Verification

In [None]:
# Test that JAX can actually compute on GPU
if gpu_devices:
    print("Testing JAX GPU computation...")
    
    # Create array and perform computation
    x = jnp.ones((1000, 1000))
    
    with jax.default_device(gpu_devices[0]):
        y = jnp.dot(x, x)
        y.block_until_ready()
    
    print(f"Computation device: {y.device()}")
    print(f"Device platform: {y.device().platform}")
    
    if y.device().platform == 'gpu':
        print("\n✓ JAX GPU backend is working!")
    else:
        print("\n✗ WARNING: Computation did not run on GPU!")
else:
    print("Skipping GPU test - no GPU available")

## 4. Create Synthetic CT Data

We'll create a simple phantom and generate synthetic sinogram data for reconstruction testing.

In [None]:
# Parameters for synthetic data
num_views = 180          # Number of projection angles
num_det_channels = 256   # Detector width (pixels)
num_slices = 4           # Number of slices to reconstruct

# Generate projection angles (0 to 180 degrees)
angles = np.linspace(0, np.pi, num_views, endpoint=False).astype(np.float32)

print(f"Number of views: {num_views}")
print(f"Detector channels: {num_det_channels}")
print(f"Number of slices: {num_slices}")
print(f"Angles range: {np.degrees(angles[0]):.1f}° to {np.degrees(angles[-1]):.1f}°")

In [None]:
# Create a simple phantom with geometric shapes
phantom_size = num_det_channels

# Create coordinate grids
y, x = np.ogrid[-phantom_size//2:phantom_size//2, -phantom_size//2:phantom_size//2]

# Create phantom with a circle and some structure
phantom = np.zeros((num_slices, phantom_size, phantom_size), dtype=np.float32)

# Large circle (body)
radius1 = phantom_size // 3
mask1 = x*x + y*y <= radius1**2
phantom[:, mask1] = 0.5

# Smaller circle (feature inside)
radius2 = phantom_size // 8
offset_x, offset_y = phantom_size // 6, 0
mask2 = (x - offset_x)**2 + (y - offset_y)**2 <= radius2**2
phantom[:, mask2] = 1.0

# Another smaller circle
mask3 = (x + offset_x)**2 + (y - offset_y)**2 <= radius2**2
phantom[:, mask3] = 0.8

print(f"Phantom shape: {phantom.shape}")
print(f"Phantom value range: [{phantom.min():.2f}, {phantom.max():.2f}]")

In [None]:
# Generate synthetic sinogram
# For testing purposes, we create a simplified sinogram
sinogram_shape = (num_views, num_slices, num_det_channels)

np.random.seed(42)
sinogram = np.zeros(sinogram_shape, dtype=np.float32)

# Simple projection approximation (not physically accurate, but good for testing)
for i, angle in enumerate(angles):
    # Rotate and sum
    cos_a, sin_a = np.cos(angle), np.sin(angle)
    projection = np.sum(phantom, axis=2) * abs(cos_a) + np.sum(phantom, axis=1) * abs(sin_a)
    projection = projection[:, :num_det_channels]
    sinogram[i] = projection

# Add small noise
sinogram += np.random.randn(*sinogram_shape).astype(np.float32) * 0.01
sinogram = np.clip(sinogram, 0, None)  # Ensure non-negative

print(f"Sinogram shape: {sinogram.shape}")
print(f"Sinogram value range: [{sinogram.min():.4f}, {sinogram.max():.4f}]")

## 5. Run MBIRJAX Reconstruction

This is the key test - running actual GPU-accelerated reconstruction using mbirjax.

In [None]:
# Create mbirjax model
print("Creating ParallelBeamModel...")
ct_model = mj.ParallelBeamModel(sinogram_shape, angles)

# Configure reconstruction parameters
ct_model.set_params(
    sharpness=0.0,      # Regularization sharpness
    verbose=0,          # Minimal output
    snr_db=30.0,        # Signal-to-noise ratio in dB
)

print("Model created successfully!")
print(f"Sinogram shape: {sinogram_shape}")

In [None]:
# Run reconstruction with timing
print("Running reconstruction...")
start_time = time.time()

recon_result, recon_dict = ct_model.recon(
    sinogram,
    print_logs=False,
    weights=None
)

# Convert to numpy and force synchronization
recon_result = np.array(recon_result)

elapsed_time = time.time() - start_time

print(f"\n✓ Reconstruction completed!")
print(f"  Time: {elapsed_time:.3f} seconds")
print(f"  Output shape: {recon_result.shape}")
print(f"  Output dtype: {recon_result.dtype}")
print(f"  Value range: [{recon_result.min():.4f}, {recon_result.max():.4f}]")

In [None]:
# Verify reconstruction output
print("Verification checks:")

# Check for NaN values
has_nan = np.isnan(recon_result).any()
print(f"  Contains NaN: {has_nan} {'✗ PROBLEM' if has_nan else '✓ OK'}")

# Check for Inf values
has_inf = np.isinf(recon_result).any()
print(f"  Contains Inf: {has_inf} {'✗ PROBLEM' if has_inf else '✓ OK'}")

# Check shape
shape_ok = recon_result.shape[0] == num_slices
print(f"  Shape correct: {shape_ok} {'✓ OK' if shape_ok else '✗ PROBLEM'}")

if not has_nan and not has_inf and shape_ok:
    print("\n✓ All verification checks passed!")
else:
    print("\n✗ Some verification checks failed!")

## 6. Visualize Results

In [None]:
import matplotlib.pyplot as plt

# Select middle slice for visualization
slice_idx = num_slices // 2

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original phantom
im0 = axes[0].imshow(phantom[slice_idx], cmap='viridis')
axes[0].set_title(f'Original Phantom (slice {slice_idx})')
axes[0].axis('off')
plt.colorbar(im0, ax=axes[0], shrink=0.8)

# Sinogram (one slice)
im1 = axes[1].imshow(sinogram[:, slice_idx, :], cmap='viridis', aspect='auto')
axes[1].set_title(f'Sinogram (slice {slice_idx})')
axes[1].set_xlabel('Detector channel')
axes[1].set_ylabel('Projection angle')
plt.colorbar(im1, ax=axes[1], shrink=0.8)

# Reconstructed image
im2 = axes[2].imshow(recon_result[slice_idx], cmap='viridis')
axes[2].set_title(f'MBIRJAX Reconstruction (slice {slice_idx})')
axes[2].axis('off')
plt.colorbar(im2, ax=axes[2], shrink=0.8)

plt.tight_layout()
plt.show()

In [None]:
# Display all reconstructed slices
fig, axes = plt.subplots(1, num_slices, figsize=(4*num_slices, 4))

for i in range(num_slices):
    ax = axes[i] if num_slices > 1 else axes
    im = ax.imshow(recon_result[i], cmap='viridis')
    ax.set_title(f'Slice {i}')
    ax.axis('off')
    plt.colorbar(im, ax=ax, shrink=0.8)

plt.suptitle('All Reconstructed Slices', fontsize=14)
plt.tight_layout()
plt.show()

## 7. Summary

In [None]:
print("=" * 60)
print("MBIRJAX CUDA Installation Verification Summary")
print("=" * 60)
print(f"\nPackage versions:")
print(f"  Python: {sys.version.split()[0]}")
print(f"  NumPy: {np.__version__}")
print(f"  JAX: {jax.__version__}")
print(f"  MBIRJAX: {mbirjax_version}")

print(f"\nHardware:")
print(f"  GPU devices: {len(gpu_devices)}")
if gpu_devices:
    for i, gpu in enumerate(gpu_devices):
        print(f"    GPU {i}: {gpu}")

print(f"\nReconstruction test:")
print(f"  Sinogram shape: {sinogram_shape}")
print(f"  Reconstruction time: {elapsed_time:.3f} seconds")
print(f"  Output shape: {recon_result.shape}")

all_ok = len(gpu_devices) > 0 and not has_nan and not has_inf and shape_ok
print(f"\n{'✓ ALL TESTS PASSED!' if all_ok else '✗ SOME TESTS FAILED'}")
print("=" * 60)