# Tomo-SIREN 

Made for Jakob, to demystify neural fields.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vedranaa/fibre-pack/blob/main/fibre_packer_demo.ipynb)


Author: vand@dtu.dk, 2025


In [1]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
# !pip install phantominator -q
import phantominator
import scipy.interpolate

First, I need an image and its projections. I will work with normalized coordinates, so the image is defined on [-1, 1]x[-1, 1]. 

In [None]:
def project_image(image, thetas, nr_B=128, nr_S=None):
    '''Forward project the (square) image.
    
    image: 2D grayscale image to project
    thetas: projection angles in radians
    nr_B : int, number of detector bins
    nr_S : int, number of samples along every projection ray, defaults to nr_B 

    Returns: nr_B projection values
    ''' 

    nr_S = nr_S or nr_B
    r, c = image.shape

    # Define interpolation grid in normalized coordinates
    b, s = np.linspace(-1, 1, nr_B), np.linspace(-1, 1, nr_S)
    B, S = np.meshgrid(b, s, indexing='ij')  # this indexing is used by torch

    # Rotate interpolation grid
    cos_a, sin_a = np.cos(thetas), np.sin(thetas)
    X = S[..., None] * cos_a - B[..., None] * sin_a
    Y = S[..., None] * sin_a + B[..., None] * cos_a

    # Interpolate using normalized coordinates
    interp = scipy.interpolate.RegularGridInterpolator(
        (np.linspace(-1, 1, r), np.linspace(-1, 1, c)), 
        image,
        bounds_error=False,  # don't raise error for out-of-bounds
        fill_value=0 # fill with 0 for out-of-bounds
    )
    val = interp((X, Y))
    p = val.mean(axis=0).T  # Projections in rows of sinogram
    return p
    


image = phantominator.shepp_logan(256)
nr_B = 128
nr_thetas = 90
thetas = np.linspace(0., np.pi, nr_thetas, endpoint=False)
sinogram = project_image(image, thetas, nr_B=100)

fig, ax = plt.subplots(1, 2)
ax[0].imshow(image)
ax[0].set_title('GT image')
ax[1].imshow(sinogram)
ax[1].set_title('GT sinogram')
plt.show()

Then, I need a network. 

In [None]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega.
    
    # If is_first=True, omega is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, indim, outdim, bias=True, is_first=False, omega=30):
        super().__init__()
        self.omega = omega
        self.is_first = is_first
        self.indim = indim
        self.linear = nn.Linear(indim, outdim, bias=bias)

        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.indim, 1 / self.indim)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.indim) / self.omega, 
                                             np.sqrt(6 / self.indim) / self.omega)
        
    def forward(self, input):
        return torch.sin(self.omega * self.linear(input))  # Here is sine activation!!!
    
    
class Siren(nn.Module):
    def __init__(self, indim=2, outdim=1, nr_hidden=3, hiddendim=256, 
                 outermost_linear=True, first_omega=30, hidden_omega=30):
        super().__init__()

        layers = []
        layers.append(SineLayer(indim, hiddendim, is_first=True, omega=first_omega))

        for i in range(nr_hidden):
            layers.append(SineLayer(indim=hiddendim, outdim=hiddendim, 
                                      is_first=False, omega=hidden_omega))

        if outermost_linear:
            final_linear = nn.Linear(hiddendim, outdim)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hiddendim) / hidden_omega, 
                                              np.sqrt(6 / hiddendim) / hidden_omega)
                
            layers.append(final_linear)
        else:
            layers.append(SineLayer(indim=hiddendim, outdim=outdim, 
                                      is_first=False, omega=hidden_omega))
        
        self.net = nn.Sequential(*layers)
        

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output    

model = Siren(indim=2, outdim=1, nr_hidden=3, hiddendim=256)

In [None]:
def image_coordinates(size):
    '''Helping function giving the normalized coordinates of the pixels in the 
    image as a torch tensor.'''
    r, c = size
    x, y = torch.linspace(-1, 1, r), torch.linspace(-1, 1, c)
    X, Y = torch.meshgrid(x, y, indexing='ij')  
    coords = torch.stack([X.flatten(), Y.flatten()], dim=1)  
    return coords

