# Model-based CT reconstruction example

This notebooks solves a basic CT reconstruction problem of MNIST data using torchskadon in conjunction with torch.optim.

In [None]:
%pip install torch matplotlib
%pip install -i https://test.pypi.org/simple/ torchskradon

In [None]:
import os

import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from torchskradon.functional import skradon

We choose our computing device and load the MNIST test dataset.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST dataset
os.makedirs(os.path.join("data"), exist_ok=True)
test_dataset = torchvision.datasets.MNIST(
    root=os.path.join("data"),
    train=False,
    download=True,
    transform=transforms.ToTensor(),
)

batch_size = 8

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Load the first batch from the dataloader
data_iter = iter(test_loader)
images, _ = next(data_iter)

In [None]:
fig, axs = plt.subplots(1, batch_size, figsize=(10, 10))
for i in range(batch_size):
    axs[i].imshow(images[i, 0].cpu().detach().numpy(), cmap="gray")
    axs[i].set_title(f"Image {i + 1}")
    axs[i].axis("off")
plt.show()

Now we create our CT measurement data from these MNIST images. For the sake of brevity, we cheat and use the same discretization for simulation and (later) reconstruction. 

In [None]:
theta = torch.linspace(0.0, 180.0, 181)[:-1] # Angles from 0 to 179 degrees (see symmetry of Radon transform)
sinograms = skradon(images, theta=theta, circle=False)

In [None]:
fig, axs = plt.subplots(1, batch_size, figsize=(30, 30))
for i in range(batch_size):
    axs[i].imshow(sinograms[i, 0].cpu().detach().numpy(), cmap="gray")
    axs[i].set_title(f"Sinogram {i + 1}")
    axs[i].axis("off")
plt.show()

We wrap the Radon transform in a torch.nn.module. This is not necessary for our purpose, but nicely demonstrates how you can use torchskradon functions in the context of PyTorch's model logic. 

In [None]:
class RadonModel(torch.nn.Module):
    def __init__(self, theta):
        super().__init__()
        self.theta = theta

    def forward(self, x):
        return skradon(x, theta=self.theta, circle=False)

Finally, we start training our model to reconstruct CT images from their respective sinograms.

In [None]:
# Move data to device
images = images.to(device)
sinograms = sinograms.to(device)

# Initialize reconstruction with zeros (or random initialization)
reco = torch.zeros_like(images, requires_grad=True, device=device)

# Set up optimizer
optimizer = torch.optim.Adam([reco], lr=0.01)
loss_fn = torch.nn.MSELoss()
model = RadonModel(theta=theta).to(device)

print("Starting image reconstruction...")
print(f"Target sinogram shape: {sinograms.shape}")
print(f"Reconstruction shape: {reco.shape}")

# Reconstruction loop
for i in range(200):
    optimizer.zero_grad()

    # Forward pass: compute sinogram of current reconstruction
    pred_sinogram = skradon(reco, theta=theta, circle=False)

    # Compute loss between predicted and target sinograms
    loss_value = loss_fn(pred_sinogram, sinograms)

    # Backward pass
    loss_value.backward()

    # Update reconstruction
    optimizer.step()

    if i % 10 == 0:
        print(f"Iteration {i}, Loss: {loss_value.item():.6f}")

print("Reconstruction completed!")

In [None]:
# Visualize results
with torch.no_grad():
    mse = torch.nn.functional.mse_loss(reco, images)
    print(f"Reconstruction MSE: {mse.item():.6f}")
    fig, axes = plt.subplots(2, batch_size, figsize=(12, 6))
    for i in range(batch_size):
        # Ground-Truth image
        axes[0, i].imshow(images[i, 0].cpu().numpy(), cmap="gray")
        axes[0, i].set_title(f"Ground-Truth {i}")
        axes[0, i].axis("off")

        # Reconstructed image
        axes[1, i].imshow(reco[i, 0].cpu().numpy(), cmap="gray")
        axes[1, i].set_title(f"Reconstruction {i}")
        axes[1, i].axis("off")

    plt.tight_layout()
    plt.show()
