# Voxel Dataset Visualization

Interactive 3D hierarchical explorer with drill-down navigation!

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
import polars as pl
import json
from ipywidgets import interact, IntSlider, Dropdown, VBox, HBox, Button, Output, Layout, HTML
from IPython.display import display, clear_output

from voxel_dataset_generator.utils.metadata import MetadataAnalyzer

%matplotlib inline

## 1. Load Dataset

In [6]:
dataset_dir = Path("dataset")

if not dataset_dir.exists():
    print(f"Dataset directory not found: {dataset_dir}")
    print("Please generate a dataset first using: uv run voxel-gen --num-objects 10")
else:
    with open(dataset_dir / "metadata.json") as f:
        dataset_metadata = json.load(f)
    print("Dataset Information:")
    print(f"  Name: {dataset_metadata['dataset_name']}")
    print(f"  Base Resolution: {dataset_metadata['base_resolution']}^3")
    print(f"  Number of Levels: {dataset_metadata['num_levels']}")
    print(f"  Number of Objects: {dataset_metadata['num_objects']}")
    print(f"\nLevel Resolutions: {dataset_metadata['level_resolutions']}")
    
    # Load split information if available
    splits_file = dataset_dir / "splits.json"
    if splits_file.exists():
        with open(splits_file) as f:
            splits_data = json.load(f)
        print("\nSplit Configuration:")
        print(f"  Train: {splits_data['config']['train_ratio']:.0%}")
        print(f"  Val: {splits_data['config']['val_ratio']:.0%}")
        print(f"  Test: {splits_data['config']['test_ratio']:.0%}")
        print(f"\nSplit Distribution:")
        stats = splits_data['statistics']
        print(f"  Train: {stats['objects']['train']} objects")
        print(f"  Val: {stats['objects']['val']} objects")
        print(f"  Test: {stats['objects']['test']} objects")
    else:
        splits_data = None
        print("\nNo split information found.")

Dataset Information:
  Name: Thingi10k_Hierarchical_Voxels
  Base Resolution: 128^3
  Number of Levels: 6
  Number of Objects: 10

Level Resolutions: [128, 64, 32, 16, 8, 4]

Split Configuration:
  Train: 80%
  Val: 20%
  Test: 0%

Split Distribution:
  Train: 8 objects
  Val: 2 objects
  Test: 0 objects


## 2. Hierarchical Explorer with Drill-Down Navigation

Click on any octant to drill down into its sub-volumes!

