In [1]:
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt

In [None]:
PROBES_DIM_X = 50
PROBES_DIM_Y = 5
PROBES_DIM_Z = 5
PROBES_COUNT = PROBES_DIM_X * PROBES_DIM_Y * PROBES_DIM_Z

INPUT_FLOAT_COUNT = PROBES_COUNT + 18 # probe_feature + pos + angle
SH_FLOAT_COUNT = 27
MLP_HIDDEN_LAYER_WIDTH = 27

In [None]:
class NeuralSH(nn.Module):
    def __init__(self):
        super(NeuralSH, self).__init__()
        self.probe_features = nn.Parameter(torch.rand(PROBES_COUNT), requires_grad=True)

        self.hidden_layer = nn.Linear(INPUT_FLOAT_COUNT, MLP_HIDDEN_LAYER_WIDTH)
        self.output_layer = nn.Linear(MLP_HIDDEN_LAYER_WIDTH, SH_FLOAT_COUNT)

    def trigonometric_encoding(self, x: torch.Tensor, L: int):
        assert x.ndim == 2
        y = []
        for i in range(L):
            s = torch.sin(2**i * torch.pi * x)
            c = torch.cos(2**i * torch.pi * x)
            y.append(s)
            y.append(c)
        y = torch.cat(y, dim=1)
        return y

    def forward(self, pos, angle):
        assert pos.ndim == 2
        assert angle.ndim == 2

        pos = pos.view(-1, 3)
        angle = angle.view(-1, 1)

        assert pos.shape[0] == angle.shape[0]
        batch_size = pos.shape[0]

        pos_enc = self.trigonometric_encoding(pos, L=2)
        angle_enc = self.trigonometric_encoding(angle, L=3)

        x = torch.cat([pos_enc, angle_enc, self.probe_features.repeat(batch_size, 1)], dim=1)
        x = self.hidden_layer(x)
        x = torch.sigmoid(x)
        x = self.output_layer(x)
        return x


In [4]:
xx, yy, zz = np.meshgrid(
    np.linspace(0.0, 1.0, PROBES_DIM_X),
    np.linspace(0.0, 1.0, PROBES_DIM_Y),
    np.linspace(0.0, 1.0, PROBES_DIM_Z),
)
xx = xx.reshape(-1, 1)
yy = yy.reshape(-1, 1)
zz = zz.reshape(-1, 1)
pos_grid = np.hstack([xx, yy, zz])
pos_grid.shape

(1250, 3)

In [5]:
light_angle = np.zeros((pos_grid.shape[0], 1), dtype=float)
light_angle.shape

(1250, 1)

In [6]:
model = NeuralSH()
with torch.no_grad():
    out = model(torch.FloatTensor(pos_grid), torch.FloatTensor(light_angle))
out.shape

torch.Size([1250, 27])

In [7]:
from pathlib import Path

from texture_sampler import Texture, load_texture_by_name, sample_uv

DATA_DIR = Path("LightmapsData/")


def Texture3DSample(tex: Texture, uvw: np.ndarray):
    uvw = uvw.reshape(-1, 3)
    _, _, _, channels = tex.data.shape
    results = []
    for p in uvw:
        color = sample_uv(tex, p[0], p[1], p[2])
        color = color.reshape(-1, channels)
        results.append(color)
    results = np.concatenate(results, axis=0)
    return results


def GetVolumetricLightmapAmbient(BrickTextureUVs: np.ndarray):
    tex = load_texture_by_name(DATA_DIR, "AmbientVector")
    return Texture3DSample(tex, BrickTextureUVs)


def GetVolumetricLightmapSHCoefficients0(BrickTextureUVs: np.ndarray):
    AmbientVector = GetVolumetricLightmapAmbient(BrickTextureUVs)
    SHCoefficients0Red = Texture3DSample(load_texture_by_name(DATA_DIR, "SHCoefficients_0"), BrickTextureUVs) * 2 - 1
    SHCoefficients0Green = Texture3DSample(load_texture_by_name(DATA_DIR, "SHCoefficients_2"), BrickTextureUVs) * 2 - 1
    SHCoefficients0Blue = Texture3DSample(load_texture_by_name(DATA_DIR, "SHCoefficients_4"), BrickTextureUVs) * 2 - 1
    SHDenormalizationScales0 = np.array([
        0.488603 / 0.282095,
		0.488603 / 0.282095,
		0.488603 / 0.282095,
		1.092548 / 0.282095
    ])
    SHCoefficients0Red = SHCoefficients0Red * AmbientVector[:, 0:1] * SHDenormalizationScales0
    SHCoefficients0Green = SHCoefficients0Green * AmbientVector[:, 1:2] * SHDenormalizationScales0
    SHCoefficients0Blue = SHCoefficients0Blue * AmbientVector[:, 2:3] * SHDenormalizationScales0
    return AmbientVector, SHCoefficients0Red, SHCoefficients0Green, SHCoefficients0Blue


