# Introduction to CT

In this exercise sheet we will get to know the Computed Tomography reconstruction problem

## Load Data

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
%matplotlib inline

batch_size = 4

# datasets (MNIST)
transform_test = transforms.Compose([
    transforms.ToTensor()
])
mnist_test = datasets.MNIST('/data', train=False, download=True, transform=transform_test)

# dataloaders
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size)

batch, labels = next(iter(test_loader))

phantoms = [batch[i][0].numpy() for i in range(batch_size)]


##########################################
# TODO: show first phantom
...
##########################################

## Computed Tomography

In computed tomography, the tomography reconstruction problem is to obtain a tomographic slice image from a set of projections 1. A projection is formed by drawing a set of parallel rays through the 2D object of interest, assigning the integral of the object’s contrast along each ray to a single pixel in the projection. A single projection of a 2D object is one dimensional. To enable computed tomography reconstruction of the object, several projections must be acquired, each of them corresponding to a different angle between the rays with respect to the object. A collection of projections at several angles is called a sinogram, which is a linear transform of the original image.

In [None]:
import numpy as np
from skimage.transform import radon

n, m = 28, 28

##########################################
# TODO: specify number of angles!
angles = ...
##########################################

detectors = 40

##########################################
# TODO: create operator matrix filled with zeros
operator = ...
##########################################

theta = np.linspace(0.0, 180.0, angles, endpoint=False)

for i in range(n * m):
    unit = np.zeros(n * m)
    unit[i] = 1
    operator[:, i] = radon(unit.reshape(n, m), theta, circle=False).reshape(-1)

In [None]:
sinograms = []
for phantom in phantoms:
    plt.figure(figsize=(10, 4))

    # clean image
    plt.subplot(1, 3, 1)
    plt.imshow(phantom)

    # clean sinogram
    
    ##########################################
    # TODO: multiply operator matrix with phantom to get sinogram
    sinogram = ...
    ##########################################
    
    plt.subplot(1, 3, 2)
    plt.title('Clean sinogram')
    plt.imshow(sinogram)

    # noisy sinogram
    plt.subplot(1, 3, 3)
    
    ##########################################
    # TODO: add noise to the sinogram
    sinogram += ...
    ##########################################
    
    sinograms.append(sinogram)
    plt.title('Noisy sinogram')
    plt.imshow(sinogram)

# Direct Inverse

In [None]:
from skimage.measure import compare_psnr


plt.figure(figsize=(15, 4))
for i, phantom in enumerate(phantoms):
    plt.subplot(1, len(phantoms), i+1)
    
    ##########################################
    # TODO: compute direct inverse by inverting the matrix
    x_rec = ...
    ##########################################
    
    plt.xlabel('PSNR: %.2f' % compare_psnr(phantom, x_rec))
    plt.imshow(x_rec)

# Filtered Back Projection

In [None]:
from skimage.transform import iradon

plt.figure(figsize=(15, 4))

for i, phantom in enumerate(phantoms):
    plt.subplot(1, len(phantoms), i+1)
    x_rec = iradon(sinograms[i], theta, circle=False)
    plt.xlabel('PSNR: %.2f' % compare_psnr(phantom, x_rec))
    plt.imshow(x_rec)

# TSVD

In [None]:
##########################################
# TODO: Compute SVD of the operator and plot the singular values
U, S, V = ...

##########################################

In [None]:
def truncated_svd(U, S, V, y, k):
    S_inv = []
    sigma = np.zeros((V.shape[0], U.shape[0]))
    
    for i in range(len(S)):
        if i < k and S[i] > 1e-9:
            sigma[i,i] = 1/S[i]
        else:
            sigma[i,i] = 0
    
    A_inv = np.dot(np.dot(V.T, sigma), U.T)
    return np.dot(A_inv, y)


for k in [len(S)//8, len(S)//4, len(S)//2, 3*len(S)//4, 7*len(S)//8, len(S)]:
    plt.figure(figsize=(15, 4))

    for i, phantom in enumerate(phantoms):
        plt.subplot(1, len(phantoms), i+1)
        x_rec = truncated_svd(U, S, V, sinograms[i].reshape(-1), k).reshape(n,m)
        plt.title(r'$k=%d$' % k)
        plt.xlabel('PSNR: %.2f' % compare_psnr(phantom, x_rec))
        plt.imshow(x_rec)
        
    plt.show()