In [None]:
#This notebook will allow you to run INR using a prior from multi-look measurements with Speckle Noise
#Paramters should be configured in the speckle_recon config file

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from typing import Tuple, List
import torch
import pickle
from NeRP_main.networks import Positional_Encoder, FFN, SIREN
from NeRP_main.utils import get_config, prepare_sub_folder, get_data_loader
from torch.utils.data import TensorDataset, DataLoader, Dataset
from NeRP_main.data import create_grid
import torch.nn as nn

from my_utils import *
from multilook_utils import *

In [None]:
#Configs
config = get_config("speckle_recon.yaml")
device = torch.device('cpu')
img_size = config["img_size"]
max_iter = config['max_iter']
num_looks = config["num_looks"]
prior = True #Enable to use a prior generated using a low pass filter
better_prior = True #Enable to use a prior generated by adding noise to the original image. This can be used sequentially
if better_prior:
    #Parameters for Gaussian noise
    mean = 0
    sigma = 0.05
display = False #Enable to see the training/test losses and PSNR while the model is training

In [None]:
# Load the data - no additive noise
data = np.load("/Users/sarahhagan/Desktop/NeRP Research/Speckle_noise_datasets/coherent_just_speckle_cifar.npz")
A_loaded = data["A"] # m * n
X_loaded = data["X"] # N * n
Y_loaded = data["Y"] # N * L * m
m = A_loaded.shape[0]
n = A_loaded.shape[1]
N = X_loaded.shape[0]
L = Y_loaded.shape[1]
print("m =",m,"n =",n,"N =",N,"L =",L)

In [None]:
class SimpleSpeckleLoss(nn.Module):
    def __init__(self):
        super(SimpleSpeckleLoss, self).__init__()

    def forward(self, input, targets, n=1):
        # Compute the loss
        epsilon = 1e-8
        clamped = torch.clamp(input, min=epsilon)
        loss = 0
        for i in range(n):
            loss += torch.sum(torch.log(clamped) + (targets[i]**2) / (2 * clamped**2))
        return loss/n

In [None]:
A_tensor = torch.from_numpy(A_loaded).to(torch.float32)
learned_models = []
prior_train_psnr_list = []
prior_test_psnr_list = []
better_prior_train_psnr_list = []
better_prior_test_psnr_list = []
target_train_psnr_list = []
target_test_psnr_list = []

if prior:
    #Create DCT matrix to be used with all images
    Fm = dct_matrix(img_size)
    cutoff_idx = config['cutoff_idx'] #Low Pass Filter Cutoff
    FmI = Fm[:cutoff_idx] #Create Low Pass Filtered DCT transforms
    FmI_T = FmI.transpose() 
    B = np.kron(FmI_T, FmI_T) #Kronecker Product
    print("B:", B.shape)
    AB = np.matmul(A_loaded, B)

