In [None]:
import torch
import os
import numpy as np
import skimage

os.chdir('/homes/dnogina/code/topology-control/')

In [None]:
from src.CPipelineOrchestrator import CPipelineOrchestrator
from src.CModelTrainer import SDFDataset
from src.CEvaluator import CEvaluator
from src.CGeometryUtils import VolumeProcessor

In [None]:
ckpt = torch.load('/homes/dnogina/code/topology-control/artifacts/experiment_20250722_200033/training_artifacts/run_20250722_200037/checkpoints/latest_checkpoint.pth')

In [None]:
# torch.save(ckpt, '../artifacts/first_working.pth')

In [None]:
orc = CPipelineOrchestrator()

In [None]:
model = orc.architecture_manager.get_model()

In [None]:
model.load_state_dict(ckpt['model_state_dict'])

In [None]:
processing_results = orc.data_processor.generate_sdf_dataset(
                z_dim=16,
                latent_mean=0,
                latent_sd=0.02
            )

In [None]:
ds = SDFDataset(processing_results['dataset_info'])
ds.latent_vectors = ckpt['latent_vectors']

In [None]:
coords, latents, sdfs = ds[3]

In [None]:
resolution = 150
volume_processor = VolumeProcessor(device='cpu', resolution=resolution)
all_coords, actual_grid_size = volume_processor._get_volume_coords(device='cpu', resolution=resolution)

In [None]:
with torch.no_grad():
    predicted_sdfs = model(latents[None], all_coords[None])
    predicted_sdfs = predicted_sdfs.squeeze(-1) if predicted_sdfs.dim() > 1 else predicted_sdfs

In [None]:
def extract_mesh(grad_size_axis, sdf, level=0.0):
    """
    Extract mesh from SDF using marching cubes.
    """
    print(f"        extract_mesh: grid_size={grad_size_axis}, sdf_shape={sdf.shape}, sdf_numel={sdf.numel()}")
    
    try:
        # Validate input size
        expected_size = grad_size_axis ** 3
        if sdf.numel() != expected_size:
            raise ValueError(f"SDF size mismatch: got {sdf.numel()}, expected {expected_size} for {grad_size_axis}³ grid")
        
        # Check minimum grid size for marching cubes
        if grad_size_axis < 2:
            raise ValueError(f"Grid size {grad_size_axis} is too small. Marching cubes requires at least 2x2x2 grid.")
        
        # Extract zero-level set with marching cubes
        grid_sdf = sdf.view(grad_size_axis, grad_size_axis, grad_size_axis).detach().cpu().numpy()
        print(f"        Grid SDF shape after reshape: {grid_sdf.shape}")
        print(f"        Grid SDF range: [{grid_sdf.min():.4f}, {grid_sdf.max():.4f}]")
        
        # Automatically adjust level if it's outside the SDF range
        sdf_min, sdf_max = grid_sdf.min(), grid_sdf.max()
        original_level = level
        
        if level < sdf_min or level > sdf_max:
            if sdf_min < 0 < sdf_max:
                level = 0.0
                print(f"        Using zero level set (level=0.0)")
            else:
                # Use a level that's within the range, typically the median or a small value
                level = np.percentile(grid_sdf, 20)  # 20th percentile often works well
                print(f"        Level {original_level} outside range [{sdf_min:.4f}, {sdf_max:.4f}]")
                print(f"        Using level={level:.4f} (20th percentile)")
        
        vertices, faces, normals, _ = skimage.measure.marching_cubes(grid_sdf, level=level)
        print(f"        Marching cubes extracted: {len(vertices)} vertices, {len(faces)} faces")

        # Rescale vertices extracted with marching cubes
        x_max = np.array([1, 1, 1])
        x_min = np.array([-1, -1, -1])
        vertices = vertices * ((x_max-x_min) / grad_size_axis) + x_min

        return vertices, faces
    except Exception as e:
            print(f"        extract_mesh error: {e}")
            return np.array([]), np.array([])

In [None]:
vertices, faces = extract_mesh(actual_grid_size, predicted_sdfs[0], level=0.0)

In [None]:
import plotly.graph_objects as go
import numpy as np
from trimesh import Trimesh

import plotly.io as pio
pio.renderers.default='notebook'

x, y, z = -np.array(vertices).T
i, j, k = np.array(faces).T

fig = go.Figure(data=[
    go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, opacity=1),
    # go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode='markers', marker=dict(size=3, color=sdfs),marker_colorscale='Viridis')
])
# fig.update_layout(template='plotly_white',  show_scale=True)
fig.show()


In [None]:
os.chdir('/homes/dnogina/code/topology-control/volume')
from compute_volume import compute_genus

In [None]:
compute_genus(vertices, faces)