In [None]:
# from dataset import HyperspectralVideoData
from dataset import MultispectralLightField

# hsv_data = HyperspectralVideoData('dataset/hyperspectral-video-33bands')
# hsv_data.load_data(time=31, height=480, width=752, channels=33)
mlf_data = MultispectralLightField('dataset/multispectral-light-field')
scene_str = 'Bust'
mlf_data.load_data(scene=scene_str)

In [None]:
compression_power = 64
siren_features = {64: 845, 128: 595, 256: 420, 512: 295}
signet_features = {64: 830, 128: 585, 256: 405, 512: 285}
bases = 4
power = 0
sirenfe_features = {4: {64: 300, 512: 105},
                    5: {64: 270, 128: 190, 512: 95},
                    6: {64: 245, 128: 175, 256: 120, 512: 85},
                    7: {512: 80},
                    8: {512: 75},
                    12: {512: 60}}

signetfe_features = {3: {512: 120},
                     4: {64: 300, 512: 105},
                     5: {64: 270, 512: 95},
                     6: {64: 245, 128: 175, 256: 120, 512: 85},
                     7: {512: 80},
                     8: {512: 75},
                     12: {512: 60}}
# based on sirenfe
siren_batches = {3: {512: 500_000},
                 4: {64: 150_000, 128: 200_000, 512: 450_000},
                 5: {64: 125_000, 128: 175_000, 512: 400_000},
                 6: {64: 100_000, 128: 150_000, 256: 250_000, 512: 350_000},
                 7: {512: 300_000},
                 8: {512: 300_000},
                 12: {512: 250_000}}

In [None]:
import torch
from model import Siren, Signet, SirenFeatureEncoding, SignetFeatureEncoding

device = torch.device("cuda:0")
# siren = Siren(in_features=5, out_features=1, hidden_features=siren_features[compression_power], num_hidden_layers=3)
# siren = Signet(in_features=5, out_features=1, hidden_features=signet_features[compression_power], num_hidden_layers=3, 
#                c=16, alpha=0.5, batch_coord=mlf_data.get_pixels(batch_size=siren_batches[bases][compression_power])[0])
# siren = SirenFeatureEncoding(in_features=5, out_features=1, hidden_features=sirenfe_features[bases][compression_power], num_hidden_layers=3, bases=bases, power=power, alt=False)
# siren = SirenFeatureEncoding(in_features=5, out_features=1, hidden_features=sirenfe_features[bases][compression_power], num_hidden_layers=3, bases=bases, power=power, alt=True)
siren = SignetFeatureEncoding(in_features=5, out_features=1, hidden_features=signetfe_features[bases][compression_power], num_hidden_layers=3, 
                              c=16, alpha=0.5, batch_coord=mlf_data.get_pixels(batch_size=siren_batches[bases][compression_power])[0], bases=bases, power=power)

optim = torch.optim.Adam(lr=1e-4, params=siren.parameters())
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)

siren.to(device)

In [None]:
from datetime import datetime

step = 0
while True:
# while step <= 50:
    batch_coord, batch_data = mlf_data.get_pixels(batch_size=siren_batches[bases][compression_power])
    batch_coord = batch_coord.to(device)
    batch_data = batch_data.to(device)
    
    siren_output = siren(batch_coord)
    
    loss = (siren_output - batch_data).abs().mean()
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    step += 1
    if step == 1 or step % 50 == 0:
        print(f"{datetime.now()} step:{step:04d}, loss:{loss.item():0.8f}")
    if step % 150 == 0:
        scheduler.step()

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

# def to_uint8(im):
#     return ((im.detach().cpu().clamp(-1,1).numpy() * 0.5 + 0.5) * 255).astype(np.uint8)

def to_uint16(im):
    return ((im.detach().cpu().clamp(-1,1).numpy() * 0.5 + 0.5) * (2**16 - 1)).astype(np.uint16)

batch_coord, batch_data = mlf_data.get_images(batch_size=1)
if isinstance(siren, Signet):
    siren.update_gegenbauer_init(batch_coord)
with torch.no_grad():
    batch_coord = batch_coord.to(device)
    batch_data = batch_data.to(device)
    siren_output = siren(batch_coord)
    
plt.imshow(to_uint16(batch_data[0,...]), cmap="gray")
plt.show()
    
plt.imshow(to_uint16(siren_output[0,...,0]), cmap='gray')
plt.show()

In [None]:
import os
import imageio
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage.metrics import structural_similarity as SSIM
from utils import wavelength_to_rgb

# with torch.no_grad():
#     for ti in range(hsv_data.time):
#         batch_coord = hsv_data.mgrid[ti:ti+1,...].to(device)
#         batch_data = hsv_data.data[ti:ti+1,...].to(device)
        
#         siren_output = siren(batch_coord)
        
#         batch_data = to_uint8(batch_data)
#         siren_output = to_uint8(siren_output)
#         for ci in range(hsv_data.channels):
#             psnrs.append(PSNR(batch_data[0,...,ci], siren_output[0,...,ci]))
#             ssims.append(SSIM(batch_data[0,...,ci], siren_output[0,...,ci]))

net_str = 'signet' if isinstance(siren, Signet) else 'siren'
if isinstance(siren, SirenFeatureEncoding) or isinstance(siren, SignetFeatureEncoding):
    net_str += f'fe_{bases}_{power}'
out_folder = f'outputs/{scene_str}_{compression_power}_{net_str}'
if not os.path.exists(out_folder):
    os.makedirs(out_folder)

torch.save(siren.state_dict(), f"{out_folder}/{net_str}.pt")

psnrs = []
ssims = []
with torch.no_grad():
    for yi in range(3):
    # for yi in range(2):
        for xi in range(3):
        # for xi in range(2):
            for li in range(13):
            # for li in range(2):
                batch_coord = mlf_data.mgrid[yi:yi+1,xi,...,li,0:5].to(device)
                batch_data = mlf_data.data[yi,xi,...,li].to(device)
                
                if isinstance(siren, Signet) and yi == 0 and xi == 0 and li == 0:
                    siren.update_gegenbauer_init(batch_coord)

                siren_output = siren(batch_coord)

                batch_data = to_uint16(batch_data)
                siren_output = to_uint16(siren_output)
                
                psnrs.append(PSNR(batch_data, siren_output[0,...,0]))
                ssims.append(SSIM(batch_data, siren_output[0,...,0]))
                
                if yi==xi:
                    batch_data = (np.tile(batch_data[...,None], (1, 1, 3)) * wavelength_to_rgb(400+li*25)).astype(np.uint16)
                    siren_output = (np.tile(siren_output[0,...,0][...,None], (1, 1, 3)) * wavelength_to_rgb(400+li*25)).astype(np.uint16)
                    render_g = (batch_data / 13) if li==0 else render_g + (batch_data / 13)
                    render_r = (siren_output / 13) if li==0 else render_r + (siren_output / 13)
                    if li==12:
                        imageio.imwrite(f"{out_folder}/{yi}_{xi}_render_g.png", render_g.astype(np.uint16))
                        imageio.imwrite(f"{out_folder}/{yi}_{xi}_render_r.png", render_r.astype(np.uint16))
                    imageio.imwrite(f"{out_folder}/{yi}_{xi}_{li:02d}_g.png", batch_data)
                    imageio.imwrite(f"{out_folder}/{yi}_{xi}_{li:02d}_r.png", siren_output)

with open(f"{out_folder}/metrics.txt",'w') as f:
    f.write(f"MEAN PSNR:{np.array(psnrs).mean()}\n")
    f.write(f"MEAN SSIM:{np.array(ssims).mean()}\n")