<a href="https://colab.research.google.com/github/tomdct/BlogColabScripts/blob/main/SirenImageFitting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Image fitting with a SIREN**

---

We investigate fitting an image with a SIREN. The original SIREN project page is available [here](https://www.vincentsitzmann.com/siren/), and our code has been partially adapted from theirs.

(**N.B.** You should run the code on this page with the GPU enabled, which you can find under `Runtime > Change runtime type > Hardware accelerator`.)


## **Preliminary business**


In [1]:
#@title We start by making all of the imports that we will need.

import os
import numpy as np
import matplotlib.pyplot as plt
import time
import requests
from PIL import Image
from io import BytesIO

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

In [None]:
#@title Next we load our image.

# Loading the image from a url.

img_url = 'https://cdn.britannica.com/79/150179-050-E2707D87/human-eye.jpg'
response = requests.get(img_url)
eye_img = Image.open(BytesIO(response.content))

# We reduce everything down by a scale factor of four, to
# speed up the experiments. To use the original image set
# scale_factor = 1

scale_factor = 4
eye_img = eye_img.resize([eye_img.width // scale_factor, eye_img.height // scale_factor])

# Collect height, width and number of channels (in this case 3, for RGB)

img_width = eye_img.size[0]
img_height = eye_img.size[1]
img_size = (img_height, img_width)
no_of_channels = len(eye_img.getbands())

# Now show the eye image.

fig,ax = plt.subplots(1,1)

ax.imshow(eye_img)
ax.set_title('This is the eye image we will be using:')
ax.axis('off')

plt.show()


In [62]:
#@title Converting an image into a tensor and creating a grid

# Converting an image to a tensor converst RGB values to
# floats with values in [0, 1]. We rescale to get values
# in [-1, 1].

def image_to_tensor(img):
    transform = Compose([
        ToTensor(),
        Normalize(torch.Tensor([0.5]),torch.Tensor([0.5]))
    ])
    return transform(img)

# Creates an array of grid coordinates

def create_grid(grid_steps, bounds=None):
    # given the input
    #    grid_steps = (a_1,a_2,...,a_n)
    # creates an array of coordinates in an n-dimensional grid with a_i steps
    # in the ith dimension. If bounds=None then the bounds are [-1,1]

    grid_dim = len(grid_steps)

    if bounds == None:
        bounds = [ [-1,1] for i in range(grid_dim)]

    tensors = tuple( [ torch.linspace(bounds[i][0],
                                      bounds[i][1],
                                      steps=grid_steps[i]
                                      ) for i in range(grid_dim) ] )

    my_grid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    return my_grid.reshape(-1, grid_dim)

# Define a dataloader for obtaining the RGB values at each pixel.

class ImageFitter(Dataset):
    def __init__(self, img):
        super().__init__()
        img_tensor = image_to_tensor(img)
        self.pixels = img_tensor.permute(1, 2, 0).view(-1, no_of_channels)
        self.coords = create_grid((img.height, img.width))

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        if idx > 0: raise IndexError

        return self.coords, self.pixels

img_fitter = ImageFitter(eye_img)
dataloader = DataLoader(img_fitter, batch_size=1, pin_memory=True, num_workers=0)

In [63]:
#@title Define SineLayer and Siren

class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the nonlinearity. Different signals may require different omega_0 in the first layer - this is a hyperparameter.
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))


class Siren(nn.Module):
    def __init__(self, *layer_sizes, outermost_linear=False, first_omega_0=30, hidden_omega_0=30.):
        super().__init__()

        self.no_of_layers = len(layer_sizes)
        self.net = []

        # First layer
        self.net.append(SineLayer(layer_sizes[0], layer_sizes[1], is_first=True, omega_0=first_omega_0))

        # Hidden layers
        for i in range(1, self.no_of_layers - 2):
            self.net.append(SineLayer(layer_sizes[i], layer_sizes[i+1],
                                          is_first=False, omega_0=hidden_omega_0))

        # Last layer
        if outermost_linear:
            final_linear = nn.Linear( layer_sizes[self.no_of_layers-2], layer_sizes[self.no_of_layers-1] )

            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / layer_sizes[self.no_of_layers-1]) / hidden_omega_0,
                                              np.sqrt(6 / layer_sizes[self.no_of_layers-1]) / hidden_omega_0)

            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(layer_sizes[self.no_of_layers-2], layer_sizes[self.no_of_layers-1],
                                      is_first=False, omega_0=hidden_omega_0))

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords


