In [None]:
# from dataset import HyperspectralVideoData
from dataset import MultispectralLightField

# hsv_data = HyperspectralVideoData('dataset/hyperspectral-video-33bands')
# hsv_data.load_data(time=31, height=480, width=752, channels=33)
mlf_data = MultispectralLightField('dataset/multispectral-light-field')
mlf_data.load_data(scene='Bust')

In [None]:
compression_power = 512
siren_features = {64: 845, 128: 595, 256: 420, 512: 295}
signet_features = {64: 830, 128: 585, 256: 405, 512: 285}
siren_batches = {64: 400_000, 128: 550_000, 256: 800_000, 512: 1_200_000}

In [None]:
import torch
from model import Siren, Signet, SirenMulti

device = torch.device("cuda:0")
# siren = Siren(in_features=5, out_features=1, hidden_features=siren_features[compression_power], num_hidden_layers=3)
siren = Signet(in_features=5, out_features=1, hidden_features=signet_features[compression_power], num_hidden_layers=3, 
               c=16, alpha=0.5, batch_coord=mlf_data.get_pixels(batch_size=siren_batches[compression_power])[0])
# siren = SirenMulti(in_features=3, out_features=33, hidden_features=205, num_hidden_layers=3)

optim = torch.optim.Adam(lr=1e-4, params=siren.parameters())
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)

siren.to(device)

In [None]:
from datetime import datetime

step = 0
# while True:
while step <= 50:
    batch_coord, batch_data = mlf_data.get_pixels(batch_size=siren_batches[compression_power])
    batch_coord = batch_coord.to(device)
    batch_data = batch_data.to(device)
    
    siren_output = siren(batch_coord)
    
    loss = (siren_output - batch_data).abs().mean()
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    step += 1
    if step == 1 or step % 50 == 0:
        print(f"{datetime.now()} step:{step:04d}, loss:{loss.item():0.8f}")
    if step % 150 == 0:
        scheduler.step()

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

# def to_uint8(im):
#     return ((im.detach().cpu().clamp(-1,1).numpy() * 0.5 + 0.5) * 255).astype(np.uint8)

def to_uint16(im):
    return ((im.detach().cpu().clamp(-1,1).numpy() * 0.5 + 0.5) * (2**16 - 1)).astype(np.uint16)

batch_coord, batch_data = mlf_data.get_images(batch_size=1)
if isinstance(siren, Signet):
    siren.update_gegenbauer_init(batch_coord)
with torch.no_grad():
    batch_coord = batch_coord.to(device)
    batch_data = batch_data.to(device)
    siren_output = siren(batch_coord)
    
plt.imshow(to_uint16(batch_data[0,...]), cmap="gray")
plt.show()
    
plt.imshow(to_uint16(siren_output[0,...,0]), cmap='gray')
plt.show()

In [None]:
import os
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage.metrics import structural_similarity as SSIM

# with torch.no_grad():
#     for ti in range(hsv_data.time):
#         batch_coord = hsv_data.mgrid[ti:ti+1,...].to(device)
#         batch_data = hsv_data.data[ti:ti+1,...].to(device)
        
#         siren_output = siren(batch_coord)
        
#         batch_data = to_uint8(batch_data)
#         siren_output = to_uint8(siren_output)
#         for ci in range(hsv_data.channels):
#             psnrs.append(PSNR(batch_data[0,...,ci], siren_output[0,...,ci]))
#             ssims.append(SSIM(batch_data[0,...,ci], siren_output[0,...,ci]))
   
psnrs = []
ssims = []
with torch.no_grad():
    # for yi in range(9):
    for yi in range(2):
        # for xi in range(9):
        for xi in range(2):
            # for li in range(13):
            for li in range(2):
                batch_coord = mlf_data.mgrid[yi:yi+1,xi,...,li,0:5].to(device)
                batch_data = mlf_data.data[yi,xi,...,li].to(device)
                
                if isinstance(siren, Signet) and yi == 0 and xi == 0 and li == 0:
                    siren.update_gegenbauer_init(batch_coord)

                siren_output = siren(batch_coord)

                batch_data = to_uint16(batch_data)
                siren_output = to_uint16(siren_output)
                
                psnrs.append(PSNR(batch_data, siren_output[0,...,0]))
                ssims.append(SSIM(batch_data, siren_output[0,...,0]))

net_str = 'signet' if isinstance(siren, Signet) else 'siren'
out_folder = f'outputs/{net_str}_{compression_power}'
if not os.path.exists(out_folder):
    os.makedirs(out_folder)

torch.save(siren.state_dict(), f"{out_folder}/{net_str}.pt")
with open(f"{out_folder}/metrics.txt",'w') as f:
    f.write(f"MEAN PSNR:{np.array(psnrs).mean()}\n")
    f.write(f"MEAN SSIM:{np.array(ssims).mean()}\n")