In [3]:
import sys
sys.path.append('../datasets')

import torch

import matplotlib.pyplot as plt

from datasets import HSIDataset

from torchvision import transforms
from math import sqrt

In [4]:
# Define device (default is "cpu")
device = "cpu" 

# Define dtype
dtype = torch.float64

# Define random seed
seed = 42
torch.manual_seed(seed)

# Define data path
data_path = '/home/mhiriy/data/harvard.zarr'


In [5]:
val_transform = transforms.Compose([transforms.ToTensor()]) # Transforms a the input data to torch tensors
dataset = HSIDataset(root_dir=data_path, split='train', transform=val_transform)

In [24]:

# Choose the index of the selected image in the dataset
# data_idx = [43]#  ,42,8,43]
idx = 17
x = dataset[idx].unsqueeze(0).to(device=device, dtype=dtype) # import image to device (cpu or gpu), sizes of x is [1,number of bands, width, height]
# Adds a small amount of white gaussian noise to the input HSI (sigma2 = 1e-4)
x += torch.rand_like(x, device=device, dtype=dtype)*1e-2/torch.norm(x)

# Compute the panchromatic image from the ground truth HSI
panc = torch.sum(x, dim=1).unsqueeze(1)/x.shape[1]

# Adds noise to the input HSI
SNR = 10

sigma2 = 10**(-SNR/10) * torch.norm(x, dim=[2,3])**2 / x.shape[2] / x.shape[3]
sigma2 = sigma2.unsqueeze(0).unsqueeze(1).reshape(1, sigma2.numel(), 1, 1)
sigma2 = sigma2.repeat(1, 1, x.shape[2], x.shape[3])

y = x + torch.sqrt(sigma2)*torch.randn_like(x, device=device, dtype=dtype)



In [None]:
# Plot RGB and panchromatic 

rgb = dataset.rgb_index
x_rgb = x[0, rgb,...].cpu().numpy().transpose(1, 2, 0)
x_rgb = (x_rgb - x_rgb.min())/(x_rgb.max() - x_rgb.min())

y_rgb = y[0, rgb, ...].cpu().numpy().transpose(1, 2, 0)
y_rgb = (y_rgb - y_rgb.min())/(y_rgb.max() - y_rgb.min())



# plt.figure(figsize=(15, 5))
# plt.subplot(131)
plt.figure(figsize=(10,10))
plt.imshow(x_rgb)
# plt.title('RGB')
plt.axis('off')

# plt.subplot(132)
plt.figure(figsize=(10,10))
plt.imshow(panc[0, 0, ...].cpu().numpy(), cmap='gray')
# plt.title('Panchromatic')
plt.axis('off')

# plt.subplot(133)
plt.figure(figsize=(10,10))
plt.imshow(y_rgb)
# plt.title('Noisy RGB')
plt.axis('off')

plt.show()

In [None]:
# Plot bands of the HSI
band_idx = [0,10,20,30]

for i, idx in enumerate(band_idx):
    plt.figure(figsize=(20, 20))
    plt.subplot(121)
    plt.imshow(y[0, idx, ...].cpu().numpy(), cmap='gray')
    plt.colorbar()
    plt.title(f'Band {idx}')
    plt.axis('off')

    plt.subplot(122)
    plt.imshow(x[0, idx, ...].cpu().numpy(), cmap='gray')
    
plt.show()

In [44]:

# Choose the index of the selected image in the dataset
# data_idx = [43]#  ,42,8,43]
import numpy as np
from PIL import Image

idx = [6,17,22,42,8,43]

