In [None]:
#This notebook will allow you to run INR using a prior using measurements generated
#by a gaussian measurement matrix A

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from typing import Tuple, List
import torch
import pickle
from NeRP_main.networks import Positional_Encoder, FFN, SIREN
from NeRP_main.utils import get_config, prepare_sub_folder, get_data_loader
from torch.utils.data import TensorDataset, DataLoader, Dataset
from NeRP_main.data import create_grid

from my_utils import *

In [None]:
#Configs
config = get_config("gauss_recon.yaml") #Most settings should be changed from in this file
device = torch.device('cpu')
sample_rate = 309 #30%
prior = False #Enable this to use a low pass prior
better_prior = False #Enable this to use the original image plus some additive noise as prior
display = False #Set this to true to print the training and test PSNR values as the model is training
if better_prior:
    #Parameters for Gaussian noise
    mean = 0
    sigma = 0.05
img_size = config["img_size"]
max_iter = config['max_iter']

In [None]:
gaussian10 = gaussian_matrix((sample_rate,1024)) 
A = normalize_max_row_norm(gaussian10) #Create gaussian matrix

In [None]:
#Open the data
with open('/Users/sarahhagan/Desktop/NeRP Research/gray_cifar-100.pkl', 'rb') as f:
    data = pickle.load(f)

In [None]:
print(data.shape)
gaussian_observations = generate_observations(data, A, sigma=0.025)
num_imgs = len(gaussian_observations['X'])
print("Number of images: ", num_imgs)

In [None]:
print("Vectorised Image Shape: ", gaussian_observations['Y'][0].shape)
y = sensor_output(torch.from_numpy(gaussian_observations['Y'][0]).to(torch.float32), torch.from_numpy(A).to(torch.float32))

In [None]:
A_tensor = torch.from_numpy(A).to(torch.float32)
print(A_tensor.shape)
learned_models = []
prior_train_psnr_list = []
prior_test_psnr_list = []
target_train_psnr_list = []
target_test_psnr_list = []

if prior:
    #Create DCT matrix to be used with all images
    Fm = dct_matrix(img_size)
    cutoff_idx = config['cutoff_idx'] #Low Pass Filter Cutoff
    FmI = Fm[:cutoff_idx] #Create Low Pass Filtered DCT transforms
    FmI_T = FmI.transpose() 
    B = np.kron(FmI_T, FmI_T) #Kronecker Product
    print("B:", B.shape)
    AB = np.matmul(A, B)

for i in range(0, 50):
    truth = gaussian_observations['Y'][i].reshape(img_size,img_size)
    measurements = gaussian_observations['X'][i]
    print("--------------------------------- Image ", i, "---------------------------------")
    plt.imshow(truth, cmap='gray')
    plt.title("Ground Truth")
    plt.show()
    
    # Setup input encoder:
    encoder = Positional_Encoder(config['encoder'])
    # Setup model
    if config['model'] == 'SIREN':
        model = SIREN(config['net'])
    elif config['model'] == 'FFN':
        model = FFN(config['net'])
    else:
        raise NotImplementedError
    model.to(device)
    model.train()
    #Set up optimiser
    if config['optimizer'] == 'Adam':
            optim = torch.optim.Adam(model.parameters(), lr=config['lr'], betas=(config['beta1'], config['beta2']), weight_decay=config['weight_decay'])
    else:
        NotImplementedError

    # Setup loss function
    if config['loss'] == 'L2':
        loss_fn = torch.nn.MSELoss()
    elif config['loss'] == 'L1':
        loss_fn = torch.nn.L1Loss()
    else:
        NotImplementedError

    if prior:
        print("Prior:")
        #Create a "prior image" using DCT
        v_x_tilde, residuals, rank, singular_values = np.linalg.lstsq(AB, measurements) #Least Squares
        x_tilde = undo_vec(v_x_tilde, cutoff_idx, cutoff_idx) #Solution in the Frequency Domain
        x_hat = np.matmul(np.matmul(FmI_T, x_tilde), FmI) #Solution in the Image Domain
        
        plt.imshow(x_hat, cmap="gray")
        plt.title("Prior Reconstruction")
        plt.show()

        prior_loss = 0.5 * np.sum((x_hat - truth)**2)/(img_size**2)
        prior_psnr = - 10 * np.log10(2 * prior_loss).item()
        print("Prior PSNR: ", prior_psnr)

        #Embed the prior image
        prior_img = torch.from_numpy(x_hat).to(torch.float32).reshape(1, img_size, img_size, 1)
        dataset = myImageDataset_2D(truth,img_size)
        data_loader = DataLoader(dataset=dataset, 
                            batch_size=config["batch_size"], 
                            shuffle=True, 
                            drop_last=True, 
                            num_workers=0)
        model, prior_train_psnr, prior_test_psnr = train_gauss(model, optim, loss_fn, data_loader, prior_img, encoder, config, A_tensor=A_tensor, display=display, display_img=False, learn_from_proj=False, device="cpu", max_iter=max_iter)
        prior_train_psnr_list.append(prior_train_psnr)
        prior_test_psnr_list.append(prior_test_psnr)

    #Embed the second prior
    if better_prior:
        print("Better Prior:")
        #Create a "prior image" by adding noise to original image
        
        # Generate Gaussian noise with the same shape as the image
        gaussian_noise = np.random.normal(mean, sigma, truth.shape).astype(np.float32)
        
        # Add the noise to the image
        x_hat = cv2.add(truth, gaussian_noise)
        
        plt.imshow(x_hat, cmap="gray")
        plt.title("Better Prior")
        plt.show()

        prior_loss = 0.5 * np.sum((x_hat - truth)**2)/(img_size**2)
        prior_psnr = - 10 * np.log10(2 * prior_loss).item()
        print("Better Prior PSNR: ", prior_psnr)

        #Embed the prior image
        prior_img = torch.from_numpy(x_hat).to(torch.float32).reshape(1, img_size, img_size, 1)
        dataset = myImageDataset_2D(truth,img_size)
        data_loader = DataLoader(dataset=dataset, 
                            batch_size=config["batch_size"], 
                            shuffle=True, 
                            drop_last=True, 
                            num_workers=0)
        model, prior_train_psnr, prior_test_psnr = train_gauss(model, optim, loss_fn, data_loader, prior_img, encoder, config, A_tensor=A_tensor, display=display, display_img=False, learn_from_proj=False, device="cpu", max_iter=max_iter)
        prior_train_psnr_list.append(prior_train_psnr)
        prior_test_psnr_list.append(prior_test_psnr)
        
    #Train from the sensor measurements
    print("Target:")
    measurements = torch.from_numpy(measurements).to(torch.float32)
    inverse_dataset = myImageDataset_2D(truth,img_size)
    inverse_data_loader = DataLoader(dataset=inverse_dataset, 
                        batch_size=config["batch_size"], 
                        shuffle=True, 
                        drop_last=True, 
                        num_workers=0)
    model, target_train_psnr, target_test_psnr = train_gauss(model, optim, loss_fn, inverse_data_loader, measurements, encoder, config, A_tensor=A_tensor, display=display, display_img=False, learn_from_proj=True, device="cpu", max_iter=3000)
    learned_models.append(model)
    target_train_psnr_list.append(target_train_psnr)
    target_test_psnr_list.append(target_test_psnr)