def GetVolumetricLightmapSH3(BrickTextureUVs: np.ndarray):
    AmbientVector, SHCoefficients0Red, SHCoefficients0Green, SHCoefficients0Blue = GetVolumetricLightmapSHCoefficients0(BrickTextureUVs)
    SHCoefficients1Red = Texture3DSample(load_texture_by_name(DATA_DIR, "SHCoefficients_1"), BrickTextureUVs) * 2 - 1
    SHCoefficients1Green = Texture3DSample(load_texture_by_name(DATA_DIR, "SHCoefficients_3"), BrickTextureUVs) * 2 - 1
    SHCoefficients1Blue = Texture3DSample(load_texture_by_name(DATA_DIR, "SHCoefficients_5"), BrickTextureUVs) * 2 - 1
    SHDenormalizationScales1 = np.array([
		1.092548 / 0.282095,
		4.0 * 0.315392 / 0.282095,
		1.092548 / 0.282095,
		2.0 * 0.546274 / 0.282095
    ])
    SHCoefficients1Red = SHCoefficients1Red * AmbientVector[:, 0:1] * SHDenormalizationScales1
    SHCoefficients1Green = SHCoefficients1Green * AmbientVector[:, 1:2] * SHDenormalizationScales1
    SHCoefficients1Blue = SHCoefficients1Blue * AmbientVector[:, 2:3] * SHDenormalizationScales1

    IrradianceSH = np.concatenate([
        AmbientVector[:, 0:1], # .x
        SHCoefficients0Red[:],
        SHCoefficients1Red[:],
        AmbientVector[:, 1:2], # .y
        SHCoefficients0Green[:],
        SHCoefficients1Green[:],
        AmbientVector[:, 2:3], # .z
        SHCoefficients0Blue[:],
        SHCoefficients1Blue[:],
    ], axis=1)
    return IrradianceSH

In [8]:
# BrickTextureUVs __i06676.x, __i16677.x, __i26678.x float3 0.31288, 0.43666, 0.66076

# V0 _8422.x, __i06913.x, __i16915.x, __i26917.x float4 1.48535, -0.41522, 0.71059, 0.04087
# V1 __i36919.x, __i06997.x, __i16999.x, __i27001.x float4 -0.1588, -0.80276, 0.53671, 0.2306
# V2 __i37003.x float -0.3305
# V0 _8423.x, __i06925.x, __i16927.x, __i26929.x float4 0.98926, -0.75354, 0.73377, 0.08955
# V1 __i36931.x, __i07009.x, __i17011.x, __i27013.x float4 -0.16142, -0.97288, 0.60831, 0.24069
# V2 __i37015.x float -0.27203
# V0 _8424.x, __i06937.x, __i16939.x, __i26941.x float4 0.81299, -0.61416, 0.62958, 0.08644
# V1 __i36943.x, __i07021.x, __i17023.x, __i27025.x float4 -0.1487, -0.83557, 0.52588, 0.21856
# V2 __i37027.x float -0.2175


In [9]:
brick_uvs = np.array([0.31288, 0.43666, 0.66076]).reshape(1, 3)

shader_value = np.array([
    1.48535, -0.41522, 0.71059, 0.04087,
    -0.1588, -0.80276, 0.53671, 0.2306,
    -0.3305,
    0.98926, -0.75354, 0.73377, 0.08955,
    -0.16142, -0.97288, 0.60831, 0.24069,
    -0.27203,
    0.81299, -0.61416, 0.62958, 0.08644,
    -0.1487, -0.83557, 0.52588, 0.21856,
    -0.2175,
]).reshape(1, 27)

this_value = GetVolumetricLightmapSH3(brick_uvs)

print("shader: ", shader_value)
print("this: ", this_value)
print("errors: ", np.abs(shader_value - this_value))
print(f"max abs error: {np.abs(shader_value - this_value).max():.5f}")

shader:  [[ 1.48535 -0.41522  0.71059  0.04087 -0.1588  -0.80276  0.53671  0.2306
  -0.3305   0.98926 -0.75354  0.73377  0.08955 -0.16142 -0.97288  0.60831
   0.24069 -0.27203  0.81299 -0.61416  0.62958  0.08644 -0.1487  -0.83557
   0.52588  0.21856 -0.2175 ]]
this:  [[ 1.48513877 -0.41543918  0.71072756  0.04014002 -0.15832449 -0.80275309
   0.53758384  0.22957705 -0.33113261  0.9894107  -0.75399901  0.73368275
   0.0887034  -0.1607023  -0.97259599  0.60830068  0.24001047 -0.27270406
   0.81303811 -0.61448763  0.62945858  0.08567148 -0.14806844 -0.83529595
   0.52589594  0.21794688 -0.21810566]]