#Change the range to alter which images to run INR over
for i in range(0, 50):
    truth = X_loaded[i].reshape(img_size,img_size)
    measurements = Y_loaded[i][:num_looks]
    print(measurements.shape)
    print(A_tensor.shape)
    print("--------------------------------- Image ", i, "---------------------------------")
    plt.imshow(truth, cmap='gray')
    plt.title("Ground Truth")
    plt.show()
    plt.imshow(np.abs(measurements[0].reshape(img_size,img_size)), cmap='gray')
    plt.title("With Speckle")
    plt.show()

    # Setup input encoder:
    encoder = Positional_Encoder(config['encoder'])
    # Setup model
    if config['model'] == 'SIREN':
        model = SIREN(config['net'])
    elif config['model'] == 'FFN':
        model = FFN(config['net'])
    else:
        raise NotImplementedError
    model.to(device)
    model.train()
    #Set up optimiser
    if config['optimizer'] == 'Adam':
            optim = torch.optim.Adam(model.parameters(), lr=config['lr'], betas=(config['beta1'], config['beta2']), weight_decay=config['weight_decay'])
    else:
        NotImplementedError

    # Setup loss function
    if config['loss'] == 'L2':
        loss_fn = torch.nn.MSELoss()
    elif config['loss'] == 'L1':
        loss_fn = torch.nn.L1Loss()
    elif config['loss'] == 'simple_speckle':
        loss_fn = SimpleSpeckleLoss()
    else:
        NotImplementedError

    if prior:
        if config['loss'] == 'L2' or config['loss'] == 'simple_speckle':
            prior_loss_fn = torch.nn.MSELoss()
        elif config['loss'] == 'L1':
            prior_loss_fn = torch.nn.L1Loss()
        else:
            NotImplementedError
        print("Prior:")
        #Create a "prior image" using DCT
        prior_measurements = 0
        for i in range(num_looks):
            prior_measurements += measurements[i]**2 
        prior_measurements /= num_looks
        prior_measurements = prior_measurements**0.5
        v_x_tilde, residuals, rank, singular_values = np.linalg.lstsq(AB, np.abs(prior_measurements)) #Least Squares
        x_tilde = undo_vec(v_x_tilde, cutoff_idx, cutoff_idx) #Solution in the Frequency Domain
        x_hat = np.matmul(np.matmul(FmI_T, x_tilde), FmI) #Solution in the Image Domain
        
        plt.imshow(x_hat, cmap="gray")
        plt.title("Prior Reconstruction")
        plt.show()

        prior_loss = 0.5 * np.sum((x_hat - truth)**2)/(img_size**2)
        prior_psnr = - 10 * np.log10(2 * prior_loss).item()
        print("Prior PSNR: ", prior_psnr)

        #Embed the prior image
        prior_img = torch.from_numpy(x_hat).to(torch.float32).reshape(1, img_size, img_size, 1)
        dataset = myImageDataset_2D(truth,img_size)
        data_loader = DataLoader(dataset=dataset, 
                            batch_size=config["batch_size"], 
                            shuffle=True, 
                            drop_last=True, 
                            num_workers=0)
        model, prior_train_psnr, prior_test_psnr = train_gauss(model, optim, prior_loss_fn, data_loader, prior_img, encoder, config, A_tensor=A_tensor, display=display, display_img=False, learn_from_proj=False, device="cpu", max_iter=max_iter)
        prior_train_psnr_list.append(prior_train_psnr)
        prior_test_psnr_list.append(prior_test_psnr)
        
    if better_prior:
        if config['loss'] == 'L2' or config['loss'] == 'simple_speckle':
            prior_loss_fn = torch.nn.MSELoss()
        elif config['loss'] == 'L1':
            prior_loss_fn = torch.nn.L1Loss()
        else:
            NotImplementedError
        print("Prior:")
        
        #Create a "prior image" by adding noise to the ground truth
        
        # Generate Gaussian noise with the same shape as the image
        gaussian_noise = np.random.normal(mean, sigma, truth.shape).astype(np.float32)
        
        # Add the noise to the image
        x_hat = np.add(truth, gaussian_noise)
        
        plt.imshow(x_hat, cmap="gray")
        plt.title("Better Prior")
        plt.show()

        prior_loss = 0.5 * np.sum((x_hat - truth)**2)/(img_size**2)
        prior_psnr = - 10 * np.log10(2 * prior_loss).item()
        print("Prior PSNR: ", prior_psnr)

        #Embed the prior image
        prior_img = torch.from_numpy(x_hat).to(torch.float32).reshape(1, img_size, img_size, 1)
        dataset = myImageDataset_2D(truth,img_size)
        data_loader = DataLoader(dataset=dataset, 
                            batch_size=config["batch_size"], 
                            shuffle=True, 
                            drop_last=True, 
                            num_workers=0)
        model, prior_train_psnr, prior_test_psnr = train_gauss(model, optim, loss_fn, data_loader, prior_img, encoder, config, A_tensor=A_tensor, display=True, display_img=False, learn_from_proj=False, device="cpu", max_iter=max_iter)
        better_prior_train_psnr_list.append(prior_train_psnr)
        better_prior_test_psnr_list.append(prior_test_psnr)
        
    #Train from the sensor measurements
    print("Target:")
    measurements = torch.from_numpy(measurements).to(torch.float32)
    measurements = measurements.reshape(num_looks, 1, img_size, img_size, 1)#remove this line later
    inverse_dataset = myImageDataset_2D(truth,img_size)
    inverse_data_loader = DataLoader(dataset=inverse_dataset, 
                        batch_size=config["batch_size"], 
                        shuffle=True, 
                        drop_last=True, 
                        num_workers=0)
    model, target_train_psnr, target_test_psnr = train_multilook_gauss(model, optim, loss_fn, inverse_data_loader, measurements, encoder, config, A_tensor=A_tensor, display=True, display_img=False, learn_from_proj=False, device="cpu", max_iter=max_iter)
    learned_models.append(model)
    target_train_psnr_list.append(target_train_psnr)
    target_test_psnr_list.append(target_test_psnr)