In [None]:
#Show the graphs of the PSNRs
num_imgs = len(target_train_psnr_list)
if prior or better_prior:
    prior_train_indices = range(0, len(prior_train_psnr_list[0])*config['log_iter'], config['log_iter'])
    prior_test_indices = range(0, len(prior_test_psnr_list[0])*config['val_iter'], config['val_iter'])
target_train_indices = range(0, len(target_train_psnr_list[0])*config['log_iter'], config['log_iter'])
target_test_indices = range(0, len(target_test_psnr_list[0])*config['val_iter'], config['val_iter'])
colors = plt.get_cmap('tab20', num_imgs)

#Prior Graph
if prior or better_prior:
    plt.figure(figsize=(10, 6))
    for i in range(num_imgs):
        color = colors(i)
        plt.plot(prior_train_indices, prior_train_psnr_list[i], color=color, linestyle='--')
        plt.plot(prior_test_indices, prior_test_psnr_list[i], color=color, linestyle='-')
    plt.title('Test vs Train PSNR for Prior images')
    plt.xlabel('Iteration')
    plt.ylabel('PSNR (dB)')
    plt.grid(True)
    plt.show()

#Target graph
plt.figure(figsize=(10, 6))
for i in range(num_imgs):
    color = colors(i)
    plt.plot(target_train_indices, target_train_psnr_list[i], color=color, linestyle='--')
    plt.plot(target_test_indices, target_test_psnr_list[i], color=color, linestyle='-')
plt.title('Test vs Train PSNR for Target images')
plt.xlabel('Iteration')
plt.ylabel('PSNR (dB)')
plt.grid(True)
plt.show()

In [None]:
if prior or better_prior:
    avg_prior_train_psnr = sum([lst[-1] for lst in prior_train_psnr_list])/len(prior_train_psnr_list)
    avg_prior_test_psnr = sum([lst[-1] for lst in prior_test_psnr_list])/len(prior_test_psnr_list)
avg_target_train_psnr = sum([lst[-1] for lst in target_train_psnr_list])/len(target_train_psnr_list)
avg_target_test_psnr = sum([lst[-1] for lst in target_test_psnr_list])/len(target_test_psnr_list)

In [None]:
if prior or better_prior:
    print(f"Avg Prior Train PSNR: {avg_prior_train_psnr:.2f}")
    print(f"Avg Prior Test PSNR:  {avg_prior_test_psnr:.2f}")
print(f"Avg Target Train PSNR: {avg_target_train_psnr:.2f}")
print(f"Avg Target Test PSNR:  {avg_target_test_psnr:.2f}")