In [60]:
#@title Implement the training loop

# Define the siren and optimiser

img_siren = Siren(2, 256, 256, 256, 3, outermost_linear=True, first_omega_0=30, hidden_omega_0=30)

if torch.cuda.is_available():
    img_siren.cuda()

optim = torch.optim.Adam(lr=1e-4, params=img_siren.parameters())

# Implement the summary step.

def print_summary(step, loss, model_output):
    fig, ax = plt.subplots(1,2)
    model_output = (model_output + 1)/2
    a = model_output.cpu().permute(1,2,0).view(img_height, img_width, no_of_channels).detach().numpy()
    ax[0].imshow(a)
    ax[0].set_title(f"After %d steps, loss=%0.6f." % (step, loss))
    ax[1].imshow(eye_img)
    ax[1].set_title(f"Original image")
    plt.show()

# Define the training loop.

def run_through_training_loop(total_steps, steps_til_summary):

    model_input, ground_truth = next(iter(dataloader))

    if torch.cuda.is_available():
        model_input, ground_truth = model_input.cuda(), ground_truth.cuda()

    for step in range(total_steps+1):
        model_output, coords = img_siren(model_input)
        loss = ((model_output - ground_truth)**2).mean()

        optim.zero_grad()
        loss.backward()
        optim.step()

        if not step % steps_til_summary:
            print_summary(step, loss, model_output)

## **Experiment 1: train a SIREN to fit the image.**

We...

In [None]:
run_through_training_loop(5000, 500)

## **Experiment 2: check for any improvement in resolution.**

We...

In [None]:
#@title Check the improvement in resolution of the close-up patches.

# We pick out three close-up patches which we will use to
# test the resolution capabilities of our siren.

top_left_pixels = [[230, 338], [336, 174], [630, 200]]

fig,ax = plt.subplots(1,3, figsize=(8,8*3))
ax[1].set_title('Here are three close-ups:')

for i in range(3):
    a, b = top_left_pixels[i][0] // scale_factor, top_left_pixels[i][1] // scale_factor
    patch_size = 128 // scale_factor
    ax[i].imshow(eye_img.crop((a,b,a+patch_size,b+patch_size)))
    ax[i].axis('off')

plt.show()

with torch.no_grad():

    fig,ax = plt.subplots(3,2, figsize=(8,12))

    for i in range(3):
        a, b = top_left_pixels[i][0], top_left_pixels[i][1]
        zoomed_in_region = create_grid((512,512), bounds = [[2*b/img_height - 1, 2*(b + patch_width)/img_height - 1],
                                                            [2*a/img_width - 1,2*(a + patch_width)/img_width - 1]])

        if torch.cuda.is_available():
            zoomed_in_region = zoomed_in_region.cuda()

        model_out, _ = img_siren(zoomed_in_region)
        model_out = (model_out + 1)/2
        ax[i,0].axis('off')
        ax[i,1].axis('off')
        ax[i,0].imshow(model_out.cpu().view(512,512,3).detach().numpy())
        ax[i,1].imshow(eye_img.crop((a,b,a + patch_width,b + patch_width)))
    ax[0,0].set_title(f"SIREN")
    ax[0,1].set_title(f"Original image")

    fig.show()



## **Experiment 3: comparision with other types of activation functions.**

We now implement simple ReLU and sigmoid networks with identical layers, and see the difference in how quickly the corresponding network can learn the image.

In [24]:
#@title Define simple ReLU and Sigmoid networks

# ReLU layers and ReLU networks

class ReLULayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

    def forward(self, input):
        return torch.relu(self.linear(input))

