# Sinusoidal Representation Network for Single Image Super Resolution
This is the implementation of the study conducted in https://github.com/robertobressani/ai_ml_siren_sr

The following code (except from [Using ESDR](##using-edsr) which uses TPU) must be run by selecting GPU as hardware accelerator (Runtime -> change runtime type)

The first 3 blocks should be run once before running every experiment

NOTE: some blocks can have problems with output logging. `train` accept a `with_log` param to be set to `False` to avoid this problem (all the output is however reported under `/plots/*` folder)

In [None]:
!rm -r ai_ml_siren_sr
!git clone -b main https://github.com/robertobressani/ai_ml_siren_sr
!pip3 install 'torch'
!pip3 install 'torchvision'
!pip3 install 'kornia'

In [None]:
BASE_DIR='/content/ai_ml_siren_sr'
DEVICE= 'cuda'

import sys
sys.path.insert(1, BASE_DIR)

In [None]:
import math
import torch
from torch import nn
import torch.nn.functional as F
import os
import PIL
import matplotlib
from torchvision.transforms import Resize
import matplotlib.pyplot as plt
import numpy


from utils import data_utils, math_utils, summary_utils
from core.trainer import Trainer
from core.network import NetworkParams, NetworkDimensions
from datasets.AudioSignalDataset import AudioSignal
from datasets.ImageFittingDataset import ImageFitting
from datasets.PoissonImageDataset import PoissonImageDataset

from layers.relulayer import ReLuLayer
from layers.sinelayer import SineLayer


# Audio signal fitting

In [None]:
playlist = ["gt_bach", "gt_counting", "gt_movie"]

# defining experiments
nets_params = [
               NetworkParams(layer_class= ReLuLayer, description='relu'),
               NetworkParams(description='siren'),
               NetworkParams(description='siren_omega_r'), 
               NetworkParams(description='siren_omega_r_fft'), 
               ]

iterations= 500


def get_first_omega_0(description, omega_r):
  first_omega_0 = 30
  if description.find("siren") >= 0:
    if description.find("omega_r") >= 0:
      first_omega_0 = omega_r
    else:
      first_omega_0 = 3000
  
  return first_omega_0


def get_hidden_layers(num_samples, channels):
  return min(5, 1 + int((0.6 * num_samples * channels) / 256 ** 2))


for net_params in nets_params:

    # define a trainer for each params
    trainer = Trainer(net_params, device=DEVICE)

    for name in playlist:
        # load dataset
        dataset = AudioSignal(f"{BASE_DIR}/data/{name}.wav", name=name)

        # get dynamically hidden layers
        hidden_layers = get_hidden_layers(dataset.get_num_samples(), dataset.channels)

        # get omega_0 for the first layer
        first_omega_0 = get_first_omega_0(net_params.description, dataset.rate)

        # prepare network dimensions
        dims = NetworkDimensions(1, dataset.channels, hidden_layers=hidden_layers, first_omega_0=first_omega_0)

        # define the loss function to use
        combining_losses = net_params.description.find("fft") >= 0 
        loss_fn = (lambda gt, model_output : \
                   data_utils.get_fft_mse(gt, model_output, 50)) if combining_losses \
                   else data_utils.get_mse

        # train
        print(f"\nTrain with:\nhidden_layers={hidden_layers}\nfirts omega={first_omega_0}\ncombining losses={combining_losses}\n")
        trainer.train(dims, [dataset], [iterations], summary_fn=summary_utils.audio_summary, lr=1e-4, loss_fn=loss_fn,
                      patience=min(100, 10 * hidden_layers))
        
        # test
        trainer.test(dataset, validate_fn=summary_utils.audio_validate)

    mean, std = trainer.statistics()
    print(f"\nMSE Mean {mean[0]} and std {std[0]} for {net_params.description}")

# Image fitting

In [None]:
images = data_utils.get_div2k_images()

# defining experiments
nets_params = [
               NetworkParams(description='ImageFitting with SIREN'), 
               NetworkParams(layer_class= ReLuLayer, description='ImageFitting using ReLU activation'),
               NetworkParams(description='ImageFitting with SIREN + custom omega_0 ')
               ]
iterations= 5000
hidden_layers = 2

for net_params in nets_params:
  # defining trainer 
  trainer = Trainer(net_params, device=DEVICE)
  for image_name in images: 
      # training and test for each image        
      image = data_utils.get_div2k_image(image_name, dir=BASE_DIR+"/", resolution= 'low')

      # loading the dataset 
      dataset = ImageFitting(data_image=image, name= image_name, normalized=True)

      #computing omega_0* basing on laplacian
      lapl = PoissonImageDataset(data_image = image, name= image_name, fit_laplacian = True)      
      first_omega_0 = torch.std(lapl.gt).detach()*250 if net_params.description.find("omega")>=0 else 30

      #dimensioning the network
      dims = NetworkDimensions(2, dataset.channels, hidden_layers=hidden_layers, first_omega_0 = first_omega_0)

      # defining the experiment on the basis of the experiment
      lr = 1e-4 if net_params.layer_class == ReLuLayer else 5e-4

      #training the network
      trainer.train(dims, [dataset], [iterations], summary_fn=summary_utils.image_fitting_summary, 
                    loss_fn=data_utils.get_mse, lr=lr,
                    patience = max(100, 10*hidden_layers))
            
      # testing the representation obtained             
      trainer.test(dataset, validate_fn=summary_utils.image_fitting_validate)

  # getting results for each experiment
  mean, std = trainer.statistics(compute_PSNR=True)
  print(f"PSNR Mean {mean[0]} dB and std {std[0]} dB for {net_params.description}")

# Solving Poisson Equation



## Training on gradient

In [None]:
images = data_utils.get_div2k_images()

# defining experiments to perform
nets_params = [
               NetworkParams(description='Poisson trained on grad with SIREN'), 
               NetworkParams(description='Poisson trained on grad with SIREN with numerical methods'), 
               NetworkParams(layer_class= ReLuLayer, description='Poisson trained on grad using ReLU activation with numerical methods')
               ]
iterations= 5000
hidden_layers = 2

desc= ["Image", "Gradient", "Laplacian"]

for net_params in nets_params:
  # defining trainer for every experiment
  trainer = Trainer(net_params, device=DEVICE)
  for image_name in images:         

      # defining dataset
      image = data_utils.get_div2k_image(image_name, dir=BASE_DIR+"/", resolution= 'low')
      dataset = PoissonImageDataset(data_image=image, name=image_name,
                              fit_laplacian=False, normalized=True)   
      
      # computing omega_0* basing on laplacian
      lapl = PoissonImageDataset(data_image = image, name= image_name, fit_laplacian = True)
      first_omega_0 = torch.std(lapl.gt).detach()*250 if net_params.layer_class == SineLayer else 30
    
      # computing transformation on which to compute the loss
      manipulation = data_utils.get_manipulator("grad_num") if net_params.description.find("numerical")>=0 \
          else data_utils.get_manipulator("grad", .1)
      
      # defining network dimensions
      dims = NetworkDimensions(2, dataset.channels, hidden_layers=hidden_layers, first_omega_0=first_omega_0)

      # defining learning rate for the experiment
      lr= 5e-4 if net_params.layer_class == SineLayer and net_params.description.find("numerical")>0 else 1e-4

      # training the net
      trainer.train(dims, [dataset], [iterations], summary_fn=summary_utils.poisson_image_summary, 
                    loss_fn=data_utils.get_mse, lr=lr, output_manipulation=manipulation,
                    patience = max(100, 10*hidden_layers)
                    )
      # defining validation function
      validation =summary_utils.poisson_image_validate  if net_params.description.find("numerical")>=0  else \
          lambda model_output, coords, dataset, layer_folder :\
           summary_utils.poisson_image_validate( model_output, coords, dataset, layer_folder, numerical=False, 
                                                lapl_factor = 0.05, grad_factor=2.5)

      # testing results
      trainer.test(dataset, validate_fn=validation)

  # reporting results
  mean, std = trainer.statistics(compute_PSNR=True)
  print(f"{net_params.description}:")
  for i in range(len(mean)):
    print(f"  {desc[i]}\t (mean, std): {mean[i]},{std[i]}")

## Training on laplacian

In [None]:
images = data_utils.get_div2k_images()

#defining experiments
nets_params = [
    NetworkParams(description='Poisson trained on laplacian with SIREN with numerical methods'), 
    NetworkParams(layer_class= ReLuLayer, description='Poisson trained on laplacian using ReLU activation with numerical methods')
               ]
iterations= 5000
hidden_layers = 2

desc= ["Image", "Gradient", "Laplacian"]

for net_params in nets_params:
  # defining trainer for every experiment
  trainer = Trainer(net_params, device=DEVICE)
  for image_name in images: 
       # defining dataset        
      image = data_utils.get_div2k_image(image_name, dir=BASE_DIR+"/", resolution= 'low')
      dataset = PoissonImageDataset(data_image=image, name=image_name,
                              fit_laplacian=True, normalized=True)   
      
       # computing omega_0* 
      first_omega_0 = torch.std(dataset.gt).detach()*250 if net_params.layer_class == SineLayer else 30
    
      # computing transformation on which to compute the loss
      manipulation = data_utils.get_manipulator("lapl_num") if net_params.description.find("numerical")>=0 \
          else data_utils.get_manipulator("lapl", 0.05)
      
      # defining network dimensions
      dims = NetworkDimensions(2, dataset.channels, hidden_layers=hidden_layers, first_omega_0=first_omega_0)

      # defining learning rate for the experiment
      lr= 5e-4 if net_params.layer_class == SineLayer else 1e-4

      # training the net
      trainer.train(dims, [dataset], [iterations], summary_fn=summary_utils.poisson_image_summary, 
                    loss_fn=data_utils.get_mse, lr=lr, output_manipulation=manipulation,
                    patience = max(100, 10*hidden_layers)
                    )
      # defining validation function
      validation =summary_utils.poisson_image_validate  if net_params.description.find("numerical")>=0  else \
          lambda model_output, coords, dataset, layer_folder :\
           summary_utils.poisson_image_validate( model_output, coords, dataset, layer_folder, numerical=False,
                                                lapl_factor = 0.05, grad_factor=2.5)

      # testing results
      trainer.test(dataset, validate_fn=validation)

  # reporting results
  mean, std = trainer.statistics(compute_PSNR=True)
  print(f"{net_params.description}:")
  for i in range(len(mean)):
    print(f"  {desc[i]}\t (mean, std): {mean[i]},{std[i]}")

# Exploiting super resolution

## Using bicubic method

In [None]:
UPSCALING = 4

images = data_utils.get_div2k_images()

os.makedirs(f"./plots/bicubic/", exist_ok=True)

results =[]

for image_name in images:
    # getting images
    image_hr = data_utils.get_image_tensor(data_utils.get_div2k_image(image_name, dir=BASE_DIR+"/"), down_scale=1)
    image_lr = data_utils.get_image_tensor(data_utils.get_div2k_image(image_name, dir=BASE_DIR+"/", resolution='low'), down_scale=1)
    channels, height, width = image_hr.shape
    
    # upsampling using bicubic
    super_resolution = Resize([int(height), int(width)], interpolation=PIL.Image.BICUBIC)
    output = super_resolution(image_lr)

    image = data_utils.to_hwc(image_hr)
    output = data_utils.to_hwc(torch.clamp(output, min=0, max=1))

    # measuring the results
    mse = data_utils.get_mse(image, output)
    PSNR = math_utils.PSNR(mse)

    if channels == 1:
        image = image.view(height, width)
        output = output.view(height, width)

    # plotting and saving results
    fig, axes = plt.subplots(1, 2, figsize=(18, 7))
    plt.suptitle("Bicubic Super Resolution", fontsize=15)
    axes[0].imshow(image.cpu().detach().numpy())
    axes[0].set_title("Ground truth")
    axes[1].imshow(output.cpu().detach().numpy())
    axes[1].set_title(f"Reconstruction x{UPSCALING}")
    plt.savefig(f"./plots/bicubic/{image_name}_x4.png")
    plt.show()

    matplotlib.image.imsave(f"./plots/bicubic/{image_name}_x4_reconstruction.png", output.detach().numpy())


    print(image_name,"\t mse: ", mse, " PSNR: ", PSNR)
    results.append(PSNR)

print(f"Bicubic SNR (mean,std): {numpy.mean(results)}, {numpy.std(results)}")

## Using EDSR

Testing results of EDSR on our dataset.

**NOTE: Pay attention that this code must be run using TPU and not GPU**

In [None]:
! git clone https://github.com/krasserm/super-resolution
! mv super-resolution EDSR

import sys

sys.path.insert(1, '/content/EDSR')

! wget https://martin-krasser.de/sisr/weights-edsr-16-x4.tar.gz

! tar xvfz weights-edsr-16-x4.tar.gz

In [None]:
from model import resolve_single
from model.common import psnr
from model.edsr import edsr
import tensorflow as tf
import statistics

from EDSR.utils import load_image, plot_sample

model = edsr(scale=4, num_res_blocks=16)
model.load_weights('weights/edsr-16-x4/weights.h5')

images = data_utils.get_div2k_images()
images_hr = [load_image(f"{BASE_DIR}/data/images/resized/{image}.png") for image in images]
images_lr = [load_image(f"{BASE_DIR}/data/images/resized/{image}x4.png")  for image in images]
p = []
for i in range(len(images)):
  lr = images_lr[i]
  sr = resolve_single(model, lr)

  gt = images_hr[i]
  ps = float(tf.get_static_value(psnr(gt, sr)))
  p.append(ps)
  print(ps)
  
  plot_sample(lr, sr)

print("PSNR (mean, std):", statistics.mean(p), ",", statistics.stdev(p))

## Using SIREN
Results are reported also under `plots/image_super_resolution/Super Resolution */results` to appreciate better differiencies between hr and ground_truth

### Derivation of $\omega_{HR}$
This is a run on training image. Same results are obtained with the whole DIV2K validation dataset (to avoid to overload the network for this experiment it has not been uploaded on the repository)

In [None]:
images = data_utils.get_div2k_images()
res =[]

for image_name in images:
    
    # getting omega_0* for low resolution images upsampled to HR with bicubic
    image = data_utils.get_div2k_image(image_name, dir='ai_ml_siren_sr/', resolution= 'low')
    lapl = PoissonImageDataset(data_image = image, name= image_name, fit_laplacian = True, up_scale=4)
    
    # getting omega_0 of testing HR images
    image_hr =data_utils.get_div2k_image(image_name, dir='ai_ml_siren_sr/', resolution= 'high')
    lapl_hr =  PoissonImageDataset(data_image = image_hr, name= image_name, fit_laplacian = True)
    
    # computing their relation
    res.append( torch.std(lapl.gt).detach()/torch.std(lapl_hr.gt).detach())
      
print(numpy.mean(res), numpy.std(res))

### Basic SIREN training

In [None]:
images = data_utils.get_div2k_images()

resolutions = [
        {"down": 1, "up": 1},
]
iterations = [5000]
net_params = NetworkParams(description="Super Resolution basic")

hidden_layers=2


trainer = Trainer(net_params, device='cuda')
for image_name in images:
    # getting image
    image = data_utils.get_div2k_image(image_name, dir='ai_ml_siren_sr/', resolution= 'low')

    # getting list of dataset (one element in basic case)
    datasets = list(map(lambda item: ImageFitting(data_image=image, name=image_name, normalized=True,
                                                    down_scale=item["down"],
                                                    up_scale=item["up"]     
                                                    ), resolutions))
    
    # computing omega_HR
    lapl = PoissonImageDataset(data_image = image, name= image_name, fit_laplacian = True, up_scale=4) 
    trainer.params.first_omega_0 =torch.std(lapl.gt).detach()*250/0.15
    
    dims = NetworkDimensions(2, datasets[0].channels, hidden_layers=hidden_layers, hidden_features=256)

    # training the network
    trainer.train(dims,datasets, iterations, summary_fn=summary_utils.image_super_resolution_summary, 
                  loss_fn=data_utils.get_mse, lr=5e-4,
                  regularization=5e-6, 
                  output_manipulation = data_utils.get_manipulator('grad_num', .10),
                  patience = max(100, 10*hidden_layers)
                  )
    # getting HR image
    image = data_utils.get_div2k_image(image_name, dir='ai_ml_siren_sr/')
    dataset_hr =  ImageFitting(data_image=image, name= image_name, normalized=True)

    # testing
    trainer.test(dataset_hr, validate_fn=summary_utils.image_super_resolution_validate)

# reporting results
mean, std = trainer.statistics(compute_PSNR=True)
for i in range (len(mean)):
    print(f"PSNR Mean {mean[i]} dB and std {std[i]} dB for Basic SIREN SR")

### SIREN training trick 1

In [None]:
images = data_utils.get_div2k_images()

resolutions = [
        {"down": 4, "up": 4},
        {"down": 3, "up": 3},
        {"down": 2, "up": 2},
        {"down": 1.5, "up": 1.5},
        {"down": 1, "up": 1},
        {"down": 1, "up": 1.5},
        {"down": 1, "up": 2},
]
iterations = [500,500,500,500,1000,1000,1000]
net_params = NetworkParams(description="Super Resolution trick1")

hidden_layers=2


trainer = Trainer(net_params, device='cuda')
for image_name in images:
    # getting image
    image = data_utils.get_div2k_image(image_name, dir='ai_ml_siren_sr/', resolution= 'low')

    # getting list of dataset (one element in basic case)
    datasets = list(map(lambda item: ImageFitting(data_image=image, name=image_name, normalized=True,
                                                    down_scale=item["down"],
                                                    up_scale=item["up"]     
                                                    ), resolutions))
    
    # computing omega_HR
    lapl = PoissonImageDataset(data_image = image, name= image_name, fit_laplacian = True, up_scale=4) 
    trainer.params.first_omega_0 =torch.std(lapl.gt).detach()*250/0.15
    
    dims = NetworkDimensions(2, datasets[0].channels, hidden_layers=hidden_layers, hidden_features=256)

    # training the network
    trainer.train(dims,datasets, iterations, summary_fn=summary_utils.image_super_resolution_summary, 
                  loss_fn=data_utils.get_mse, lr=5e-4,
                  regularization=5e-6, 
                  output_manipulation = data_utils.get_manipulator('grad_num', .10),
                  patience = max(100, 10*hidden_layers)
                  )
    # getting HR image
    image = data_utils.get_div2k_image(image_name, dir='ai_ml_siren_sr/')
    dataset_hr =  ImageFitting(data_image=image, name= image_name, normalized=True)

    # testing
    trainer.test(dataset_hr, validate_fn=summary_utils.image_super_resolution_validate)

# reporting results
mean, std = trainer.statistics(compute_PSNR=True)
for i in range (len(mean)):
    print(f"PSNR Mean {mean[i]} dB and std {std[i]} dB for SIREN SR trick1")

### SIREN training trick 2

In [None]:
images = data_utils.get_div2k_images()

resolutions = [
        {"down": 1, "up": 1},
]
iterations = [1000]

for i in range(1,21):
  # defining all small training steps
  resolutions.append({"down":1, "up":1+(0.1*i)})
  iterations.append(200)

net_params = NetworkParams(description="Super Resolution trick2")

hidden_layers=2


trainer = Trainer(net_params, device='cuda')
for image_name in images:
    # getting image
    image = data_utils.get_div2k_image(image_name, dir='ai_ml_siren_sr/', resolution= 'low')

    # getting list of dataset (one element in basic case)
    datasets = list(map(lambda item: ImageFitting(data_image=image, name=image_name, normalized=True,
                                                    down_scale=item["down"],
                                                    up_scale=item["up"]     
                                                    ), resolutions))
    
    # computing omega_HR
    lapl = PoissonImageDataset(data_image = image, name= image_name, fit_laplacian = True, up_scale=4) 
    trainer.params.first_omega_0 =torch.std(lapl.gt).detach()*250/0.15
    
    dims = NetworkDimensions(2, datasets[0].channels, hidden_layers=hidden_layers, hidden_features=256)

    # training the network
    trainer.train(dims,datasets, iterations, summary_fn=summary_utils.image_super_resolution_summary, 
                  loss_fn=data_utils.get_mse, lr=5e-4,
                  regularization=5e-6, 
                  output_manipulation = data_utils.get_manipulator('grad_num', .10),
                  patience = max(100, 10*hidden_layers)
                  )
    # getting HR image
    image = data_utils.get_div2k_image(image_name, dir='ai_ml_siren_sr/')
    dataset_hr =  ImageFitting(data_image=image, name= image_name, normalized=True)

    # testing
    trainer.test(dataset_hr, validate_fn=summary_utils.image_super_resolution_validate)

# reporting results
mean, std = trainer.statistics(compute_PSNR=True)
for i in range (len(mean)):
    print(f"PSNR Mean {mean[i]} dB and std {std[i]} dB for SIREN SR trick2")

# Ablation studies

In [None]:
# run this code before ablation study execution
from core.network import Network

## Baseline for activation distributions

Analyzing activations and spectrum under Sitzmann's initialization

In [None]:
  dims = NetworkDimensions(in_features=1, out_features=1, hidden_layers=2, hidden_features=2048, first_omega_0=30)
  params = NetworkParams(outermost_linear=True)
  model = Network(params=params, dimensions=dims)

  input_signal = torch.linspace(-1, 1, 65536//4).view(1, -1, 1)
  activations = model.forward_with_activations(input_signal, retain_grad=True)

  output = activations[next(reversed(activations))]

  output.mean().backward()

  data_utils.plot_all_activations_and_grads(activations)

## First layer $\omega_0$

### $\omega_0 = 1$

In [None]:
omega_0 = 1
print(f"Network with omega 0={omega_0}")
dims = NetworkDimensions(in_features=1, out_features=1, hidden_layers=2, hidden_features=2048, first_omega_0=omega_0)
params = NetworkParams(outermost_linear=True)
model = Network(params=params, dimensions=dims)

input_signal = torch.linspace(-1, 1, 65536//4).view(1, -1, 1)
# generating the output and activations for a uniform input
activations = model.forward_with_activations(input_signal, retain_grad=True)
output = activations[next(reversed(activations))]
output.mean().backward()

# plot activations at every layer
data_utils.plot_all_activations_and_grads(activations)

### $\omega_0 = 30$

In [None]:
omega_0 = 30
print(f"Network with omega 0={omega_0}")
dims = NetworkDimensions(in_features=1, out_features=1, hidden_layers=2, hidden_features=2048, first_omega_0=omega_0)
params = NetworkParams(outermost_linear=True)
model = Network(params=params, dimensions=dims)

input_signal = torch.linspace(-1, 1, 65536//4).view(1, -1, 1)
# generating the output and activations for a uniform input
activations = model.forward_with_activations(input_signal, retain_grad=True)
output = activations[next(reversed(activations))]
output.mean().backward()

# plot activations at every layer
data_utils.plot_all_activations_and_grads(activations)

### $\omega_0 = 1000$

In [None]:
omega_0 = 1000
print(f"Network with omega 0={omega_0}")
dims = NetworkDimensions(in_features=1, out_features=1, hidden_layers=2, hidden_features=2048, first_omega_0=omega_0)
params = NetworkParams(outermost_linear=True)
model = Network(params=params, dimensions=dims)

input_signal = torch.linspace(-1, 1, 65536//4).view(1, -1, 1)
# generating the output and activations for a uniform input
activations = model.forward_with_activations(input_signal, retain_grad=True)
output = activations[next(reversed(activations))]
output.mean().backward()

# plot activations at every layer
data_utils.plot_all_activations_and_grads(activations)

### Testing images with different $\omega_0$

In [None]:
iterations= 500
hidden_layers = 2

# 91 is the value of omega_0* (discussed in the report) for the image under analysis 
omega_values = [1, 30, 91, 1000, 2000] 
image_name = "0803"
results = []
for omega_0 in omega_values:
  # Load image
  image = data_utils.get_div2k_image(image_name, dir=BASE_DIR+"/", resolution='low')
  dataset = ImageFitting(data_image=image, name=image_name, normalized=True)

  # Prepare the trainer
  dims = NetworkDimensions(2, dataset.channels, hidden_layers=hidden_layers, first_omega_0 = omega_0)
  trainer = Trainer(NetworkParams(description=f"siren_w{omega_0}"), device=DEVICE)
  trainer.train(dims, [dataset], [iterations], summary_fn=summary_utils.image_fitting_summary, 
                loss_fn=data_utils.get_mse, lr=5e-4,
                patience = max(100, 10*hidden_layers),
                with_log=False)
  trainer.test(dataset, validate_fn=summary_utils.image_fitting_validate)
  
  # Load and save result
  mean = trainer.statistics(compute_PSNR=True)[0][0]
  results.append(mean)
  print(f"PSNR Mean {mean} dB for omega_0 = {omega_0}")

In [None]:
# Plot results
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
ax.bar(['1', '30', '$\omega_0^*$','1000', '2000'],results)

plt.show()

## First layer initialization

#### He initialization

In [None]:
# initialization function definition
init = lambda weights : nn.init.kaiming_normal_(weights, a=0.0, nonlinearity='relu', mode='fan_in')

dims = NetworkDimensions(in_features=1, out_features=1, hidden_layers=2, hidden_features=2048, first_omega_0=30)
params = NetworkParams(outermost_linear=True, first_init=init)
model = Network(params=params, dimensions=dims)

input_signal = torch.linspace(-1, 1, 65536//4).view(1, -1, 1)
# generating the output and activations for a uniform input
activations = model.forward_with_activations(input_signal, retain_grad=True)
output = activations[next(reversed(activations))]
output.mean().backward()

# plot activations at every layer
data_utils.plot_all_activations_and_grads(activations)

#### Xavier initialization

In [None]:
# initialization function definition
init =  lambda weights : nn.init.xavier_uniform_(weights)

dims = NetworkDimensions(in_features=1, out_features=1, hidden_layers=2, hidden_features=2048, first_omega_0=30)
params = NetworkParams(outermost_linear=True, first_init=init)
model = Network(params=params, dimensions=dims)

input_signal = torch.linspace(-1, 1, 65536//4).view(1, -1, 1)
# generating the output and activations for a uniform input
activations = model.forward_with_activations(input_signal, retain_grad=True)
output = activations[next(reversed(activations))]
output.mean().backward()

# plot activations at every layer
data_utils.plot_all_activations_and_grads(activations)

### Testing images with different initializations of first layer

In [None]:
iterations= 500
hidden_layers = 2
image_name = "0803"
inits = [None, lambda weights : nn.init.kaiming_normal_(weights, a=0.0, nonlinearity='relu', mode='fan_in'), lambda weights : nn.init.xavier_uniform_(weights) ]
descriptions = ["Sitzmann", "He", "Xavier"]

for init,description in zip(inits, descriptions):
  # training the network on an image for every initialization scheme
  trainer = Trainer(NetworkParams(description=f"first_init_{description}", first_init=init), device=DEVICE)
  image = data_utils.get_div2k_image(image_name, dir=BASE_DIR+"/", resolution='low')
  dataset = ImageFitting(data_image=image, name= image_name, normalized=True)
  
  dims = NetworkDimensions(2, dataset.channels, hidden_layers=hidden_layers, first_omega_0 = 96)

  trainer.train(dims, [dataset], [iterations], summary_fn=summary_utils.image_fitting_summary, 
                loss_fn=data_utils.get_mse, lr=5e-4,
                patience = max(100, 10*hidden_layers))
  trainer.test(dataset, validate_fn=summary_utils.image_fitting_validate)


  mean, std = trainer.statistics(compute_PSNR=True)
  print(f"PSNR Mean {mean[0]} dB and std {std[0]} dB for first init = {description}")

## Hidden layers initialization

#### He initialization

In [None]:
 # initialization function definition
init = lambda weights : nn.init.kaiming_normal_(weights, a=0.0, nonlinearity='relu', mode='fan_in')

dims = NetworkDimensions(in_features=1, out_features=1, hidden_layers=2, hidden_features=2048, first_omega_0=30)
params = NetworkParams(outermost_linear=True, hidden_init=init)
model = Network(params=params, dimensions=dims)

input_signal = torch.linspace(-1, 1, 65536//4).view(1, -1, 1)
# generating the output and activations for a uniform input
activations = model.forward_with_activations(input_signal, retain_grad=True)
output = activations[next(reversed(activations))]
output.mean().backward()

# plot activations at every layer
data_utils.plot_all_activations_and_grads(activations)

#### Xavier initialization

In [None]:
# initialization function definition
init =  lambda weights : nn.init.xavier_uniform_(weights)

dims = NetworkDimensions(in_features=1, out_features=1, hidden_layers=2, hidden_features=2048, first_omega_0=30)
params = NetworkParams(outermost_linear=True, hidden_init=init)
model = Network(params=params, dimensions=dims)

input_signal = torch.linspace(-1, 1, 65536//4).view(1, -1, 1)
# generating the output and activations for a uniform input
activations = model.forward_with_activations(input_signal, retain_grad=True)
output = activations[next(reversed(activations))]
output.mean().backward()

# plot activations at every layer
data_utils.plot_all_activations_and_grads(activations)

### Testing images with different initializations of hidden layers

In [None]:
iterations= 500
hidden_layers = 2
image_name = "0803"
inits = [None, lambda weights : nn.init.kaiming_normal_(weights, a=0.0, nonlinearity='relu', mode='fan_in'), lambda weights : nn.init.xavier_uniform_(weights) ]
descriptions = ["Sitzmann", "He", "Xavier"]

for init,description in zip(inits, descriptions):
  # training the network on an image for every initialization scheme
  trainer = Trainer(NetworkParams(description=f"hidden_init_{description}", hidden_init=init), device=DEVICE)
  image = data_utils.get_div2k_image(image_name, dir=BASE_DIR+"/", resolution='low')
  dataset = ImageFitting(data_image=image, name= image_name, normalized=True)
  
  dims = NetworkDimensions(2, dataset.channels, hidden_layers=hidden_layers, first_omega_0 = 30)

  trainer.train(dims, [dataset], [iterations], summary_fn=summary_utils.image_fitting_summary, 
                loss_fn=data_utils.get_mse, lr=1e-4,
                patience = max(100, 10*hidden_layers))
  trainer.test(dataset, validate_fn=summary_utils.image_fitting_validate)


  mean, std = trainer.statistics(compute_PSNR=True)
  print(f"PSNR Mean {mean[0]} dB and std {std[0]} dB for hidden init = {description}")