In [3]:
import sys
sys.path.append('../datasets')

import torch
import hyde 

import matplotlib.pyplot as plt

from datasets import HSIDataset
from torchvision import transforms
from math import sqrt

In [38]:
# Define device (default is "cpu")
device = "cuda" 

# Define dtype
dtype = torch.float32

# Define random seed
seed = 42
torch.manual_seed(seed)

# Define data path
data_path = '/home/mhiriy/data/harvard.zarr'

# Signal to Noise Ratio
SNR = 10

In [39]:
# Load Datasets 
val_transform = transforms.Compose([transforms.ToTensor()]) # Transforms a the input data to torch tensors
dataset = HSIDataset(root_dir=data_path, split='train', transform=val_transform)

# Choose the index of the selected image in the dataset (ENMAP dataset contains 11 000 images)
idx = 43
x = dataset[idx].unsqueeze(0).to(device=device,dtype=dtype) # import image to device (cpu or gpu), sizes of x is [1,number of bands, width, height]
# Adds a small amount of white gaussian noise to the input HSI (sigma2 = 1e-4)
x += 1e-2*torch.rand_like(x, device=device, dtype=dtype)/torch.norm(x)

# Adds noise to the input HSI
sigma2 = 10**(-SNR/10) * torch.norm(x, dim=[2,3])**2 / x.shape[2] / x.shape[3]
sigma2 = sigma2.unsqueeze(0).unsqueeze(1).reshape(1, sigma2.numel(), 1, 1)
sigma2 = sigma2.repeat(1, 1, x.shape[2], x.shape[3])

y = x + torch.sqrt(sigma2)*torch.randn_like(x, device=device, dtype=dtype)


In [40]:
## FastHyDe method with a BM3D denoiser
method = hyde.FastHyDe()
output_hyde = method(y.squeeze(0).permute(1,2,0), noise_type="additive", iid=True, k_subspace=10, normalize=True)
output_hyde = output_hyde.permute(2,0,1).unsqueeze(0)


In [None]:
## Using a pretrained model

path_to_model = '..//algorithms/HyDe/pretrained-models/qrnn3d/hyde-bs16-blindSNR-gaussian-qrnn3d-l2.pth'
method = hyde.NNInference(arch="qrnn3d", pretrained_file=path_to_model)
output_qrnn3D = method(y.unsqueeze(0), band_dim=1, permute=False).squeeze(0)


In [None]:


out = output_qrnn3D
cmap = 'bone'

band_index = [0, 10, 50, 100, 200]

plt.figure(figsize=(30, 5))
for i, idx in enumerate(band_index):
    plt.subplot(151+i)
    plt.imshow(x[0, idx, ...].cpu().numpy(), cmap=cmap)
    plt.colorbar()
    plt.title('Ground Truth HSI - bands = {}'.format(idx))
    plt.axis('off')

plt.figure(figsize=(30, 5))
for i, idx in enumerate(band_index):
    plt.subplot(151+i)
    plt.imshow(y[0, idx, ...].cpu().numpy(), cmap=cmap)
    plt.colorbar()
    plt.title('Noisy HSI - band = {}'.format(idx))
    plt.axis('off')

plt.figure(figsize=(30, 5))
for i, idx in enumerate(band_index):
    plt.subplot(151+i)
    plt.imshow(out[0, idx, ...].cpu().numpy(), cmap=cmap)
    plt.colorbar()
    plt.title('Denoised HSI - band = {}'.format(idx))
    plt.axis('off')

plt.show()

In [None]:
from datasets import get_eigenimages

z = get_eigenimages(x)
z_est = get_eigenimages(output_qrnn3D)
z_noisy = get_eigenimages(y)

eigen_index = [1,3,5,7,9]

cmap = 'bone'

plt.figure(figsize=(30, 5))
for i, idx in enumerate(eigen_index):
    plt.subplot(151+i)
    plt.imshow(z[0, idx, ...].cpu().numpy(), cmap=cmap)
    plt.colorbar()
    plt.title('Eigenimage {}'.format(idx))
    plt.axis('off')

plt.figure(figsize=(30, 5))
for i, idx in enumerate(eigen_index):
    plt.subplot(151+i)
    plt.imshow(z_noisy[0, idx, ...].cpu().numpy(), cmap=cmap)
    plt.colorbar()
    plt.title('Noisy Eigenimage {}'.format(idx))
    plt.axis('off')

plt.figure(figsize=(30, 5))
for i, idx in enumerate(eigen_index):
    plt.subplot(151+i)
    plt.imshow(z_est[0, idx, ...].cpu().numpy(), cmap=cmap)
    plt.colorbar()
    plt.title('Denoised Eigenimage {}'.format(idx))
    plt.axis('off')

plt.show()