In [None]:
#Show the graphs of the PSNRs
num_imgs = len(target_train_psnr_list)
num_imgs =1
if prior:
    prior_train_indices = range(0, len(prior_train_psnr_list[0])*config['log_iter'], config['log_iter'])
    prior_test_indices = range(0, len(prior_test_psnr_list[0])*config['val_iter'], config['val_iter'])
if better_prior:
    better_prior_train_indices = range(0, len(better_prior_train_psnr_list[0])*config['log_iter'], config['log_iter'])
    better_prior_test_indices = range(0, len(better_prior_test_psnr_list[0])*config['val_iter'], config['val_iter'])
target_train_indices = range(0, len(target_train_psnr_list[0])*config['log_iter'], config['log_iter'])
target_test_indices = range(0, len(target_test_psnr_list[0])*config['val_iter'], config['val_iter'])
colors = plt.get_cmap('tab20', num_imgs)

#Prior Graph
if prior:
    plt.figure(figsize=(10, 6))
    for i in range(num_imgs):
        color = colors(i)
        plt.plot(prior_train_indices, prior_train_psnr_list[i], color=color, linestyle='--')
        plt.plot(prior_test_indices, prior_test_psnr_list[i], color=color, linestyle='-')
    plt.title('Test vs Train PSNR for Prior images')
    plt.xlabel('Iteration')
    plt.ylabel('PSNR (dB)')
    plt.grid(True)
    plt.show()
    
if prior:
    plt.figure(figsize=(10, 6))
    for i in range(num_imgs):
        color = colors(i)
        plt.plot(better_prior_train_indices, better_prior_train_psnr_list[i], color=color, linestyle='--')
        plt.plot(better_prior_test_indices, better_prior_test_psnr_list[i], color=color, linestyle='-')
    plt.title('Test vs Train PSNR for Better Prior images')
    plt.xlabel('Iteration')
    plt.ylabel('PSNR (dB)')
    plt.grid(True)
    plt.show()

#Target graph
plt.figure(figsize=(10, 6))
for i in range(num_imgs):
    color = colors(i)
    plt.plot(target_train_indices, target_train_psnr_list[i], color=color, linestyle='--')
    plt.plot(target_test_indices, target_test_psnr_list[i], color=color, linestyle='-')
plt.title('Test vs Train PSNR for Target images')
plt.xlabel('Iteration')
plt.ylabel('PSNR (dB)')
plt.grid(True)
plt.show()

In [None]:
if prior:
    avg_prior_train_psnr = sum([lst[-1] for lst in prior_train_psnr_list])/len(prior_train_psnr_list)
    avg_prior_test_psnr = sum([lst[-1] for lst in prior_test_psnr_list])/len(prior_test_psnr_list)
avg_target_train_psnr = sum([lst[-1] for lst in target_train_psnr_list])/len(target_train_psnr_list)
avg_target_test_psnr = sum([lst[-1] for lst in target_test_psnr_list])/len(target_test_psnr_list)

In [None]:
if prior or better_prior:
    print(f"Avg Prior Train PSNR: {avg_prior_train_psnr:.2f}")
    print(f"Avg Prior Test PSNR:  {avg_prior_test_psnr:.2f}")
print(f"Avg Target Train PSNR: {avg_target_train_psnr:.2f}")
print(f"Avg Target Test PSNR:  {avg_target_test_psnr:.2f}")