In [None]:
import taichi as ti
from setup.voxel_setup import setup_voxel_scene
from simulator.propagate import *
from common.plot import *

 # debug=True to check boundary access
ti.init(arch=ti.gpu)

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from scipy import ndimage
import sys

%load_ext autoreload
%autoreload 2

matplotlib.use('Qt5Agg')
%matplotlib widget

## 1. Load voxel model

In [None]:
NUM_X, NUM_Y, NUM_Z = 128, 128, 128
scene, floor_height = setup_voxel_scene(NUM_X, NUM_Y, NUM_Z)

In [None]:
# scene.finish()

## 2. Light simulation

In [None]:
scene_ior = scene.get_ior_grid().to_numpy()
assert isinstance(scene_ior, np.ndarray) and scene_ior.shape == (NUM_X, NUM_Y, NUM_Z), "The scene IOR should be a NumPy array of shape (NUM_X, NUM_Y, NUM_Z)"

sampler_multiplier = 3
pos_perturbation_scale = 0.45
initial_wavefront_pos, initial_wavefront_dir = generate_initial_wavefront(sampler_multiplier, pos_perturbation_scale, NUM_X, NUM_Y, NUM_Z)
# plot_ior_field(scene_ior, initial_wavefront_pos, initial_wavefront_dir, sampler_multiplier, floor_height)

smoothed_ior = ndimage.gaussian_filter(scene_ior, sigma=3.0, radius=1)
ior_gradients = compute_ior_gradient(smoothed_ior)

# Test the setter and getter for the gradient field
scene.set_grad_field(ior_gradients)
test_grad = scene.get_grad_field()
plot_gradients_3d(test_grad, floor_height, threshold=0.01, alpha=0.01)

In [None]:
test_delta_t = 0.3 * (NUM_Y / 100)
test_num_steps = int(1.1 * (NUM_Y / test_delta_t))
irradiance_grid, local_directions = simulate_wavefront_propagation(scene_ior, ior_gradients, initial_wavefront_pos, initial_wavefront_dir, device, sampler_multiplier, floor_height, test_num_steps, test_delta_t)

## 3. Compute irradiance

In [None]:
def remove_under_floor(grid: np.ndarray, floor_height: int) -> np.ndarray:
    grid[:, :floor_height, :] = 0
    return grid

visualise_irradiance_grid_3d(irradiance_grid, floor_height, threshold=5*sampler_multiplier**3)
above_floor_irradiance_grid = remove_under_floor(irradiance_grid, floor_height=floor_height)
visualise_irradiance_grid_slices(above_floor_irradiance_grid, threshold=3, num_slices=8, z_start=30, z_end=120)

In [None]:
filtered_above_floor_irradiance_grid = above_floor_irradiance_grid.copy()
filtered_above_floor_irradiance_grid = ndimage.gaussian_filter(filtered_above_floor_irradiance_grid, sigma=0.5) # Apply Gaussian filter to smooth the radiometric grid
visualise_irradiance_grid_slices(filtered_above_floor_irradiance_grid, threshold=3, num_slices=8, z_start=30, z_end=120)

## 6. Neural network irradiance （MLP）

In [None]:
# SEED = 42
# import torch.nn as nn
# class IrradianceNet(nn.Module):
#     def __init__(self):
#         super(IrradianceNet, self).__init__()
#         self.model = nn.Sequential(
#             nn.Linear(3, 512),
#             nn.ReLU(),
#             nn.Linear(512, 1024),
#             nn.ReLU(),
#             nn.Linear(1024, 512),
#             nn.ReLU(),
#             nn.Linear(512, 1)
#         )
#         for m in self.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.xavier_uniform_(m.weight)

#     def forward(self, x):
#         return self.model(x).squeeze()

# def prepare_data(irradiance: torch.Tensor, train_ratio=0.90) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
#     x = np.arange(irradiance.shape[0])
#     y = np.arange(irradiance.shape[1]) 
#     z = np.arange(irradiance.shape[2])
#     X, Y, Z = np.meshgrid(x, y, z)
#     inputs = np.stack((X, Y, Z), axis=-1).reshape(-1, 3)
#     targets = irradiance.flatten()

#     inputs = torch.from_numpy(inputs).float().to(device)
#     targets = targets.clone().detach().to(device)

#     # Random shuffle
#     # torch.manual_seed(SEED)
#     indices = torch.randperm(len(inputs))
#     inputs = inputs[indices]
#     targets = targets[indices]

