## 1. Load voxel model

In [None]:
from setup.voxel_setup import setup_voxel_scene
from common.plot import Plotter
from common.figure import *
from simulation.simulator import get_irrad_loc_dir, compute_ior_gradient

import taichi as ti
import matplotlib.pyplot as plt
import numpy as np

%load_ext autoreload
%autoreload 2
# May comment it because the compatibility of this extension is not good
# %matplotlib widget 

# debug=True to check boundary access
ti.init(arch=ti.gpu)

SCENE_CFG = {
    # Optional Names: "geometry", "bunny", "footed_glass", "stemmed_glass"
    "Name": "geometry", 
     
    "HDR Res": (4000, 2000), 
    "HDR Name": "Light_wooden_floor_room_4k.hdr",
    "Cam Pos": [(-2, 1, 4), (0, 0.1, 1.5), (1, 2, 1)],

    # "HDR Res": (2000, 1000),
    # "HDR Name": "Light_wooden_frame_room_2k.hdr",
    # "Cam Pos": [(2, 0.5, 2), (-2, -1, 2), (-2, -1, 0)],

    "Screen Res": (1280, 960),
    
    "Num XYZ": (128, 128, 128),
    'Floor Ratio': -0.95,

    "Sampler Num": 8,

    "Load Save": True,

    "Save Fig": False,
}

PROC_CFG = {
    "Gauss Sigma": 4.0,
    "Gauss Radius": 2,

    "Grad Threshold": 0.0,
}

plotter = Plotter(SCENE_CFG)
scene = setup_voxel_scene(SCENE_CFG)

scene.apply_filter(PROC_CFG)
scene.gradient = compute_ior_gradient(scene.ior)
plotter.plot_wavefront(scene.ior, None, None)

## 2. Perform light simulation

In [None]:
scene.irradiance, scene.local_diretion = get_irrad_loc_dir(scene, SCENE_CFG, plotter=plotter,
                                                           num_show_images=2)


In [None]:
# plotter.plot_irradiance_grid(scene.irradiance)
plotter.plot_irradiance_slices(scene.irradiance, "np-irrad", num_slices=4, z_start=30, z_end=100)

# plotter.plot_local_direction_grid_slices(scene.local_diretion[:, floor_height:, :], num_slices=4, z_start=30, z_end=100)

## 3. Ray marching render

In [None]:
# scene.rt_render(free_mode=False)
save_offline_render(scene, SCENE_CFG, filename="Origin", to_plot=True)

## 4. Different data structures to store or fit irradiance

In [None]:
from data.siren import SirenFitter, siren_post_process
from data.mlp import MLPFitter, mlp_post_process
from data.octree import Octree, octree_post_process

### 4.1 SIREN to fit irradiance

In [None]:
siren_fitter= SirenFitter(scene.irradiance, SCENE_CFG,
                     hidden_features=256, hidden_layers=3, omega=24)
siren_fitter.fit(total_epochs=30, batch_size=20000, lr=5e-4) 

In [None]:
siren_res = siren_fitter.infer()
plotter.plot_irradiance_slices(siren_res, "siren-irrad", threshold=3, num_slices=4, z_start=30, z_end=100)
siren_res.shape

In [None]:
# corrected_siren_res = siren_post_process(siren_res, gamma=None)
# plotter.plot_irradiance_slices(corrected_siren_res, "corrected-siren-irrad", threshold=3, num_slices=4, z_start=30, z_end=100)
# assert corrected_siren_res.shape == scene.irradiance.shape, "The shape of the corrected siren result should be the same as the original irradiance grid"

In [None]:
scene.irradiance = siren_res

# scene.rt_render(False)
save_offline_render(scene, SCENE_CFG, filename="Siren", to_plot=True)

### 4.2 MLP to fit irradiance

In [None]:
mlp = MLPFitter(scene.irradiance, SCENE_CFG, num_epoches=1000)

In [None]:
# Visualize the predicted irradiance field
mlp_predicted_irradiance = mlp.predict(pad=True)
plotter.plot_irradiance_slices(mlp_predicted_irradiance, "mlp-irrad", threshold=3, 
                               num_slices=4, z_start=30, z_end=100)

In [None]:
# corrected_mlp_res = mlp_post_process(mlp_predicted_irradiance, None)
# plotter.plot_irradiance_slices(corrected_mlp_res, "corrected-mlp-irrad", threshold=3, num_slices=4, z_start=30, z_end=100)
# assert corrected_mlp_res.shape == scene.irradiance.shape, "The shape of the corrected mlp result should be the same as the original irradiance grid"

In [None]:
scene.irradiance = np.clip(mlp_predicted_irradiance, 0, 255)
# scene.rt_render(False)

save_offline_render(scene, SCENE_CFG, filename="MLP", to_plot=True)

### 4.3 Octree to store irradiance

In [None]:
octree = Octree(threshold=12)
octree.construct(scene.irradiance)
print(f"Number of nodes: {len(octree)}")
print(f"Octree Memory usage: {octree.__sizeof__()} bytes")
print(f"In comparison, NumPy Storage Usage: {scene.irradiance.nbytes} bytes")
octree.visualize(plotter, "octree-irrad", num_slices=4, z_start=30, z_end=100)

In [None]:

octree_res = octree.init_empty_grid()
octree.fill_grid(octree.root, octree_res, 0, 0, 0, octree.grid_size)
# plotter.plot_irradiance_slices(octree_res, threshold=3, num_slices=4, z_start=30, z_end=100)
# corrected_octree_res = octree_post_process(octree_res)
# plotter.plot_irradiance_slices(corrected_octree_res, threshold=3, num_slices=4, z_start=30, z_end=100)
# assert corrected_octree_res.shape == scene.irradiance.shape, "The shape of the corrected octree result should be the same as the original irradiance grid"

scene.irradiance = np.clip(octree_res, 0, 255)

# scene.rt_render(False)

save_offline_render(scene, SCENE_CFG, filename="Octree", to_plot=True)

# x, y, z = 60, 60, 60
# value = octree.query(x, y, z)
# value