I'll do my best to explain the process of learning the sensitivity image usign a (very) simple residual CNN

In [None]:
# first let's get our imports sorted
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

import os
import sys
import sirf.STIR as stir

import matplotlib.pyplot as plt

dir_path = os.path.dirname(os.getcwd())
source_path = os.path.join(dir_path, 'source')
sys.path.append(source_path)

from data.ellipses import EllipsesDataset

In [None]:
# First, let's get some template data
data_path = os.path.join(dir_path, 'data', 'template_data')
emission_image = stir.ImageData(os.path.join(data_path, 'emission.hv'))
attenuation_image = stir.ImageData(os.path.join(data_path, 'attenuation.hv'))
template_sinogram = stir.AcquisitionData(os.path.join(data_path, 'template_sinogram.hs'))

In [None]:
num_samples = 32
batch_size = 8

dataloader = torch.utils.data.DataLoader(
    EllipsesDataset(attenuation_image, template_sinogram,  
                    num_samples=num_samples, generate_non_attenuated_sensitivity=False),
                    batch_size=batch_size, shuffle=True)

X_train = []
y_train = []
for train in dataloader:
    X, y = train
    X_train.append(X)
    y_train.append(y)

X_train, y_train = torch.cat(X_train, dim=0), torch.cat(y_train, dim=0)


In [None]:
class UltraSimpleResNet(nn.Module):
    def __init__(self, in_channels=2, out_channels=1):
        super(UltraSimpleResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x) + x
        x = self.relu(x)
        x = self.conv3(x)+ x
        x = self.relu(x)
        x = self.conv4(x)
        return x
    



In [None]:
model = UltraSimpleResNet(in_channels=2, out_channels=1)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# lets see what we get from the untrained model

out = model(X_train[0].to(device))

In [None]:
plt.imshow(out[0].detach().cpu().numpy())
plt.title('Untrained model output')

In [None]:
# unsurpsingly, it's just a load of rubbish
# Let's see if we can train it to do something useful

def train_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)

train_loss = []
for epoch in range(10):
    loss = train_epoch(model, train_loader, nn.MSELoss(), torch.optim.Adam(model.parameters(), lr=0.001), device)
    train_loss.append(loss)
    print(f'Epoch {epoch} - Loss: {loss}')

In [None]:
plt.plot(train_loss)

In [None]:
# Let's see what we get from the trained model
out = model(X_train[0].to(device))
plt.imshow(out[0].detach().cpu().numpy())

In [None]:
# and we've got something that looks a little bit more sensible!!!
# Let's compare it to the ground truth
fig, ax = plt.subplots(1, 3)
ax[0].imshow(out[0].detach().cpu().numpy())
ax[0].set_title('Model output')
ax[1].imshow(y_train[0].detach().cpu().numpy())
ax[1].set_title('Ground truth')
ax[2].imshow((X_train[0][0]).detach().cpu().numpy())
ax[2].set_title('Input sensitivity')

We're still a long way off (it looks much more like the input sensitivity) but we can see that something is happening!