errors:  [[2.11226082e-04 2.19178445e-04 1.37558532e-04 7.29984310e-04
  4.75512036e-04 6.91442618e-06 8.73842717e-04 1.02294865e-03
  6.32609990e-04 1.50698414e-04 4.59014470e-04 8.72505959e-05
  8.46600861e-04 7.17695855e-04 2.84014562e-04 9.32347010e-06
  6.79534589e-04 6.74060019e-04 4.81107330e-05 3.27627989e-04
  1.21424787e-04 7.68522614e-04 6.31563045e-04 2.74048831e-04
  1.59423057e-0

In [10]:
true_sh = GetVolumetricLightmapSH3(pos_grid)
true_sh.shape

(1250, 27)

# Training

In [11]:
from torch.utils.data import Dataset, DataLoader, random_split

class SHDataset(Dataset):
    def __init__(
            self,
            positions_data: np.ndarray,
            light_angles_data: np.ndarray,
            spherical_harmonics_data: np.ndarray
        ):
        super().__init__()
        self.positions_data = torch.FloatTensor(positions_data)
        self.light_angles_data = torch.FloatTensor(light_angles_data)
        self.spherical_harmonics_data = torch.FloatTensor(spherical_harmonics_data)
        assert spherical_harmonics_data.shape[1] == SH_FLOAT_COUNT

    def __getitem__(self, index):
        source = self.positions_data[index], self.light_angles_data[index]
        target = self.spherical_harmonics_data[index]
        return source, target
    
    def __len__(self):
        assert len(self.positions_data) == len(self.light_angles_data)
        return len(self.positions_data)

In [12]:
dataset = SHDataset(pos_grid, light_angle, true_sh)
train_ds, test_ds = random_split(dataset, [0.8, 0.2], torch.Generator().manual_seed(42))

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=len(test_ds))

In [22]:
def train(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader,
          epochs=100, lr=1e-3, patience=10, device="cuda"):
    model = model.to(device)
    criterion = nn.SmoothL1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val_loss = float("inf")
    epochs_no_improve = 0
    history = {"train_loss": [], "val_loss": []}

    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        for source, target in train_loader:
            pos, angle = source
            pos = pos.to(device)
            angle = angle.to(device)
            target = target.to(device)
            optimizer.zero_grad()
            outputs = model(pos, angle)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        avg_train_loss = np.mean(train_losses)

        # Validation
        model.eval()
        val_losses = []
        with torch.no_grad():
            for source, target in val_loader:
                pos, angle = source
                pos = pos.to(device)
                angle = angle.to(device)
                target = target.to(device)
                outputs = model(pos, angle)
                val_loss = criterion(outputs, target)
                val_losses.append(val_loss.item())

        avg_val_loss = np.mean(val_losses)

        history["train_loss"].append(avg_train_loss)
        history["val_loss"].append(avg_val_loss)

        print(f"Epoch [{epoch+1}/{epochs}] "
              f"Train Loss: {avg_train_loss:.6f} "
              f"Val Loss: {avg_val_loss:.6f}")

        if patience != 0: # early stopping based on validation loss
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                torch.save(model.state_dict(), "best_model.pth")
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"Early stopping triggered after {epoch+1} epochs.")
                    model.load_state_dict(torch.load("best_model.pth"))
                    break
    
    if patience == 0:
        torch.save(model.state_dict(), "best_model.pth")

    return model, history


In [24]:
model = NeuralSH()

In [25]:
model, history = train(model, train_dl, test_dl, epochs=10000, patience=0)

Epoch [1/10000] Train Loss: 0.128200 Val Loss: 0.091394
Epoch [2/10000] Train Loss: 0.118679 Val Loss: 0.089517
Epoch [3/10000] Train Loss: 0.117376 Val Loss: 0.087377
Epoch [4/10000] Train Loss: 0.114376 Val Loss: 0.085631
Epoch [5/10000] Train Loss: 0.114497 Val Loss: 0.084310
Epoch [6/10000] Train Loss: 0.112886 Val Loss: 0.083345
Epoch [7/10000] Train Loss: 0.109642 Val Loss: 0.082816
Epoch [8/10000] Train Loss: 0.110562 Val Loss: 0.082346
Epoch [9/10000] Train Loss: 0.108491 Val Loss: 0.082107
Epoch [10/10000] Train Loss: 0.112057 Val Loss: 0.082329
Epoch [11/10000] Train Loss: 0.111485 Val Loss: 0.082111
Epoch [12/10000] Train Loss: 0.110773 Val Loss: 0.081966
Epoch [13/10000] Train Loss: 0.111652 Val Loss: 0.081922
Epoch [14/10000] Train Loss: 0.109706 Val Loss: 0.082061
Epoch [15/10000] Train Loss: 0.108174 Val Loss: 0.082124
Epoch [16/10000] Train Loss: 0.110928 Val Loss: 0.081872
Epoch [17/10000] Train Loss: 0.111098 Val Loss: 0.081853
Epoch [18/10000] Train Loss: 0.108586 Va