#     # Split into training and validation sets
#     train_size = int(train_ratio * len(inputs))
#     train_inputs, val_inputs = inputs[:train_size], inputs[train_size:]
#     train_targets, val_targets = targets[:train_size], targets[train_size:]

#     return train_inputs, val_inputs, train_targets, val_targets

# def train_model(model: IrradianceNet, train_inputs: torch.Tensor, val_inputs: torch.Tensor, train_targets: torch.Tensor, val_targets: torch.Tensor, 
#                 num_epochs=320, batch_size=1024, patience=60):
#     torch.cuda.empty_cache()
    
#     criterion = nn.MSELoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.002)

#     best_val_loss = float('inf')
#     epochs_no_improve = 0
#     best_model = None

#     for epoch in range(num_epochs):
#         model.train()
#         for i in range(0, len(train_inputs), batch_size):
#             batch_inputs = train_inputs[i:i+batch_size]
#             batch_targets = train_targets[i:i+batch_size]
            
#             outputs = model(batch_inputs)
#             loss = criterion(outputs, batch_targets)
            
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
        
#         model.eval()
#         with torch.no_grad():
#             val_outputs = model(val_inputs)
#             val_loss = criterion(val_outputs, val_targets)
        
#         if val_loss < best_val_loss * 0.95: # at least 5% improvement
#             best_val_loss = val_loss
#             epochs_no_improve = 0
#             best_model = model.state_dict()
#             print(f"Current best model is at epoch {epoch + 1} and val loss: {val_loss.item():.4f}")
#         else:
#             epochs_no_improve += 1
        
#         if epochs_no_improve == patience:
#             print(f'Early stopping triggered at epoch {epoch + 1}')
#             break
        
#         if (epoch + 1) % 40 == 0:
#             print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')

#     assert best_model is not None
#     model.load_state_dict(best_model)
#     torch.cuda.empty_cache()
#     return model

# def generate_irradiance_field_3d(model: IrradianceNet, size=100) -> torch.Tensor:
#     x = np.arange(size)
#     y = np.arange(size)
#     z = np.arange(size)
#     X, Y, Z = np.meshgrid(x, y, z)
#     coords = np.stack((X, Y, Z), axis=-1).reshape(-1, 3)    
#     with torch.no_grad():
#         inputs = torch.tensor(coords, dtype=torch.float32, device=device)
#         predictions = model(inputs)
    
#     return predictions.reshape(size, size, size)

# # Prepare the data and train the MLP model
# train_inputs, val_inputs, train_targets, val_targets = prepare_data(irradiance_grid)
# model = IrradianceNet().to(device)  
# model = train_model(model, train_inputs, val_inputs, train_targets, val_targets)

# # Visualize the predicted irradiance field
# predicted_irradiance = generate_irradiance_field_3d(model)
# show_radiometric_grid_3d(predicted_irradiance, threshold=0.2, only_above_floor=True)

## 7. Octree

In [None]:
class Node:
    def __init__(self, val: int = 0, children=None):
        self.val = val
        self.children = children or [None] * 8
    
    @property
    def is_leaf(self):
        return all(child is None for child in self.children)
    
    def __repr__(self):
        return self._repr_recursive()

    def _repr_recursive(self, depth=0, max_depth=2):
        indent = "  " * depth
        if self.is_leaf:
            return f"{indent}Node(val={self.val})"
        else:
            if depth >= max_depth:
                return f"{indent}Node(...)"
            children_repr = ",\n".join(self._repr_recursive_child(child, depth + 1, max_depth) for child in self.children)
            return f"{indent}Node(\n{children_repr})"

    def _repr_recursive_child(self, child, depth, max_depth):
        return "  " * depth + "None" if child is None else child._repr_recursive(depth, max_depth)

