# Implicit Neural Representations Tutorial for BVM 2025

This tutorial will introduce the concept of Implicit Neural Representations (INRs) and how they can be used in medical image analysis. We will cover the basics of INRs, how they can be used in image reconstruction/representation, denoising and non-linear registration tasks. We will also cover the basics of the [SIREN](https://vsitzmann.github.io/siren/) architecture and how it can be used to implement INRs. Additionally, we will look at a coordinate encoding method that can be used to improve the performance of INRs. 

In [None]:
# import and set up the environment
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from IPython.display import clear_output, display
from matplotlib import pyplot as plt

# # download needed images and files
!wget -nc https://cloud.imi.uni-luebeck.de/s/fBnwQNLWXNDqsj5/download -O ct_image_pytorch.pth
!wget -nc https://cloud.imi.uni-luebeck.de/s/Njgp9L78KDkJXFj/download -O images_flow.pth
!wget -nc https://cloud.imi.uni-luebeck.de/s/85824cEMDK2zbFr/download -O utils.py

import utils

# set matplotlib them to default
plt.style.use('default')

# Set device
# if nvidia gpu is available use it
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using {device}")
# if mps on MacBook is available use it
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
    print(f"Using {device} and setting PYTORCH_ENABLE_MPS_FALLBACK=1")
# fallback to cpu
else:
    device = torch.device('cpu')
    print(f"Using {device}")


## 1. Image Reconstruction

Implicit neural representations aim to represent a singular data instance, i.e. an image in our case, as a continuous function. This function is represented with a neural network, that gets coordinates as input and is optimized to return the image values at that coordinate. In short, we seek to parameterize a greyscale image $f(x)$ with pixel coordinates $x$ with a neural network $\Phi$ such that $\mathcal{L}=\iint_{\Omega} \lVert \Phi(\mathbf{x}) - f(\mathbf{x}) \rVert\mathrm{d}\mathbf{x}$ is minimized ($\Omega$ being the domain of the image).


Different from other networks and tasks, an INR network is not trained on a larger training dataset, but should explicitly overfit on that singular data instance.

Let's first load in the image we want to represent.

In [None]:
img = torch.load('ct_image_pytorch.pth').squeeze()
print(f"Image shape: {img.shape}")

plt.imshow(torch.clamp(img, -500, 500).cpu().numpy(), cmap='gray')
plt.axis('off')
plt.show()


### 1.1. SIREN implementation

Implement the [SIREN](https://vsitzmann.github.io/siren/) network architecture. The network always has a first linear layer, a last one, and $n$ hidden layers. The hidden layers are all the same and consist of a linear layer followed by a sine activation function. The last linear layer has no activation.  

First implement the SIREN network, which consists of linear layers and sine activations.

In [None]:
class SIREN(nn.Module):
    def __init__(self,in_features, out_features, hidden_ch=256,num_layers=3, scale=30):
        super().__init__()
        self.scale = scale
        # Todo: initialize the first linear layer
        
        # Todo: initialize the hidden layers
            

        # Todo: initialize the last linear layer
        

        
    def forward(self,x):
        # Todo: apply the sine function to all the layers multiplied with the scaling factor except the last one
        return x


Furthermore, implement a function that initializes the weights of the linear layers.
- The first layer should be initialized with a uniform distribution in the range: $$\mathcal{U}(-\frac{1}{fan\_in}, \frac{1}{fan\_in})$$
- The hidden layers should be initialized with a uniform distribution in the range: $$\mathcal{U}(-\frac{\sqrt{\frac{6}{fan\_in}}}{scale}, \frac{\sqrt{\frac{6}{fan\_in}}}{scale})$$

You can use `torch.nn.init` functions to initialize the weights of the linear layers see the [Documentation](https://pytorch.org/docs/stable/nn.init.html) for more details.

In [None]:
def weights_init(m, scale=30):
    with torch.no_grad():
        # Given: check if the module is a linear layer
        if isinstance(m, nn.Linear):
            # Shape of the weight matrix: (fan_out, fan_in)
            # We want to initialize the weights depending on fan_in
            fan_in = 
            if m.weight.shape[1]>4: # if not first layer
                # Todo: calculate the val for the weights initialization
               
            else:
                # Todo: calculate the val for the weights initialization
                

            # Todo: initialize the weights with uniform distribution between -val and val

            # Todo: initialize the bias with zeros



Test your implementation of the SIREN Network with the following: settings in_features = 1, hidden_features = 64, hidden_layers = 3, out_features = 1

Call `weights_init` on your network using `.apply` ([Doc](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.apply)).

The Network you have implemented should have 12673 parameters.

In [None]:
# Given: Code to check the number of parameters
net = SIREN(in_features=1, out_features=1, hidden_ch=64, num_layers=3)
print(net)
net.apply(weights_init)
print('Parameters:', utils.num_params(net))


### Task 1.2 Coordinate grid

Since the input to the SIREN network (INRs) are coordinates, write a function that returns a grid of coordinates for a given size $H\times W$. The coordinate values should be between -1 and 1.

You can do that by using `torch.mehsgrid`, `torch.linspace` and `torch.stack`.
Alternatively, you can do that in one step using `torch.nn.functional.affine_grid` and `torch.eye`.

The grid should be reshaped to a vector of shape $H*W\times 2$, where $N$ ist the number of pixels in the image (use `.view`).

In [None]:
# create_grid (H,W): create coordinate grid of size HxW
def create_grid(H,W):
    # Todo: create a grid of size HxW with coordinates in the range [-1,1]
    coords = None

    return coords


In [None]:
# Given: Test your function to create an example grid of size 100 x 100
coords = create_grid(100,100)
print(coords.shape)
plt.scatter(coords[:, 0], coords[:, 1], s=1.)


### Task 1.3 Training
Now you can implement the training routine to reconstruct the image. The loss function should be the mean squared error between the predicted and the ground truth image.

In [None]:
# Given: optimization setup
num_iters = 100
lr = 1e-3
H, W = img.shape

# Todo: create a 2D grid of coordinates with the same size as the image
coordinates = None

# Todo: create a SIREN Network with the following settings
# in_features = 2, hidden_features = 256, hidden_layers = 3, out_features = 1
# The Network you have implemented should have 198401 parameters
net = None

print('Parameters:', utils.num_params(net))

# Todo: initialize the weights of the network


# Todo: initialize the optimizer with the parameters of the network and a learning rate of 1e-3
optimizer = None

# Todo: normalize image by its maximum value and minimum value and move it to the device
scale_min, scale_max = None, None
img_normed = None

# Given: plotting
fig, ax = plt.subplots(1, 4, figsize=(24, 5))
labels = ['Reconstructed Image', 'Ground Truth', 'Difference: Ground Truth - Reconstructed', 'Loss']
for i in range(4):
    ax[i].set_title(labels[i])
    if i != 3:
        ax[i].set_axis_off()
ax[3].set_xlim(0, num_iters)


# statistics
running_loss = []

for i in range(num_iters):

    # Todo: reset the gradients of the optimizer
    

    # Todo: forward pass of the network
    siren_recon = None

    # Todo: calculate the loss
    loss =  None

    # Todo: backward pass and update the weights

    running_loss.append(loss.item())
    if i % 10 == 0 or i == num_iters - 1:
        print('Iteration %d    Loss %.4f' % (i, loss.item()))
        ax[0].imshow(torch.clamp(siren_recon.detach().cpu().reshape(H, W) * (scale_max-scale_min) + scale_min, -500, 500), cmap='gray')
        ax[1].imshow(torch.clamp(img, -500, 500), cmap='gray')
        ax[2].imshow((utils.normalize(img) - utils.normalize(siren_recon.detach().cpu().reshape(H, W) * (scale_max-scale_min) + scale_min)).abs(), vmin=0, vmax=1, cmap='jet')
        ax[3].plot(running_loss, 'r')
        display(fig); plt.close()

plt.show()


### Why the periodic activation function?
In other networks ReLU, sigmoid or tanh are very common activation function. For INRs this however bears some problems.
The coordinate input for nearby voxels differs only slightly. When applying only linear layers and those activations, we can also only obtain a smooth curve as output. This prohibits us from representing higher-dimensional properties (e.g. edges) of the image properly.

To demonstrate the difference, create a class `MLP` that is identical to the `SIREN` class except for the nonlinearity: Use `F.relu` instead of the scaled sinus function. Then just copy the training above and use the simple MLP+ReLU network instead of the siren network. You'll see the simple ReLU network doesn't have the capacity to represent high-dimensional features and only finds an extremely low-dimensional fit for the image.

In [None]:
# Todo: copy the training from above and change SIREN to MLP
class MLP(SIREN):
    def forward(self,x):
        #TODO: Remove the sinusoidal activation from SIREN and change them to F.relu
        return x

# Given: optimization setup
num_iters = 100
lr = 1e-3
H, W = img.shape

# Todo: create a 2D grid of coordinates with the same size as the image using your create_grid function
coordinates = None

# Todo: create a Network with the following settings
# in_features = 2, hidden_features = 256, hidden_layers = 3, out_features = 1
# The Network you have implemented should have 198401 parameters
net = None

# Give: print the number of parameters of the network
print('Parameters:', utils.num_params(net))
print(net)

# Todo: normalize image by its maximum value and minimum value and move it to the device
scale_min, scale_max = None, None
img_normed = None

# Todo: initialize the optimizer with the parameters of the network and a learning rate of 1e-3
optimizer = None

# Given: plotting
fig, ax = plt.subplots(1, 4, figsize=(24, 5))
labels = ['Reconstructed Image', 'Ground Truth', 'Difference: Ground Truth - Reconstructed', 'Loss']
for i in range(4):
    ax[i].set_title(labels[i])
    if i != 3:
        ax[i].set_axis_off()
ax[3].set_xlim(0, num_iters)


# statistics
running_loss = []

for i in range(num_iters):

    # Todo: reset the gradients of the optimizer
    

    # Todo: forward pass of the network
    reluMLP_recon = None

    # Todo: calculate the loss
    loss = None

    # Todo: backward pass and update the weights

    running_loss.append(loss.item())
    if i % 10 == 0 or i == num_iters - 1:
        print('Iteration %d    Loss %.4f' % (i, loss.item()))
        ax[0].imshow(torch.clamp(reluMLP_recon.detach().cpu().reshape(H, W) * (scale_max - scale_min) + scale_min, -500, 500), cmap='gray')
        ax[1].imshow(torch.clamp(img, -500, 500), cmap='gray')
        ax[2].imshow((utils.normalize(img) - utils.normalize(reluMLP_recon.detach().cpu().reshape(H, W) * (scale_max - scale_min) + scale_min)).abs(), vmin=0, vmax=1, cmap='jet')
        ax[3].plot(running_loss, 'r')
        display(fig); plt.close()

plt.show()


Compare the results of the SIREN and MLP network.

In [None]:
# Given: compare the results of siren and mlp networks by calculating their ssim and psnr values and plot the images
siren_recon = siren_recon.detach().cpu().reshape(H, W)
reluMLP_recon = reluMLP_recon.detach().cpu().reshape(H, W)

img_list = [ siren_recon, reluMLP_recon]
titles = ['SIREN', 'MLP']
utils.plot_comparison(img, img_list, titles)



### Other approaches
Aside from SIREN, there are further approaches to properly represent higher-dimensional images. Other activation functions than sinusoidal, e.g. [wavelets](https://openaccess.thecvf.com/content/CVPR2023/papers/Saragadam_WIRE_Wavelet_Implicit_Neural_Representations_CVPR_2023_paper.pdf) or gaussian, can work.

Furthermore, instead of changing the activation function, one can directly transform the input through e.g. a [fourier feature mapping](https://arxiv.org/pdf/2006.10739) or hashing grid encoding to obtain a better representation of higher frequencies.

## Task 1.4 Fourier Feature encoding

<center width="100%" style="padding:25px"><img src="https://bmild.github.io/fourfeat/img/teaser.png" width="1000px"></center>

In contrast to using Sinusoidal or Wavelet activation functions we can employ a positional encoder such as [Fourier Features](https://bmild.github.io/fourfeat/) to the coordinates to improve the convergence of high frequencies of a simple ReLU MLP, making it achieve a performance comparable to SIREN or WIRE networks.
In this bonus task, we will implement the Fourier feature mapping, which is based on the Bochner’s theorem to approximate shift-invariant kernels. Here we are going to use random Fourier features to approximate the Gaussian kernel. The Fourier features map the coordinates to a high-dimensional feature space $\gamma(\mathbf{v}): \mathbb{R}^D -> \mathbb{R}^{\mathcal{F}}$, where $D$ is the dimensionality of the input coordinates and $\mathcal{F}$ is the dimensionality of the feature space.

The Fourier features are computed as follows:
$$
\begin{align}
\gamma(\mathbf{v})=[\cos (2 \pi \mathbf{B v}), \sin (2 \pi \mathbf{B v})]^{\mathrm{T}}
\end{align}
$$

where $\mathbf{B}$ is sampled from a Gaussian distribution $\mathbf{B}_i \sim \mathcal{N}(0, \sigma^2 \mathbf{I})$ and $\gamma(\mathbf{v})$ is the Fourier feature of the coordinate $\mathbf{v}$. The scale $\sigma^2$ of the Gaussian matrix helps us control the amount of underfitting/overfitting on high-frequency details.

Hint, you can use:
- `torch.randn` to sample from a Gaussian distribution.
- `torch.matmul` can be used to multiply two tensors.
- `torch.cat` can be used to concatenate tensors.

### Fourier Feature implementation

In [None]:
# Given: Class skeleton for FourierFeatures
class FourierFeatures(nn.Module):
    def __init__(self, coord_dims, num_freqs, scale=1):
        super().__init__()
        # Given
        self.scale = scale
        # Ensure that the number of frequencies is even
        assert num_freqs % 2 == 0
        # Given: calculate the feature dimensions
        self.feature_dims = num_freqs // 2

        # Todo: initialize the basis matrix B with random values and use the scale factor
        self.B = None

    def forward(self, coordinates):
        # Todo: implement the forward pass based on the formula in the description
        features = None
        return features


Test your implementation on the coordinates. The output should have the shape $HW \times F$.

Your features should look similar to the following image (remember that it is not a 100% match since the features are random):
<center width="100%" style="padding:25px"><img src="https://cloud.imi.uni-luebeck.de/s/NfMa2P8WEf884F2/download" width="2000px"></center>

In [None]:
# Given: Test your implementation
frequency_dim = 256
sigma2 = 5

ff_encodings = FourierFeatures(coord_dims=2, num_freqs=frequency_dim, scale=sigma2)(coordinates)
print('Fourier Encodings:', ff_encodings.shape)
ff_encodings = ff_encodings.view(H, W, -1).cpu().detach().numpy()

# Given: plot some of the fourier features
fig, ax = plt.subplots(1, 5, figsize=(20, 5))
for i in range(5):
    ax[i].imshow(ff_encodings[..., i], cmap='jet')
    ax[i].set_title(f'Feature {i}')
    ax[i].axis('off')
plt.show()


### Training

Now, let's train the ReLU MLP with the Fourier feature mapping. The training routine should be the same as the one for the ReLU MLP network but we encode the coordinates with the Fourier feature mapping before feeding them to the network.

In [None]:
# Todo: copy the training from above and change SIREN to MLP
class MLP(SIREN):
    def forward(self,x):
        #Todo: Remove the sinusoidal activation from SIREN and change them to F.relu
        return x

# Given: optimization setup
num_iters = 100
lr = 1e-3

# Todo: create a 2D grid of coordinates with your create_grid function
H, W = img.shape
coordinates = None

# Todo: create a Network with the following settings
# in_features = 256 (num_freqs), hidden_features = 256, hidden_layers = 3, out_features = 1
# The Network you have implemented should have 198401 parameters
net = None


# print the number of parameters of the network
print('Parameters:', utils.num_params(net))
print(net)

# Todo: create your FourierFeatures encoding for coordinate_dims=2, num_freqs=256, scale(sigma2)=5
encoding = None

# Todo: normalize image by its maximum value and minimum value and move it to the device
scale_min, scale_max = None, None
img_normed = None

# Todo: initialize the optimizer with the parameters of the network and a learning rate of 1e-3
optimizer = None

# plotting
fig, ax = plt.subplots(1, 4, figsize=(24, 5))
labels = ['Reconstructed Image', 'Ground Truth', 'Difference: Ground Truth - Reconstructed', 'Loss']
for i in range(4):
    ax[i].set_title(labels[i])
    if i != 3:
        ax[i].set_axis_off()
ax[3].set_xlim(0, num_iters)


# statistics
running_loss = []

for i in range(num_iters):

    # Todo: reset the gradients of the optimizer
    

    # Todo: forward pass of the network with the encoded coordinates
    ff_reluMLP_recon = None

    # Todo: calculate the loss
    loss = None

    # Todo: backward pass and update the weights

    running_loss.append(loss.item())
    if i % 10 == 0 or i == num_iters - 1:
        print('Iteration %d    Loss %.4f' % (i, loss.item()))
        ax[0].imshow(torch.clamp(ff_reluMLP_recon.detach().cpu().reshape(H, W) * (scale_max-scale_min) + scale_min, -500, 500), cmap='gray')
        ax[1].imshow(torch.clamp(img, -500, 500), cmap='gray')
        ax[2].imshow((utils.normalize(img) - utils.normalize(ff_reluMLP_recon.detach().cpu().reshape(H, W) * (scale_max-scale_min) + scale_min)).abs(), vmin=0, vmax=1, cmap='jet')
        ax[3].plot(running_loss, 'r')
        display(fig); plt.close()

plt.show()


Let's plot the comparison to our results before. We can see now with the Fourier feature mapping the ReLU MLP can represent high-frequency details and achieve a performance comparable to SIREN.

In [None]:
ff_reluMLP_recon = ff_reluMLP_recon.detach().cpu().reshape(H, W)
img_list = [siren_recon, reluMLP_recon, ff_reluMLP_recon]
names = ['SIREN', 'ReLU MLP', 'Fourier Features MLP']
utils.plot_comparison(img, img_list, names)


## Task 2: Image Denoising

We now want to use the SIREN network to not only reconstruct an image but also remove added noise from it. To achieve that, we add another loss term to our routine that penalizes noise as done in Exercise 1.

First let's add some noise to our image.

In [None]:
# Given: add noise to the image
noisy_img = utils.add_noise(img, 60)

# plot the images
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(img, cmap='gray')
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(noisy_img, cmap='gray')
plt.title('Noisy Image')
plt.axis('off')


### Task 2.1 Total Variation
Now we implement the [total variation regularization](https://en.wikipedia.org/wiki/Total_variation_denoising) as the sum of the absolute differences between neighboring pixels. You'll need to add the sum for both dimensions. Normalize the result diving by `(Nx - 1) * (Ny - 1)`, Omit any for-loops in your implementation and use clever indexing instead.

In [None]:
def TVLoss(x):

    Nx = x.shape[0]
    Ny = x.shape[1]
    # Todo: calculate the TVLoss
    tv_loss = None
    return tv_loss

# try out the TVLoss
print('TVLoss:', TVLoss(img)) # should return "TVLoss: tensor(53.9729)" for the original image
print('TVLoss:', TVLoss(noisy_img)) #should be around "TVLoss: tensor(160.2281)" for the noisy image


### Task 2.2 Training
Slightly change the training for the reconstruction above. The loss should penalize:
- the difference between reconstructed image and noisy image
- the total variation of the reconstructed image
The ratio between both losses should be controlled with a variable `lamda_tv`. Experiment with this value to obtain a suitable result. (Hint: the total variation loss should be weighted less than the image similarity loss)

In [None]:
# Optimization setup
num_iters = 100
lr = 1e-3
# Todo: weighting factor for the TVLoss
lambda_tv = 0.05

# Todo: create a 2D grid of coordinates with your create_grid function
H, W = img.shape
coordinates = None
# Todo: create a SIREN Network with the following settings
# in_features = 2, hidden_features = 256, hidden_layers = 3, out_features = 1
# The Network you have implemented should have 198401 parameters
denoising_siren = None

# Todo: initialize the weights of the network

# Todo: normalize noisy image by its maximum value and minimum value and move it to the device
scale_min, scale_max = None, None
noisy_img_normed = None

# Todo: initialize the optimizer with the parameters of the network and a learning rate of 1e-3
optimizer = None

# Given: plotting
fig, ax = plt.subplots(1, 5, figsize=(24, 5))
labels = ['Noisy Image', 'Denoised Image', 'Ground Truth', 'Difference: Ground Truth - Denoised', 'Loss']
for i in range(5):
    ax[i].set_title(labels[i])
    if i != 4:
        ax[i].set_axis_off()
ax[4].set_xlim(0, num_iters)


# statistics
running_loss = []

for i in range(num_iters):

    # Todo: reset the gradients of the optimizer

    # Todo: forward pass of the network
    prediction = None

    # Todo: calculate the loss
    loss = None

    # Todo: backward pass and update the weights

    running_loss.append(loss.item())
    if i % 20 == 0 or i == num_iters - 1:
        print('Iteration %d    Loss %.4f' % (i, loss.item()))
        ax[0].imshow(torch.clamp(noisy_img_normed.detach().cpu().reshape(H, W) * (scale_max-scale_min) + scale_min, -500, 500), cmap='gray')
        ax[1].imshow(torch.clamp(prediction.detach().cpu().reshape(H, W) * (scale_max-scale_min) + scale_min, -500, 500), cmap='gray')
        ax[2].imshow(torch.clamp(img, -500, 500), cmap='gray')
        ax[3].imshow((utils.normalize(img) - utils.normalize(prediction.detach().cpu().reshape(H, W) * (scale_max-scale_min) + scale_min)).abs(), vmin=0, vmax=1)
        ax[4].plot(running_loss, 'r')
        display(fig); plt.close()

plt.show()


## Task 3: Registration
For the last task, we want to use implicit neural representations to register the following two images.
Instead of fitting an image, the INR is now supposed to fit the deformation field. The loss will be the difference between the fixed image and the moving image, that is warped using the deformation field, and a regularization term.

In [None]:
# Given
if device == 'mps':
    device = 'cpu'

images = torch.load('images_flow.pth')
fixed = images[0:1].unsqueeze(1).to(device)
moving = images[1:2].unsqueeze(1).to(device)

fig, ax = plt.subplots(1, 3, figsize=(10,20))
ax[0].imshow(fixed.cpu().squeeze(), cmap='gray')
ax[0].set_title('fixed')
ax[1].imshow(moving.cpu().squeeze(), cmap='gray')
ax[1].set_title('moving')
dif = ax[2].imshow((moving - fixed).cpu().abs().squeeze(), vmin=0, vmax=200)
ax[2].set_title('differences')
plt.colorbar(dif, fraction=0.05, pad=0.1)
ax[0].axis('off')
ax[1].axis('off')
ax[2].axis('off')

fig.tight_layout()
plt.show()


Initialize everything needed for the registration task:
- **height and width** parameters $H$, $W$
- the **SIREN network**: the network needs to predict a deformation for the x- and y-values, this should be reflected in the number of output features
- **coordinate input** of size $HW \times 2$
- an **identity grid** of size $1\times H \times W \times 2$ to add to the deformation field for the warping step (you can initialize this the same way as the coordinates)


In [None]:
# Given: Move the images to the device and get the height and width of the images
fixed = fixed.to(device)
moving = moving.to(device)
H,W = fixed.size()[-2:]

# Todo: initialize the SIREN network (remember to initialise the weights as well) set hidden_ch=64 and num_layers=3
net = None  

# Todo: create a 2D grid of coordinates with HxW size, since we are using grid_sample in the training we need to use either affine_grid to create our grid, or using torch.meshgrid but applying .flip(-1) to the output of create_grid() function to match the grid_sample format
identity_grid = None  
# input coordinates are the same as the identity grid but viewed as HW x 2
coordinates = None


### Training

Next, implement the training loop.

The model output should be used to warp the moving image using `F.grid_sample`. To perform the warping step, add the predefined identity grid to your predicted displacement field (use `.view` on the model output).

The loss consists of two parts:
- The mean squared error between fixed and warped image.
- The mean of the squared gradient of the deformation field. The gradient can be obtained with `torch.gradient` and stacked with `torch.stack`. Perform this separately for the x and y direction (first and second feature dimension of the model output).
- The parts can be weighted with a factor (Hint: The Smoothness Factor should be weighted stronger than the image similarity)

In the end, you should be able to obtain a smooth displacement field that creates a well fitting warped image.

In [None]:
total_steps = 500
steps_til_summary = 100

# Todo: Adam optimizer with learning rate 1e-4
optim = None

# plotting
fig, ax = plt.subplots(1, 5, figsize=(24, 5))
labels = ['Fixed', 'Warped', 'Moving', 'Displacement Field', 'Loss']
for i in range(5):
    ax[i].set_title(labels[i])
    if i != 4:
        ax[i].set_axis_off()
ax[4].set_xlim(0, total_steps)

running_loss = []

for step in range(total_steps):
    # Todo: reset gradient

    # Todo: get the prediction
    model_output = None
    # Todo: warp the image
    warped = None

    # Todo: Compute loss criteria
    loss = None
    
    # Todo: backpropagation and optimizer step

    running_loss.append(loss.item())
    if step % steps_til_summary == 0 or step == total_steps - 1:
        rgb = utils.showFlow(model_output.reshape(1,H,W,2).permute(0,3,1,2).data.cpu())
        print('Iteration %d    Loss %.4f' % (step, loss.item()))
        ax[0].imshow(fixed.detach().cpu().reshape(H, W), cmap='gray')
        ax[1].imshow(warped.detach().cpu().reshape(H, W), cmap='gray')
        ax[2].imshow(moving.detach().cpu().reshape(H, W), cmap='gray')
        ax[3].imshow(rgb)
        ax[4].plot(running_loss, 'r')
        display(fig); plt.close()

    