def show_predictions(model):

    sizes = [(256, 256), (128, 128), (64, 64)]
    fig, ax = plt.subplots(1, len(sizes), figsize=(12, 4))

    with torch.no_grad():  # I say to torch that I will not ask it to optimize anything based on the computation in this block
        for size, a in zip(sizes, ax):
            coords = image_coordinates(size)
            pred = model(coords).reshape(size)
            a.imshow(pred.detach().numpy())
            a.set_title(f'Predicted image {size}')
        plt.show()

show_predictions(model)

THIS IS HOW FAR I MANAGED IN THIS ITERATION.

In [None]:
#  Tomography functions (instead of nerf rendering)

def project_model(model, theta, nr_B, nr_S=64, L_embed=6):
    '''Forward project the model.'''

    b = tf.linspace(-1, 1, nr_B)
    s = tf.linspace(-1, 1, nr_S)  # TODO: s (or S) may be randomized (slightly permuted)
    B, S = tf.meshgrid(b, s, indexing='xy')  
    pts = tf.stack([B, S], -1)
    pts_flat = tf.reshape(pts, [-1, 2])

    cos_a = np.cos(theta)
    sin_a = np.sin(theta)
    rot = np.array([[sin_a, cos_a], [cos_a, -sin_a]])
    pts_flat = tf.reduce_sum(pts_flat[..., None, :] * rot, -1)  # practically  matmul
   
    pts_flat = posenc(pts_flat, L_embed=L_embed)
    out = model(pts_flat)
    out = tf.reshape(out, (nr_B, nr_S))   

    w = 1/nr_S  # with randomized s weights will not be equal, but computed from s
    p = w * tf.reduce_sum(out, axis = -1)

    return p 

   
def evaluate_model(model, N=128, L_embed=6):
    '''Evaluate the model on the image grid.'''

    l = tf.linspace(-1, 1, N)
    i, j = tf.meshgrid(l, l, indexing='xy')
    pts = tf.stack([i, j], -1)
    pts_flat = tf.reshape(pts, [-1, 2])
    pts_flat = posenc(pts_flat, L_embed=L_embed)
    out = model(pts_flat)
    out = tf.reshape(out, (N, N))
    return out


In [None]:
# Fit model to projections
model = init_model()
optimizer = tf.keras.optimizers.Adam(5e-4)

N_iters = 20
nr_S = 128  # image side, as long as s is not randomized
losses = []

sinogram = tf.cast(sinogram, dtype=tf.float32)  # tf requires float32

for i in range(N_iters+1):
    for j in np.random.permutation(thetas.size):

        theta = thetas[j]
        target = sinogram[j]
    
        with tf.GradientTape() as tape:
            p = project_model(model, theta, nr_B=nr_B, nr_S=nr_S) 
            loss = tf.reduce_mean(tf.square(p - target))
        
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        losses.append(loss.numpy())


    reconstruction = evaluate_model(model, N=64)
        
    fig, ax = plt.subplots(1, 2, figsize=(10, 4))
    ax[0].imshow(reconstruction)
    ax[0].set_title(f'Iteration: {i}')
    ax[1].plot(losses)
    ax[1].set_title('Loss')
    plt.show()

print('Done')

In [None]:
# Visualize the final result
reconstruction = evaluate_model(model, N=image.shape[0])
predicted_sinogram = tf.stack([project_model(model, t, nr_B=nr_B, nr_S=nr_S) for t in thetas])

fig, ax = plt.subplots(2, 3, figsize=(15, 10))
ax[0, 0].imshow(image)
ax[0, 0].set_title(f'GT image, max:{image.max():.02}, sum:{image.sum():.02}')
ax[0, 1].imshow(reconstruction)
ax[0, 1].set_title(f'Reconstruction, max:{reconstruction.numpy().max():.02}, sum:{reconstruction.numpy().sum():.02}')
ax[0, 2].imshow(image - reconstruction, vmin=-1, vmax=1, cmap=plt.cm.bwr)
ax[0, 2].set_title(f'Residual image')
ax[1, 0].imshow(sinogram)
ax[1, 0].set_title(f'GT sinogram, max:{sinogram.numpy().max():.02}, sum:{sinogram.numpy().sum():.02}')
ax[1, 1].imshow(predicted_sinogram)
ax[1, 1].set_title(f'Predicted sinogram, max:{predicted_sinogram.numpy().max():.02}, sum:{predicted_sinogram.numpy().sum():.02}')
ax[1, 2].imshow(sinogram - predicted_sinogram, vmin=-0.1, vmax=0.1, cmap=plt.cm.bwr)
ax[1, 2].set_title(f'Residual sinogram')
plt.show()

