In [None]:
from neural_poisson.model.neural_poisson import NeuralPoisson

ckpt_path = "/home/borth/2d-gaussian-splatting/logs/2025-02-19/14-49-43/checkpoints/epoch_339.ckpt"
model = NeuralPoisson.load_from_checkpoint(ckpt_path)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from pytorch3d.ops.marching_cubes import marching_cubes
import open3d as o3d

# Generate a 3D volumetric grid (e.g., a sphere)
grid_size = 32
x = torch.linspace(-1, 1, grid_size)
y = torch.linspace(-1, 1, grid_size)
z = torch.linspace(-1, 1, grid_size)
X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
vol = X**2 + Y**2 + Z**2 - 0.5**2  # Implicit surface of a sphere

# Apply marching cubes
verts, faces = marching_cubes(vol[None], isolevel=0.5)

# Convert to NumPy for visualization
verts = verts[0].numpy()
faces = faces[0].numpy()

# Create an Open3D mesh
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces)
mesh.compute_vertex_normals()
o3d.visualization.draw_plotly([mesh])


In [None]:
import wandb
wandb.init(project="thesis", entity="robinborth")

In [None]:

v = (vol[None].permute(1, 0, 2, 3).expand(32, 3, 32, 32) > 0).to(torch.uint8) * 255
video = wandb.Video(v, fps=20, format="gif")
wandb.log({"video3": video})

In [None]:
v.max()

In [None]:
vol.shape

In [None]:
import torch
ckpt = torch.load("/home/borth/2d-gaussian-splatting/test.ckpt")
ckpt

In [None]:
import torch
from neural_poisson.model.encoder import MLP

mlp = MLP()
points = torch.rand((1000, 3))
torch.nn.init.uniform_(points, -1.0, 1.0)
mlp(points).mean()

In [None]:
import torch

axis = "x"

grid_vals = torch.linspace(-1.0, 1.0, 256)
xs, ys = torch.meshgrid(grid_vals, grid_vals, indexing="ij")
zs = torch.zeros_like(xs)
if axis == "x":
    coords = (zs.ravel(), xs.ravel(), ys.ravel())
if axis == "y":
    coords = (xs.ravel(), zs.ravel(), ys.ravel())
if axis == "z":
    coords = (xs.ravel(), ys.ravel(), zs.ravel())
grid = torch.stack(coords).reshape(-1, 3)
x, _ = self.forward(points.to(self.device))

xs

In [None]:
import torch
import torch.nn as nn
from collections import OrderedDict
from typing import Any, Callable
import numpy as np
from functools import partial

class BaseActivation(nn.Module):
    @classmethod
    def get_name(cls):
        return cls.__name__.replace("Activation", "").lower()

    @classmethod
    @torch.no_grad()
    def weight_init(self):
        return None

    @classmethod
    @torch.no_grad()
    def first_layer_weight_init(self):
        return None

class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int = 3,
        out_features: int = 1,
        hidden_features: int = 256,
        num_hidden_layers: int = 5,
        activation: type[nn.Module] | partial = nn.ReLU,
        out_activation: bool = False,
        out_bias: bool = True,
        weight_init: Callable | None = None,
        first_layer_weight_init: Callable | None = None,
    ):
        # extract the name of the activation
        activation_cls = activation
        if isinstance(activation, partial):
            activation_cls = activation.func

        layers: list[Any] = []
        names: list[str] = []

        # input layers
        layers.append(nn.Linear(in_features, hidden_features))
        names.append("layer_0")
        layers.append(activation())
        names.append(f"{activation_cls.get_name()}_0")

        # hidden layers
        for i in range(num_hidden_layers):
            layers.append(nn.Linear(hidden_features, hidden_features))
            names.append(f"layer_{i+1}")
            layers.append(activation())
            names.append(f"{activation_cls.get_name()}_{i+1}")

        # output layer
        layers.append(nn.Linear(hidden_features, out_features, bias=out_bias))
        names.append(f"layer_{i+2}")
        if out_activation:
            layers.append(activation())
            names.append(f"{activation_cls.get_name()}_{i+2}")

        # initilize the mlp with the layers
        ordered_dict = OrderedDict(zip(names, layers))
        super().__init__(ordered_dict)

        # initilize the weights of the mlp based on the activation function
        if weight_init is None:
            weight_init = activation_cls.weight_init()
        if first_layer_weight_init is None:
            first_layer_weight_init = activation_cls.first_layer_weight_init()

        if weight_init is not None:
            self.apply(weight_init)
        if first_layer_weight_init is not None:
            self["layer_0"].apply(first_layer_weight_init)

    def forward(self, x):
        return super().forward(x)



class SinusActivation(BaseActivation):
    def __init__(self, w: float = 1.0):
        super().__init__()
        self.w = w

    @classmethod
    @torch.no_grad()
    def init_weights(cls, m: nn.Module):
        if not hasattr(m, "weight"):
            return
        num_input = m.weight.size(-1)
        U = np.sqrt(6 / num_input) / 30
        m.weight.uniform_(-U, U)
    
    @classmethod
    @torch.no_grad()
    def first_layer_weight_init(m: nn.Module):
        if not hasattr(m, "weight"):
            return
        num_input = m.weight.size(-1)
        U = 1 / num_input 
        m.weight.uniform_(-U, U)

    def forward(self, x: torch.Tensor):
        return torch.sin(self.w * x)

class ReLUActivation(BaseActivation, nn.ReLU):
    pass



mlp = MLP(activation=ReLUActivation)
mlp
# init_weights = SinusActivation.init_weights
# init_weights(nn.Linear(10, 10))

In [None]:
from neural_poisson.model.encoder import MLP, SinusActivation

mlp = MLP(activation=SinusActivation)

In [None]:
activation_cls.get_name()

In [None]:
from functools import partial

activation = partial(SinusActivation)

In [None]:
import open3d as o3d
from neural_poisson.data.prepare import load_mesh
path = "/home/borth/2d-gaussian-splatting/logs/2025-02-26/08-43-11/mesh/00000.obj"
_mesh = load_mesh(path)


mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(_mesh.verts_packed().cpu().numpy())
mesh.triangles = o3d.utility.Vector3iVector(_mesh.faces_packed().cpu().numpy())
mesh.compute_vertex_normals()
o3d.visualization.draw_plotly([mesh])