In [None]:
class HierarchicalVoxelExplorer:
    def __init__(self, dataset_dir: Path, splits_data=None):
        self.dataset_dir = dataset_dir
        self.current_object_id = None
        self.subdivision_map = None
        self.splits_data = splits_data
        
        # Navigation state: list of global positions
        self.navigation_stack = []
        
        # Flag to prevent multiple updates during initialization
        self._initializing = True
        
        objects_dir = dataset_dir / "objects"
        self.object_ids = sorted([d.name.replace("object_", "") 
                                   for d in objects_dir.iterdir() if d.is_dir()])
    
    def get_object_split(self, object_id: str) -> str:
        """Get the split (train/val/test) for an object."""
        if self.splits_data is None:
            return None
        return self.splits_data['object_splits'].get(object_id, None)
    
    def get_hash_splits(self, hash_val: str) -> list:
        """Get which splits contain this hash."""
        if self.splits_data is None:
            return []
        
        hash_info = self.splits_data.get('hash_object_usage', {}).get(hash_val, {})
        splits = list(set(hash_info.values()))
        return splits
    
    def is_trivial_hash(self, hash_val: str) -> bool:
        """Check if hash is trivial (all 0s or all 1s)."""
        if self.splits_data is None:
            return False
        return hash_val in self.splits_data.get('trivial_hashes', [])
    
    def load_voxel_grid(self, object_id: str) -> np.ndarray:
        obj_dir = self.dataset_dir / "objects" / f"object_{object_id}"
        voxel_file = obj_dir / "level_0.npz"
        data = np.load(voxel_file)
        return data['voxels']
    
    def load_subdivision_map(self, object_id: str):
        obj_dir = self.dataset_dir / "objects" / f"object_{object_id}"
        with open(obj_dir / "subdivision_map.json") as f:
            self.subdivision_map = json.load(f)
    
    def get_subvolumes_for_parent(self, level: int, parent_global_pos: tuple = None):
        """Get 8 sub-volumes for a specific parent using global positions."""
        if self.subdivision_map is None:
            return []
        
        # Get all subvolumes at this level
        level_subvols = [s for s in self.subdivision_map if s['level'] == level]
        
        if parent_global_pos is None or level == 1:
            # Return first 8 (direct children of level 0)
            return level_subvols[:8]
        
        # Filter to children of specific parent based on global positions
        # Parent size at this level
        parent_size = 128 // (2 ** (level - 1))
        
        children = []
        for s in level_subvols:
            # Check if this subvolume is within the parent bounds (using global positions)
            if (s['global_position_x'] >= parent_global_pos[0] and 
                s['global_position_x'] < parent_global_pos[0] + parent_size and
                s['global_position_y'] >= parent_global_pos[1] and 
                s['global_position_y'] < parent_global_pos[1] + parent_size and
                s['global_position_z'] >= parent_global_pos[2] and 
                s['global_position_z'] < parent_global_pos[2] + parent_size):
                children.append(s)
                if len(children) == 8:
                    break
        
        return children
    
    def load_subvolume(self, subvolume_hash: str, level: int) -> np.ndarray:
        subvol_dir = self.dataset_dir / "subvolumes" / f"level_{level}"
        hash_prefix = subvolume_hash[:2]
        subvol_file = subvol_dir / hash_prefix / f"{subvolume_hash}.npz"
        
        if subvol_file.exists():
            data = np.load(subvol_file)
            return data['voxels']
        return None
    
    def plot_voxels_3d_cubes(self, voxels, colorscale='Viridis', max_voxels=10000):
        occupied = np.where(voxels)
        x, y, z = occupied
        
        if len(x) == 0:
            return None
        
        if len(x) > max_voxels:
            indices = np.random.choice(len(x), max_voxels, replace=False)
            x, y, z = x[indices], y[indices], z[indices]
        
        vertices = []
        faces = []
        colors = []
        
        cube_verts = np.array([
            [0,0,0], [1,0,0], [1,1,0], [0,1,0],
            [0,0,1], [1,0,1], [1,1,1], [0,1,1]
        ])
        
        cube_faces = np.array([
            [0,1,2], [0,2,3], [4,5,6], [4,6,7],
            [0,1,5], [0,5,4], [2,3,7], [2,7,6],
            [0,3,7], [0,7,4], [1,2,6], [1,6,5]
        ])
        
        for xi, yi, zi in zip(x, y, z):
            verts = cube_verts + np.array([xi, yi, zi])
            base_idx = len(vertices)
            vertices.extend(verts)
            faces.extend(cube_faces + base_idx)
            colors.extend([zi] * 8)
        
        vertices = np.array(vertices)
        faces = np.array(faces)
        colors = np.array(colors)
        
        return go.Mesh3d(
            x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2],
            i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
            intensity=colors, colorscale=colorscale,
            showscale=False, opacity=0.9
        )
    
    def plot_voxels_3d_isosurface(self, voxels, colorscale='Viridis'):
        if not np.any(voxels):
            return None
        
        return go.Isosurface(
            x=np.arange(voxels.shape[0]).repeat(voxels.shape[1] * voxels.shape[2]),
            y=np.tile(np.arange(voxels.shape[1]).repeat(voxels.shape[2]), voxels.shape[0]),
            z=np.tile(np.arange(voxels.shape[2]), voxels.shape[0] * voxels.shape[1]),
            value=voxels.flatten(),
            isomin=0.5, isomax=1.0, surface_count=1,
            colorscale=colorscale, showscale=False, opacity=0.9,
            caps=dict(x_show=False, y_show=False, z_show=False)
        )
    
    def format_split_badge(self, split_name: str) -> str:
        """Format split name with emoji badge."""
        badges = {
            'train': 'üü¶ Train',
            'val': 'üüß Val',
            'test': 'üü• Test'
        }
        return badges.get(split_name, split_name)
    
    def visualize_current_level(self, style='auto'):
        """Visualize current position in hierarchy."""
        if len(self.navigation_stack) == 0:
            # Show level 0
            voxels = self.load_voxel_grid(self.current_object_id)
            trace = self.plot_voxels_3d_isosurface(voxels, 'Viridis')
            
            if trace:
                fig = go.Figure(data=[trace])
                
                # Build title with split info
                split = self.get_object_split(self.current_object_id)
                split_text = f" [{self.format_split_badge(split)}]" if split else ""
                
                fig.update_layout(
                    title=f"Object {self.current_object_id}{split_text} - Level 0 (128¬≥)",
                    scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z', aspectmode='cube'),
                    width=800, height=800
                )
                return fig
        else:
            # Show 8 sub-volumes at current level
            current_level = len(self.navigation_stack)
            parent_global_pos = self.navigation_stack[-1] if len(self.navigation_stack) > 0 else None
            
            subvolumes = self.get_subvolumes_for_parent(current_level, parent_global_pos)
            
            if not subvolumes:
                return None
            
            # Auto-select style
            grid_size = 128 // (2 ** (current_level))
            if style == 'auto':
                actual_style = 'cubes' if grid_size <= 32 else 'isosurface'
            else:
                actual_style = style
            
            # Create subplot titles with split information
            subplot_titles = []
            for i, s in enumerate(subvolumes):
                title_parts = [f"Octant {i}"]
                title_parts.append(f"Global: ({s['global_position_x']},{s['global_position_y']},{s['global_position_z']})")
                
                if s['is_empty']:
                    title_parts.append('EMPTY')
                else:
                    title_parts.append(f'Occupied ({grid_size}¬≥)')
                    
                    # Add split information
                    hash_splits = self.get_hash_splits(s['hash'])
                    if hash_splits:
                        if self.is_trivial_hash(s['hash']):
                            title_parts.append('üîò Trivial (shared)')
                        elif len(hash_splits) == 1:
                            title_parts.append(f"{self.format_split_badge(hash_splits[0])} only")
                        else:
                            split_badges = [self.format_split_badge(sp) for sp in sorted(hash_splits)]
                            title_parts.append(f"Shared: {', '.join(split_badges)}")
                
                subplot_titles.append('\n'.join(title_parts))
            
            # Create subplots
            fig = make_subplots(
                rows=2, cols=4,
                specs=[[{'type': 'scene'} for _ in range(4)] for _ in range(2)],
                subplot_titles=subplot_titles,
                horizontal_spacing=0.02, vertical_spacing=0.15
            )
            
            colors = ['Viridis', 'Plasma', 'Inferno', 'Magma', 'Cividis', 'Blues', 'Greens', 'Reds']
            
            for idx, subvol in enumerate(subvolumes):
                row = idx // 4 + 1
                col = idx % 4 + 1
                
                if not subvol['is_empty']:
                    voxels = self.load_subvolume(subvol['hash'], current_level)
                    if voxels is not None:
                        if actual_style == 'cubes':
                            trace = self.plot_voxels_3d_cubes(voxels, colors[idx])
                        else:
                            trace = self.plot_voxels_3d_isosurface(voxels, colors[idx])
                        
                        if trace:
                            fig.add_trace(trace, row=row, col=col)
                
                scene_name = f'scene{idx+1}' if idx > 0 else 'scene'
                fig.update_layout({
                    scene_name: dict(
                        xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False),
                        aspectmode='cube'
                    )
                })
            
            breadcrumb = ' > '.join([f"L{i+1}" for i in range(len(self.navigation_stack))]) or "Root"
            
            # Add object split to main title
            split = self.get_object_split(self.current_object_id)
            split_text = f" [{self.format_split_badge(split)}]" if split else ""
            
            fig.update_layout(
                title_text=f"Object {self.current_object_id}{split_text} - Level {current_level} ({grid_size}¬≥) - Path: {breadcrumb}",
                showlegend=False, height=800, margin=dict(l=0, r=0, t=100, b=0)
            )
            
            return fig
        
        return None
    
    def create_interactive_explorer(self):
        """Create interactive explorer with drill-down."""
        # Create object options with split badges
        if self.splits_data:
            object_options = [(f"{obj_id} [{self.format_split_badge(self.get_object_split(obj_id))}]", obj_id) 
                             for obj_id in self.object_ids]
        else:
            object_options = [(obj_id, obj_id) for obj_id in self.object_ids]
        
        object_dropdown = Dropdown(
            options=object_options, value=self.object_ids[0],
            description='Object:', style={'description_width': '80px'}
        )
        
        style_dropdown = Dropdown(
            options=['auto', 'cubes', 'isosurface'], value='auto',
            description='Style:', style={'description_width': '80px'}
        )
        
        reset_btn = Button(description='Reset to Top', button_style='warning', layout=Layout(width='120px'))
        back_btn = Button(description='‚¨Ü Back', button_style='info', layout=Layout(width='100px'))
        
        breadcrumb_html = HTML(value="<b>Current Path:</b> Root")
        info_output = Output()
        viz_output = Output()
        button_output = Output()
        
        def update_display():
            if self._initializing:
                return
                
            # Update info
            info_output.clear_output(wait=True)
            with info_output:
                current_level = len(self.navigation_stack)
                grid_size = 128 // (2 ** current_level) if current_level > 0 else 128
                
                split = self.get_object_split(self.current_object_id)
                split_info = f" | Split: {self.format_split_badge(split)}" if split else ""
                
                print(f"üìç Level: {current_level} | Grid Size: {grid_size}¬≥{split_info}")
                if current_level == 0:
                    print("‚ÑπÔ∏è  Showing top-level object. Click 'Drill Down' to see sub-volumes.")
                else:
                    print(f"‚ÑπÔ∏è  Showing 8 sub-volumes at level {current_level}")
                    print("    Click octant buttons below to drill deeper, or 'Back' to go up.")
                    if self.splits_data:
                        print("    üü¶ = Train only | üüß = Val only | üü• = Test only | üîò = Trivial/Shared")
            
            # Update visualization
            viz_output.clear_output(wait=True)
            with viz_output:
                fig = self.visualize_current_level(style_dropdown.value)
                if fig:
                    # Use config to prevent double rendering
                    fig.show(config={'displayModeBar': False})
            
            # Update buttons
            button_output.clear_output(wait=True)
            with button_output:
                current_level = len(self.navigation_stack)
                
                if current_level == 0:
                    # At root, show drill down button
                    drill_btn = Button(description='‚¨á Drill Down to Level 1', 
                                      button_style='success', layout=Layout(width='200px'))
                    def on_drill(b):
                        self.navigation_stack.append((0, 0, 0))  # Start at origin
                        update_display()
                    drill_btn.on_click(on_drill)
                    display(drill_btn)
                else:
                    # Show octant buttons
                    parent_global_pos = self.navigation_stack[-1] if self.navigation_stack else None
                    subvols = self.get_subvolumes_for_parent(current_level, parent_global_pos)
                    
                    if current_level < 5:  # Can drill deeper
                        button_grid = []
                        for idx, subvol in enumerate(subvols[:8]):
                            if not subvol['is_empty']:
                                btn = Button(
                                    description=f"Octant {idx} ‚¨á",
                                    button_style='success',
                                    layout=Layout(width='110px', margin='2px')
                                )
                                def make_handler(s):
                                    def handler(b):
                                        # Use global position for navigation
                                        global_pos = (s['global_position_x'], s['global_position_y'], s['global_position_z'])
                                        self.navigation_stack.append(global_pos)
                                        update_display()
                                    return handler
                                btn.on_click(make_handler(subvol))
                                button_grid.append(btn)
                        
                        if button_grid:
                            print("Click an octant to drill down:")
                            display(HBox(button_grid[:4]))
                            if len(button_grid) > 4:
                                display(HBox(button_grid[4:]))
                    else:
                        print("üèÅ Maximum depth reached (Level 5 - 4¬≥ voxels)")
            
            # Update breadcrumb
            if len(self.navigation_stack) == 0:
                path = "<b>Path:</b> Root (Level 0)"
            else:
                levels = " ‚Üí ".join([f"L{i+1}" for i in range(len(self.navigation_stack))])
                path = f"<b>Path:</b> {levels}"
            breadcrumb_html.value = path
            
            # Enable/disable buttons
            back_btn.disabled = len(self.navigation_stack) == 0
        
        def on_reset(b):
            self.navigation_stack = []
            update_display()
        
        def on_back(b):
            if self.navigation_stack:
                self.navigation_stack.pop()
                update_display()
        
        def on_object_change(change):
            self.current_object_id = change['new']
            self.navigation_stack = []
            self.load_subdivision_map(self.current_object_id)
            update_display()
        
        def on_style_change(change):
            update_display()
        
        # Attach event handlers
        reset_btn.on_click(on_reset)
        back_btn.on_click(on_back)
        object_dropdown.observe(on_object_change, 'value')
        style_dropdown.observe(on_style_change, 'value')
        
        controls = VBox([
            HBox([object_dropdown, style_dropdown, back_btn, reset_btn]),
            breadcrumb_html
        ])
        
        # Display widgets
        display(controls)
        display(info_output)
        display(viz_output)
        display(button_output)
        
        # Initialize data
        self.current_object_id = object_dropdown.value
        self.load_subdivision_map(self.current_object_id)
        
        # Now enable updates and call once
        self._initializing = False
        update_display()

if dataset_dir.exists():
    explorer = HierarchicalVoxelExplorer(dataset_dir, splits_data)
    explorer.create_interactive_explorer()
else:
    print("Please create a dataset first!")

VBox(children=(HBox(children=(Dropdown(description='Object:', options=(('0000 [üüß Val]', '0000'), ('0001 [üüß Val‚Ä¶

Output()

Output()

Output()