In [7]:
import torch
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
import os
import sys
from tqdm.notebook import tqdm

if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

from src.mlp_decomposition.mlp_composite import get_model
from external.siren.dataio import get_mgrid


# MLP and Pointcloud path
MLP_CHECKPOINT_PATH = "../logs/test_experiment/occ_e8409b544c626028a9b2becd26dc2fc1_model_final.pth"
GT_POINTCLOUD_PATH = "../data/baseline/02691156_100000_pc/e8409b544c626028a9b2becd26dc2fc1.obj.npy"

RESOLUTION = 128
OCCUPANCY_THRESHOLD = 0.5
MAX_BATCH = 64**3


print("Loading model...")
model = get_model(output_type="occ") 
state_dict = torch.load(MLP_CHECKPOINT_PATH, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
print("Model loaded successfully.")

part_names = list(model.registry.keys())
print(f"Found parts in model: {part_names}")


print(f"Creating a {RESOLUTION}^3 grid in the normalized [-0.5, 0.5] space...")
mgrid = get_mgrid(RESOLUTION, dim=3) * 0.5
part_points = {}

with torch.no_grad():
    for part_name in part_names:
        print(f"  Processing part: '{part_name}'...")
        points_for_this_part = []
        for i in tqdm(range(0, mgrid.shape[0], MAX_BATCH), desc=f"  Querying {part_name}", leave=False):
            coords_batch = mgrid[i:i+MAX_BATCH, :]
            model_input_dict = {"coords": coords_batch.unsqueeze(0)}
            model_output = model(model_input_dict, part_name=part_name)
            part_sdf = model_output['model_out'].squeeze(0)
            occupancy_probs = torch.sigmoid(part_sdf)
            inside_mask = occupancy_probs.squeeze() > OCCUPANCY_THRESHOLD
            if inside_mask.any():
                points_for_this_part.append(coords_batch[inside_mask].numpy())
        if points_for_this_part:
            part_points[part_name] = np.concatenate(points_for_this_part, axis=0)
            print(f"    -> Found {part_points[part_name].shape[0]} points for '{part_name}'")
        else:
            print(f"    -> No points found for '{part_name}'")


print(f"\nLoading ground truth point cloud from {GT_POINTCLOUD_PATH}...")
gt_point_cloud = np.load(GT_POINTCLOUD_PATH)
gt_coords_raw = gt_point_cloud[:, :3]
gt_occupancy = gt_point_cloud[:, 3]

gt_coords_inside = gt_coords_raw[gt_occupancy == 1]

def normalize_points(points):
    mean = np.mean(points, axis=0, keepdims=True)
    points_centered = points - mean
    v_max, v_min = np.amax(points_centered), np.amin(points_centered)
    scale_factor = 0.5 * 0.95 / (max(abs(v_min), abs(v_max)))
    return points_centered * scale_factor

gt_coords_normalized = normalize_points(gt_coords_inside)
print(f"Loaded and normalized {gt_coords_normalized.shape[0]} ground truth points.")


if part_points:
    print("\n--- Visualizing Predicted Part Shapes vs. Ground Truth ---")
    
    pio.renderers.default = "notebook"
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter3d(
        x=gt_coords_normalized[:, 0], 
        y=gt_coords_normalized[:, 1], 
        z=gt_coords_normalized[:, 2],
        mode='markers',
        marker=dict(size=1.5, color='black', opacity=0.15),
        name='Ground Truth'
    ))
    
    for part_name, points in part_points.items():
        fig.add_trace(go.Scatter3d(
            x=points[:, 0], 
            y=points[:, 1], 
            z=points[:, 2],
            mode='markers',
            marker=dict(size=2, opacity=0.8),
            name=f"Pred: {part_name}"
        ))

    fig.update_layout(
    title='Individual Part Shapes (Color) vs. Ground Truth (Black)',
    legend_title_text='Parts',
    scene=dict(
        xaxis_title='X', 
        yaxis_title='Y', 
        zaxis_title='Z',
        xaxis=dict(range=[-0.5, 0.5]),
        yaxis=dict(range=[-0.5, 0.5]),
        zaxis=dict(range=[-0.5, 0.5]),
        
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=1)
    ),
    margin=dict(l=0, r=0, b=0, t=40)
)
    
    fig.show()

else:
    print("No occupied points were found across any parts to visualize.")

Loading model...
Model loaded successfully.
Found parts in model: ['wing', 'body', 'tail', 'engine']
Creating a 128^3 grid in the normalized [-0.5, 0.5] space...
  Processing part: 'wing'...


  Querying wing:   0%|          | 0/8 [00:00<?, ?it/s]

    -> Found 5196 points for 'wing'
  Processing part: 'body'...


  Querying body:   0%|          | 0/8 [00:00<?, ?it/s]

    -> Found 8238 points for 'body'
  Processing part: 'tail'...


  Querying tail:   0%|          | 0/8 [00:00<?, ?it/s]

    -> Found 5487 points for 'tail'
  Processing part: 'engine'...


  Querying engine:   0%|          | 0/8 [00:00<?, ?it/s]

    -> Found 386 points for 'engine'

Loading ground truth point cloud from ../data/baseline/02691156_100000_pc/e8409b544c626028a9b2becd26dc2fc1.obj.npy...
Loaded and normalized 36401 ground truth points.

--- Visualizing Predicted Part Shapes vs. Ground Truth ---


In [None]:
import numpy as np
import os

file_path = "data/02691156_100000_pc/e8409b544c626028a9b2becd26dc2fc1.obj.npy"

if not os.path.exists(file_path):
    print(f"Error: File not found at {file_path}")
else:
    try:
        data_array = np.load(file_path)

        print(f"File loaded successfully: {file_path}")
        print(f"Dimensions (shape): {data_array.shape}")
        print(f"Data type (dtype): {data_array.dtype}")
        print(f"point raw: {data_array[0:20,:3]}")

    except Exception as e:
        print(f"An error occurred while loading the file: {e}")

File loaded successfully: data/02691156_100000_pc/e8409b544c626028a9b2becd26dc2fc1.obj.npy
Dimensions (shape): (200000, 4)
Data type (dtype): float64
point raw: [[ 0.05899151  0.02914964 -0.09924176]
 [-0.04267594  0.01086618 -0.1676903 ]
 [-0.2093799  -0.00855105  0.04620329]
 [ 0.00841943  0.17768449  0.45538166]
 [ 0.00884222  0.15889561  0.44986304]
 [-0.00180511  0.03176417 -0.0860197 ]
 [-0.01899128  0.02555662  0.39024669]
 [ 0.00682636 -0.07368231 -0.4106121 ]
 [ 0.13570586 -0.01752314 -0.05232706]
 [-0.01720981  0.03408437  0.31563173]
 [ 0.08692576 -0.00203494 -0.05368338]
 [ 0.0238001  -0.04114561 -0.00295171]
 [-0.08849004  0.01376307  0.35517888]
 [-0.02152399  0.02769117 -0.00899179]
 [-0.0048515  -0.03150539 -0.36424426]
 [-0.29225496  0.01349883  0.09805016]
 [ 0.01321483  0.19286927  0.45891458]
 [-0.10673672 -0.00786897 -0.08045414]
 [ 0.037921   -0.01867238 -0.06928201]
 [ 0.05221188  0.00696405  0.37786461]]