class OcTree:
    def __init__(self, threshold: int = 0):
        self.threshold = threshold
        self.grid_size = 0
        self.root = None
    
    def construct(self, grid: np.ndarray) -> None:
        self.root = self._build_tree(grid, 0, 0, 0, grid.shape[0])
        self.grid_size = grid.shape[0]
    
    def _build_tree(self, grid: np.ndarray, x: int, y: int, z: int, size: int) -> Node:
        if self._is_homogeneous(grid, x, y, z, size):
            return Node(grid[x, y, z])

        half_size = size // 2
        children = [
            self._build_tree(grid, x, y, z, half_size),
            self._build_tree(grid, x, y, z + half_size, half_size),
            self._build_tree(grid, x, y + half_size, z, half_size),
            self._build_tree(grid, x, y + half_size, z + half_size, half_size),
            self._build_tree(grid, x + half_size, y, z, half_size),
            self._build_tree(grid, x + half_size, y, z + half_size, half_size),
            self._build_tree(grid, x + half_size, y + half_size, z, half_size),
            self._build_tree(grid, x + half_size, y + half_size, z + half_size, half_size)
        ]
        
        if all(child.is_leaf for child in children):
            values = {child.val for child in children}
            if len(values) == 1:
                return Node(values.pop())
        
        return Node(children=children)
    
    def _is_homogeneous(self, grid: np.ndarray, x: int, y: int, z: int, size: int) -> bool:
        return np.ptp(grid[x:x+size, y:y+size, z:z+size]) <= self.threshold

    def query(self, x: int, y: int, z: int) -> int:
        return self._query(self.root, 0, 0, 0, self.grid_size, x, y, z)

    def _query(self, node: Node | None, x: int, y: int, z: int, size: int, qx: int, qy: int, qz: int) -> int:
        if node is None:
            raise ValueError("OcTree is empty. Maybe you forgot to construct it?")
        
        if node.is_leaf:
            return node.val

        half_size = size // 2
        octant = (
            (qx >= x + half_size) << 2 |
            (qy >= y + half_size) << 1 |
            (qz >= z + half_size)
        )
        
        return self._query(
            node.children[octant],
            x + (octant >> 2) * half_size,
            y + ((octant >> 1) & 1) * half_size,
            z + (octant & 1) * half_size,
            half_size,
            qx, qy, qz
        )
    
    def __len__(self):
        return self._count_nodes(self.root)

    def _count_nodes(self, node: Node | None) -> int:
        if node is None:
            return 0
        return 1 + sum(self._count_nodes(child) for child in node.children)
    
    def __sizeof__(self):
        return self._calculate_memory_usage(self.root)
    
    def _calculate_memory_usage(self, node: Node | None) -> int:
        if node is None:
            return 0
        return sys.getsizeof(node) + sum(self._calculate_memory_usage(child) for child in node.children)
    
    def visualize(self, num_slices=8, z_start=30, z_end=120):
        grid = np.zeros((self.grid_size, self.grid_size, self.grid_size))
        assert self.root is not None, "OcTree is empty. Maybe you forgot to construct it?"
        self._fill_grid(self.root, grid, 0, 0, 0, self.grid_size)
        visualise_irradiance_grid_slices(grid, threshold=self.threshold, num_slices=num_slices, z_start=z_start, z_end=z_end)

    def _fill_grid(self, node: Node, grid: np.ndarray, x: int, y: int, z: int, size: int):
        if node is None or size == 0:
            return
        
        if node.is_leaf:
            grid[x:x+size, y:y+size, z:z+size] = node.val
        else:
            half_size = size // 2
            for i, child in enumerate(node.children):
                if child is not None:
                    self._fill_grid(
                        child,
                        grid,
                        x + ((i >> 2) & 1) * half_size,
                        y + ((i >> 1) & 1) * half_size,  
                        z + (i & 1) * half_size,
                        half_size
                    )

octree = OcTree(threshold=2 * (sampler_multiplier**3))
octree.construct(filtered_above_floor_irradiance_grid)
print(f"Number of nodes: {len(octree)}")
print(f"Octree Memory usage: {octree.__sizeof__()} bytes")
print(f"In comparison, NumPy Storage Usage: {filtered_above_floor_irradiance_grid.nbytes} bytes")
octree.visualize()

In [None]:
x, y, z = 60, 60, 60
value = octree.query(x, y, z)
value

## Temp Tests

In [None]:
# x = ti.field(float, shape=(3, 3))
# a = np.arange(9).reshape(3, 3).astype(np.int32)
# x.from_numpy(a)
# print(x)
# type(x)
# arr = x.to_numpy()
# print(arr)
# type(arr)
# field = ti.Vector.field(3, int, shape=(256, 512))
# field.shape  # (256, 512)
# field.n      # 3

# array = field.to_numpy()
# array.shape  # (256, 512, 3)
# print(type(array))

# field.from_numpy(array)  # the input array must in the shape (256, 512, 3)
# print(type(field))

In [None]:
@ti.data_oriented
class TiArray:
    def __init__(self, n):
        self.x = ti.field(dtype=ti.i32, shape=n)

    @ti.kernel
    def inc(self):
          
        for i in self.x:
            self.x[i] += 1

a = TiArray(32)
a.inc()
print(a.x.to_numpy())
print(a.x.dtype)

## *. Render