for i in idx:
    x = dataset[i].unsqueeze(0).to(device=device, dtype=dtype) # import image to device (cpu or gpu), sizes of x is [1,number of bands, width, height]
    # Adds a small amount of white gaussian noise to the input HSI (sigma2 = 1e-4)
    x += torch.rand_like(x, device=device, dtype=dtype)*1e-2/torch.norm(x)

    # Compute the panchromatic image from the ground truth HSI
    panc =torch.sum(x, dim=1).unsqueeze(1)/x.shape[1]
    panc = (panc[0,0]/panc.max()).numpy()
    # Adds noise to the input HSI
    SNR = 10

    sigma2 = 10**(-SNR/10) * torch.norm(x, dim=[2,3])**2 / x.shape[2] / x.shape[3]
    sigma2 = sigma2.unsqueeze(0).unsqueeze(1).reshape(1, sigma2.numel(), 1, 1)
    sigma2 = sigma2.repeat(1, 1, x.shape[2], x.shape[3])

    y = x + torch.sqrt(sigma2)*torch.randn_like(x, device=device, dtype=dtype)

    # Plot RGB and panchromatic 

    rgb = dataset.rgb_index
    x_rgb = x[0, rgb,...].cpu().numpy().transpose(1, 2, 0)
    x_rgb = (x_rgb - x_rgb.min())/(x_rgb.max() - x_rgb.min())

    y_rgb = y[0, rgb, ...].cpu().numpy().transpose(1, 2, 0)
    y_rgb = (y_rgb - y_rgb.min())/(y_rgb.max() - y_rgb.min())

    im_rgb = Image.fromarray((x_rgb * 255).astype(np.uint8))
    im_rgb.save(f'figures/{i}_rgb.png')

    im_noisy = Image.fromarray((y_rgb*255).astype(np.uint8))
    im_noisy.save(f'figures/{i}_noisy.png')

    im_panc = Image.fromarray((panc*255).astype(np.uint8))
    im_panc.save(f'figures/{i}_panc.png')


    # # plt.figure(figsize=(15, 5))
    # # plt.subplot(131)
    # plt.figure(figsize=(10,10))
    # plt.imshow(x_rgb)
    # # plt.title('RGB')
    # plt.axis('off')
    # plt.savefig()

    # # plt.subplot(132)
    # plt.figure(figsize=(10,10))
    # plt.imshow(panc[0, 0, ...].cpu().numpy(), cmap='gray')
    # # plt.title('Panchromatic')
    # plt.axis('off')
    # plt.savefig(f'figures/{i}_pan.png')


    # # plt.subplot(133)
    # plt.figure(figsize=(10,10))
    # plt.imshow(y_rgb)
    # # plt.title('Noisy RGB')
    # plt.axis('off')
    # plt.savefig(f'figures/{i}_noisy.png')



In [None]:
import sys
sys.path.append('../algorithms')
from nabla import nabla

from torch.linalg import norm

# gt_panc = torch.sum(x.unsqueeze(0), dim=1).unsqueeze(1)/x.shape[1]
grad_gt_panc = nabla(panc)

norm_grad_panc = norm(grad_gt_panc, dim=-1, keepdim=True)
norm_mean = torch.mean(norm(grad_gt_panc, ord=2, dim=-1))

grad_panc_orth = torch.zeros_like(grad_gt_panc).to(grad_gt_panc.device).type(grad_gt_panc.dtype)
grad_panc_orth[...,0] = grad_gt_panc[...,1]
grad_panc_orth[...,1] = -grad_gt_panc[...,0]

# plt.figure(figsize=(10,10))
# plt.imshow(norm_grad_panc.squeeze(), cmap='viridis')
# plt.colorbar()

mu = 0.1
cmap = 'RdYlBu'
cmap2 = 'jet'

for mu in [0, 0.1, 0.5, 1.0, 10.0, 100.0]:
    weight_tensor = torch.exp(-mu * norm_grad_panc/norm_mean) 
    weight_tensor_perp = (2 - weight_tensor) #* grad_panc_orth/norm_grad_panc
# weight_tensor = weight_tensor * grad_gt_panc/norm_grad_panc


    plt.figure(figsize=(15,10))
    plt.title(f'$w_2(x,y)$ ($\mu$={mu})')
    # plt.subplot(131)
    # plt.imshow(norm(weight_tensor, dim=-1).squeeze())
    plt.imshow(weight_tensor.squeeze(), vmin=0, vmax=2, cmap=cmap)
    plt.savefig(f'figures/mu_{mu}.w2.png', bbox_inches='tight')
    plt.axis('off')
    plt.colorbar()

    plt.figure(figsize=(15,10))
    plt.title(f'$w_1(x,y)$ ($\mu$={mu})')
    plt.imshow(weight_tensor_perp.squeeze(), vmin=0, vmax=2, cmap=cmap)
    plt.axis('off')
    plt.savefig(f'figures/mu_{mu}.w1.png', bbox_inches='tight')
    plt.colorbar()


    plt.figure(figsize=(15,10))
    plt.title(f'$ w_1(x,y) - w_2(x,y)$ ($\mu$={mu})')
    plt.imshow(weight_tensor_perp.squeeze() - weight_tensor.squeeze(), vmin=0, vmax=2, cmap=cmap2)
    plt.colorbar()
    plt.axis('off')
    plt.savefig(f'figures/mu_{mu}.w1-w2.png', bbox_inches='tight')

plt.show()