In [None]:
from dataset import HyperspectralVideoData

hsv_data = HyperspectralVideoData('dataset/hyperspectral-video-33bands')
hsv_data.load_data(time=31, height=480, width=752, channels=33)

In [None]:
import torch
from model import Siren

device = torch.device("cuda:0")
siren = Siren(in_features=3, out_features=33, hidden_features=1024, num_hidden_layers=3)

optim = torch.optim.Adam(lr=1e-4, params=siren.parameters())
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)

siren.to(device)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

batch_coord, batch_data = hsv_data.get_images(batch_size=1)
plt.imshow(batch_data[0,...,0] * 0.5 + 0.5, cmap="gray")
plt.show()

In [None]:
from datetime import datetime

step = 0
while True:
    batch_coord, batch_data = hsv_data.get_pixels(batch_size=250_000)
    batch_coord = batch_coord.to(device)
    batch_data = batch_data.to(device)
    
    siren_output = siren(batch_coord)
    
    loss = (siren_output - batch_data).abs().mean()
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    step += 1
    if step == 1 or step % 50 == 0:
        print(f"{datetime.now()} step:{step:04d}, loss:{loss.item():0.8f}")
    if step % 150 == 0:
        scheduler.step()

In [None]:
batch_coord, batch_data = hsv_data.get_images(batch_size=1)

with torch.no_grad():
    batch_coord = batch_coord.to(device)
    batch_data = batch_data.to(device)
    
    siren_output = siren(batch_coord)

In [None]:
plt.imshow(batch_data[0,...,0].cpu() * 0.5 + 0.5, cmap="gray")
plt.show()
    
plt.imshow(siren_output[0,...,0].detach().cpu().clamp(-1,1) * 0.5 + 0.5, cmap='gray')
plt.show()

plt.imshow(batch_data[-1,...,-1].cpu() * 0.5 + 0.5, cmap="gray")
plt.show()
    
plt.imshow(siren_output[-1,...,-1].detach().cpu().clamp(-1,1) * 0.5 + 0.5, cmap='gray')
plt.show()