class ReLUNetwork(nn.Module):
    def __init__(self, *layer_sizes, outermost_linear=False):
        super().__init__()

        self.no_of_layers = len(layer_sizes)
        self.net = []

        # First and hidden layers
        for i in range(0, self.no_of_layers - 2):
            self.net.append(ReLULayer(layer_sizes[i], layer_sizes[i+1]))

        # Last layer
        if outermost_linear:
            final_linear = nn.Linear( layer_sizes[self.no_of_layers-2], layer_sizes[self.no_of_layers-1] )
            self.net.append(final_linear)
        else:
            self.net.append(ReLULayer(layer_sizes[self.no_of_layers-2], layer_sizes[self.no_of_layers-1]))

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords

# sigmoid layers and sigmoid networks

class SigmoidLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

    def forward(self, input):
        return torch.sigmoid(self.linear(input))

class SigmoidNetwork(nn.Module):
    def __init__(self, *layer_sizes, outermost_linear=False):
        super().__init__()

        self.no_of_layers = len(layer_sizes)
        self.net = []

        # First and hidden layers
        for i in range(0, self.no_of_layers - 2):
            self.net.append(SigmoidLayer(layer_sizes[i], layer_sizes[i+1]))

        # Last layer
        if outermost_linear:
            final_linear = nn.Linear( layer_sizes[self.no_of_layers-2], layer_sizes[self.no_of_layers-1] )
            self.net.append(final_linear)
        else:
            self.net.append(SigmoidLayer(layer_sizes[self.no_of_layers-2], layer_sizes[self.no_of_layers-1]))

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords

In [None]:
#@title A comparison of different models


# Define the three different models and their optimisers

img_siren   = Siren(2, 256, 256, 256, 3, outermost_linear=True, first_omega_0=30, hidden_omega_0=30)
img_relu    = ReLUNetwork(2, 256, 256, 256, 3, outermost_linear=True)
img_sigmoid = SigmoidNetwork(2, 256, 256, 256, 3, outermost_linear=True)

if torch.cuda.is_available():
    img_siren.cuda()
    img_relu.cuda()
    img_sigmoid.cuda()

optim_siren   = torch.optim.Adam(lr=1e-4, params=img_siren.parameters())
optim_relu    = torch.optim.Adam(lr=1e-4, params=img_relu.parameters())
optim_sigmoid = torch.optim.Adam(lr=1e-4, params=img_sigmoid.parameters())

# Comparison summary for the training loop.

def print_comparison_summary(step, model_output):
    fig, ax = plt.subplots(1,4)
    titles = ["Sin","ReLU","Sigmoid"]
    for i in range(3):
        ax[i].axis('off')
        model_output[i] = (model_output[i] + 1)/2
        a = model_output[i].cpu().permute(1,2,0).view(img_height, img_width, no_of_channels).detach().numpy()
        ax[i].imshow(a)
        ax[i].set_title(titles[i])
    ax[3].axis('off')
    ax[3].imshow(eye_img)
    ax[3].set_title(f"Original image")
    plt.show()

# Implement the training loop.

def comparison_training_loop(total_steps, steps_til_summary):

    model_input, ground_truth = next(iter(dataloader))

    if torch.cuda.is_available():
        model_input, ground_truth = model_input.cuda(), ground_truth.cuda()

    for step in range(total_steps+1):
        model_outputs = []
        losses = []

        for model_and_optim in [ [ img_siren,   optim_siren   ],
                                 [ img_relu,    optim_relu    ],
                                 [ img_sigmoid, optim_sigmoid ]]:

            model = model_and_optim[0]
            optim = model_and_optim[1]

            model_output, coords = model(model_input)
            model_outputs.append(model_output)

            loss = ((model_output - ground_truth)**2).mean()
            losses.append(loss)

            optim.zero_grad()
            loss.backward()
            optim.step()

        if not step % steps_til_summary:
            print(f"Siren, ReLU and Sigmoid after %d steps, with losses %0.6f, %0.6f, %0.6f." % (step, losses[0], losses[1], losses[2]))
            print_comparison_summary(step, model_outputs)



comparison_training_loop(5000,100)



## **Experiment 4: messing around with the SIREN.**

In [None]:
# now we...