In [None]:
from dataset import HyperspectralVideoData

hsv_data = HyperspectralVideoData('dataset/hyperspectral-video-33bands')
hsv_data.load_data(time=31, height=480, width=752, channels=33)

In [None]:
import torch
from model import Siren

device = torch.device("cuda:0")
siren = Siren(in_features=3, out_features=33, hidden_features=256, num_hidden_layers=3)

optim = torch.optim.Adam(lr=1e-4, params=siren.parameters())
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)

siren.to(device)

In [None]:
from datetime import datetime

step = 0
while True:
    batch_coord, batch_data = hsv_data.get_pixels(batch_size=1_000_000)
    batch_coord = batch_coord.to(device)
    batch_data = batch_data.to(device)
    
    siren_output = siren(batch_coord)
    
    loss = (siren_output - batch_data).abs().mean()
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    step += 1
    if step == 1 or step % 50 == 0:
        print(f"{datetime.now()} step:{step:04d}, loss:{loss.item():0.8f}")
    if step % 150 == 0:
        scheduler.step()

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

def to_uint8(im):
    return ((im.detach().cpu().clamp(-1,1).numpy() * 0.5 + 0.5) * 255).astype(np.uint8)

batch_coord, batch_data = hsv_data.get_images(batch_size=1)

with torch.no_grad():
    batch_coord = batch_coord.to(device)
    batch_data = batch_data.to(device)
    
    siren_output = siren(batch_coord)
    
plt.imshow(to_uint8(batch_data[0,...,0]), cmap="gray")
plt.show()
    
plt.imshow(to_uint8(siren_output[0,...,0]), cmap='gray')
plt.show()

plt.imshow(to_uint8(batch_data[-1,...,-1]), cmap="gray")
plt.show()
    
plt.imshow(to_uint8(siren_output[-1,...,-1]), cmap='gray')
plt.show()

In [None]:
import os
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage.metrics import structural_similarity as SSIM

psnrs = []
ssims = []
with torch.no_grad():
    for ti in range(hsv_data.time):
        batch_coord = hsv_data.mgrid[ti:ti+1,...].to(device)
        batch_data = hsv_data.data[ti:ti+1,...].to(device)
        
        siren_output = siren(batch_coord)
        
        batch_data = to_uint8(batch_data)
        siren_output = to_uint8(siren_output)
        for ci in range(hsv_data.channels):
            psnrs.append(PSNR(batch_data[0,...,ci], siren_output[0,...,ci]))
            ssims.append(SSIM(batch_data[0,...,ci], siren_output[0,...,ci]))

out_folder = f'outputs/siren_256'
if not os.path.exists(out_folder):
    os.makedirs(out_folder)

torch.save(siren.state_dict(), f"{out_folder}/siren.pt")
with open(f"{out_folder}/metrics.txt",'w') as f:
    f.write(f"MEAN PSNR:{np.array(psnrs).mean()}\n")
    f.write(f"MEAN SSIM:{np.array(ssims).mean()}\n")