I'll do my best to explain the process of learning the sensitivity image usign a (very) simple residual CNN

In [None]:
# first let's get our imports sorted
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

import os
import sys
import sirf.STIR as stir

import numpy as np
import matplotlib.pyplot as plt

dir_path = os.path.dirname(os.getcwd())
source_path = os.path.join(dir_path, 'source')
sys.path.append(source_path)

from data.ellipses import EllipsesDataset

In [None]:
# First, let's get some template data
data_path = os.path.join(dir_path, 'data', 'template_data')
emission_image = stir.ImageData(os.path.join(data_path, 'emission.hv'))
attenuation_image = stir.ImageData(os.path.join(data_path, 'attenuation.hv'))
template_sinogram = stir.AcquisitionData(os.path.join(data_path, 'template_sinogram.hs'))

In [None]:
num_samples = 128
batch_size = 16

if os.path.exists(os.path.join(data_path, f'X_train_n{num_samples}.pt')) and os.path.exists(os.path.join(data_path, f'y_train_n{num_samples}.pt')):
    X_train = torch.load(os.path.join(data_path, f'X_train_n{num_samples}.pt'))
    y_train = torch.load(os.path.join(data_path, f'y_train_n{num_samples}.pt'))

else:
    dataloader = torch.utils.data.DataLoader(
        EllipsesDataset(attenuation_image, template_sinogram,  
                        num_samples=num_samples, generate_non_attenuated_sensitivity=False),
                        batch_size=batch_size, shuffle=True)

    X_train = []
    y_train = []
    for train in dataloader:
        X, y = train
        X_train.append(X)
        y_train.append(y)

    X_train, y_train = torch.cat(X_train, dim=0), torch.cat(y_train, dim=0)

    # save the data
    torch.save(X_train, os.path.join(data_path, f'X_train_n{num_samples}.pt'))
    torch.save(y_train, os.path.join(data_path, f'y_train_n{num_samples}.pt'))


In [None]:
class SimpleSkipCNN(nn.Module):
    def __init__(self, in_channels=2, out_channels=1, device='cpu'):
        super(SimpleSkipCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.LeakyReLU(0.2)
        
        # Learnable weights for skip connections
        self.skip_weight1 = nn.Parameter(torch.ones(1).to(device)) # Weight for the first skip connection
        self.skip_weight2 = nn.Parameter(torch.ones(1).to(device))  # Weight for the second skip connection

    def forward(self, x):
        x1 = self.conv1(x)
        x = self.relu(x1)
        
        x = self.conv2(x) + self.skip_weight1 * x1
        x = self.relu(x)
        
        x2 = x + x1
        x = self.conv3(x) + self.skip_weight2 * x2  # Apply learned weight to the second skip connection
        x = self.relu(x)
        
        x = self.conv4(x)  # No skip connection here as we're reducing to output channels
        
        return nn.ReLU()(x) # ReLU activation on the output to ensure non-negativity


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SimpleSkipCNN(in_channels=2, out_channels=1, device=device)

In [None]:
# lets see what we get from the untrained model

out = model(X_train[0].to(device))

In [None]:
plt.imshow(out[0].detach().cpu().numpy())
plt.title('Untrained model output')

In [None]:
# unsurpsingly, it's just a load of rubbish
# Let's see if we can train it to do something useful

def train_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)

train_loss = []
for epoch in range(50):
    loss = train_epoch(model, train_loader, nn.MSELoss(), torch.optim.Adam(model.parameters(), lr=0.001), device)
    train_loss.append(loss)
    print(f'Epoch {epoch} - Loss: {loss}')

In [None]:
plt.plot(train_loss)

In [None]:
# let's generate some test data - just ellipses for now
test_data = EllipsesDataset(attenuation_image, template_sinogram, num_samples=8, generate_non_attenuated_sensitivity=False)
test_loader = DataLoader(test_data, batch_size=1, shuffle=True)

# let's see how the model does on the test data
model.eval()
with torch.no_grad():
    for i, (X, y) in enumerate(test_loader):
        X, y = X.to(device), y.to(device)
        vmax = max([X[0, 0].max(), y[0].max()])
        out = model(X)
        plt.figure(figsize=(8, 2))
        plt.subplot(1, 3, 1)
        plt.imshow(X[0, 0].detach().cpu().numpy(), vmax=vmax)
        plt.title('Input')
        plt.subplot(1, 3, 2)
        plt.imshow(out[0].detach().cpu().numpy()[0], vmax=vmax)
        plt.title('Model output')
        plt.subplot(1, 3, 3)
        plt.imshow(y[0].detach().cpu().numpy(), vmax=vmax)
        plt.title('Ground truth')
        # no ticks
        plt.xticks([])
        plt.yticks([])
        # colorbar
        plt.colorbar()
        if i > 5:
            break


We're still a long way off but we can see that something is happening!

Below is a little bit of code to do with estimating the importance of different part sof your image to the model. It's not entirely necessary - more me experimenting a little

In [None]:
# Create a dummy input tensor (2 channels, 155x155 pixels)
input_tensor = X.requires_grad_() # using last input from the test data

# Forward pass to compute the output
output = model(input_tensor.to(device))

# Select a target for which to compute gradients
# Here, we simply take the mean of the output as a representative target
target = output.mean()

# Compute gradients of the target with respect to input
target.backward()

# Compute the saliency map as the absolute value of the gradient
saliency_map = input_tensor.grad.data.abs().squeeze().sum(0)  # Sum the gradients across the channels

In [None]:
def relu_backward_hook_function(module, grad_in, grad_out):
    """
    If there is a negative gradient, changes it to zero.
    """
    # grad_in contains the gradient with respect to the input of the ReLU
    # grad_out contains the gradient with respect to the output of the ReLU
    if isinstance(module, nn.ReLU):
        return (torch.clamp(grad_in[0], min=0.0),)

# 2. Register the hook for all ReLU layers in the model
def register_hooks(model):
    """
    Registers the backward hook for all ReLU layers in the model.
    """
    for module in model.modules():
        if isinstance(module, nn.ReLU):
            module.register_backward_hook(relu_backward_hook_function)

In [None]:
register_hooks(model)
model.eval()  # Set the model to evaluation mode

# Forward pass
output = model(input_tensor.to(device))

# Define a target for the backward pass, for example, the mean of the output
target = output.mean()

# Backward pass
target.backward()

# 4. Extract the gradients
# The gradients of the input with respect to the target are now modified by guided backpropagation
gradients = input_tensor.grad.data

In [None]:
# show both gradients
fig, ax = plt.subplots(1, 2)
ax[0].imshow(saliency_map.cpu().numpy(), cmap='hot')
ax[0].set_title('Saliency Map')
ax[1].imshow(gradients.squeeze().sum(0).cpu().numpy(), cmap='hot')
ax[1].set_title('Guided Backpropagation')