# Point Cloud Serialization Tutorial (PTv3/LitePT)

This notebook walks through the **serialization process** used in Point Transformer V3 (PTv3) and LitePT.

We'll use your car point cloud to understand:
1. **Grid Quantization** - Converting continuous coordinates to discrete grid
2. **Z-Order (Morton) Encoding** - Converting 3D coords to 1D codes
3. **Sorting & Serialization** - Creating ordered sequences
4. **Patch Grouping** - Grouping points for local attention
5. **Shuffle Order Strategy** - Using multiple serialization patterns

---
## Step 0: Imports and Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from typing import Tuple, List, Dict

# For nice inline plots
%matplotlib inline
plt.style.use('default')

print("Setup complete!")

---
## Step 1: Load the Point Cloud Data

In [None]:
# Your CSV path
CSV_PATH = "/DATA/pyare/Routine/LLM/Reasoning/LLMs-from-scratch-pyare/chapter-2-tokenization/car_pcloud.csv"

# Load the data
df = pd.read_csv(CSV_PATH)

print(f"Loaded {len(df)} points")
print(f"\nColumns: {df.columns.tolist()}")
print(f"\nFirst 5 rows:")
df.head()

In [None]:
# Extract coordinates
coords_raw = df[['x', 'y', 'z']].values

print("Raw coordinate ranges:")
print(f"  X: [{coords_raw[:, 0].min():.2f}, {coords_raw[:, 0].max():.2f}]")
print(f"  Y: [{coords_raw[:, 1].min():.2f}, {coords_raw[:, 1].max():.2f}]")
print(f"  Z: [{coords_raw[:, 2].min():.2f}, {coords_raw[:, 2].max():.2f}]")

# Normalize to start from origin (0, 0, 0)
coords = coords_raw - coords_raw.min(axis=0)

print(f"\nNormalized coordinate ranges:")
print(f"  X: [{coords[:, 0].min():.2f}, {coords[:, 0].max():.2f}]")
print(f"  Y: [{coords[:, 1].min():.2f}, {coords[:, 1].max():.2f}]")
print(f"  Z: [{coords[:, 2].min():.2f}, {coords[:, 2].max():.2f}]")

In [None]:
# Visualize the original point cloud
fig = plt.figure(figsize=(12, 5))

# Colored by intensity
ax1 = fig.add_subplot(121, projection='3d')
scatter1 = ax1.scatter(coords[:, 0], coords[:, 1], coords[:, 2], 
                       c=df['intensity'].values, cmap='viridis', s=1)
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
ax1.set_title('Car Point Cloud (colored by intensity)')
plt.colorbar(scatter1, ax=ax1, shrink=0.5, label='Intensity')

# Colored by Z (height)
ax2 = fig.add_subplot(122, projection='3d')
scatter2 = ax2.scatter(coords[:, 0], coords[:, 1], coords[:, 2], 
                       c=coords[:, 2], cmap='jet', s=1)
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
ax2.set_title('Car Point Cloud (colored by height)')
plt.colorbar(scatter2, ax=ax2, shrink=0.5, label='Z height')

plt.tight_layout()
plt.show()

print(f"\nTotal points: {len(coords)}")

---
## Step 2: Grid Quantization

**Purpose:** Convert continuous 3D coordinates into discrete grid coordinates.

This is essential because:
1. Morton/Hilbert codes work on integers
2. It acts as a form of voxelization
3. Controls the resolution of the serialization

```
Continuous (0.15, 0.27, 0.43) → Grid (1, 2, 4) with grid_size=0.1
```

In [None]:
def quantize_to_grid(coords: np.ndarray, grid_size: float = 0.1) -> np.ndarray:
    """
    Convert continuous 3D coordinates to discrete grid coordinates.
    
    Args:
        coords: (N, 3) array of [x, y, z] coordinates
        grid_size: Size of each voxel (e.g., 0.1 = 10cm)
    
    Returns:
        grid_coords: (N, 3) array of integer grid coordinates
    """
    grid_coords = np.floor(coords / grid_size).astype(np.int64)
    return grid_coords

# Choose grid size (experiment with different values!)
GRID_SIZE = 0.05  # 5cm voxels

grid_coords = quantize_to_grid(coords, GRID_SIZE)

print(f"Grid size: {GRID_SIZE}m ({GRID_SIZE*100}cm)")
print(f"\nGrid coordinate ranges:")
print(f"  X: [{grid_coords[:, 0].min()}, {grid_coords[:, 0].max()}] ({grid_coords[:, 0].max() - grid_coords[:, 0].min() + 1} cells)")
print(f"  Y: [{grid_coords[:, 1].min()}, {grid_coords[:, 1].max()}] ({grid_coords[:, 1].max() - grid_coords[:, 1].min() + 1} cells)")
print(f"  Z: [{grid_coords[:, 2].min()}, {grid_coords[:, 2].max()}] ({grid_coords[:, 2].max() - grid_coords[:, 2].min() + 1} cells)")

# Count unique grid cells
unique_cells = len(np.unique(grid_coords, axis=0))
print(f"\nUnique grid cells: {unique_cells} (out of {len(coords)} points)")
print(f"Average points per cell: {len(coords) / unique_cells:.2f}")

In [None]:
# Let's look at a few example points
print("Example: Continuous → Grid conversion")
print("=" * 60)
print(f"{'Point':<8} {'Continuous (x,y,z)':<30} {'Grid (gx,gy,gz)':<20}")
print("-" * 60)

for i in [0, 1, 2, 100, 500, len(coords)-1]:
    cx, cy, cz = coords[i]
    gx, gy, gz = grid_coords[i]
    print(f"P{i:<6} ({cx:>7.3f}, {cy:>7.3f}, {cz:>7.3f})    ({gx:>3}, {gy:>3}, {gz:>3})")

---
## Step 3: Z-Order (Morton) Encoding

**Purpose:** Convert 3D grid coordinates into a single 1D integer code.

**How it works:** Interleave the bits of x, y, z coordinates.

```
Example: Point at grid (5, 3, 2)
  x = 5 = 101 (binary)
  y = 3 = 011 (binary)
  z = 2 = 010 (binary)
  
Interleave (z₂y₂x₂ z₁y₁x₁ z₀y₀x₀):
  bit 2: z=0, y=0, x=1 → 001
  bit 1: z=1, y=1, x=0 → 110
  bit 0: z=0, y=1, x=1 → 011
  
Result: 001 110 011 = 118 (decimal)
```

In [None]:
def z_order_encode(grid_coords: np.ndarray, depth: int = 16) -> np.ndarray:
    """
    Encode 3D grid coordinates into Z-order (Morton) codes.
    
    Args:
        grid_coords: (N, 3) array of integer grid coordinates
        depth: Number of bits per dimension (max coord = 2^depth - 1)
    
    Returns:
        codes: (N,) array of Morton codes
    """
    x = grid_coords[:, 0].astype(np.uint64)
    y = grid_coords[:, 1].astype(np.uint64)
    z = grid_coords[:, 2].astype(np.uint64)
    
    codes = np.zeros(len(x), dtype=np.uint64)
    
    # Interleave bits: for each bit position, place x, y, z bits
    for i in range(depth):
        x_bit = (x >> i) & 1
        y_bit = (y >> i) & 1
        z_bit = (z >> i) & 1
        
        # Place in interleaved positions
        codes |= (x_bit << (3 * i))      # x at position 3i
        codes |= (y_bit << (3 * i + 1))  # y at position 3i+1
        codes |= (z_bit << (3 * i + 2))  # z at position 3i+2
    
    return codes

# Compute Morton codes
morton_codes = z_order_encode(grid_coords)

print(f"Morton codes computed for {len(morton_codes)} points")
print(f"Code range: [{morton_codes.min()}, {morton_codes.max()}]")
print(f"Unique codes: {len(np.unique(morton_codes))}")

In [None]:
# Let's trace through a few examples step by step
def explain_morton_encoding(gx: int, gy: int, gz: int, depth: int = 8) -> int:
    """
    Explain Morton encoding step by step.
    """
    print(f"\nEncoding point ({gx}, {gy}, {gz}):")
    print(f"  x = {gx} = {bin(gx)} (binary)")
    print(f"  y = {gy} = {bin(gy)} (binary)")
    print(f"  z = {gz} = {bin(gz)} (binary)")
    
    code = 0
    interleaved_bits = ""
    
    print(f"\n  Bit interleaving (z_i, y_i, x_i for each position i):")
    
    for i in range(depth):
        x_bit = (gx >> i) & 1
        y_bit = (gy >> i) & 1
        z_bit = (gz >> i) & 1
        
        code |= (x_bit << (3 * i))
        code |= (y_bit << (3 * i + 1))
        code |= (z_bit << (3 * i + 2))
        
        interleaved_bits = f"{z_bit}{y_bit}{x_bit}" + interleaved_bits
        
        if i < 4:  # Show first 4 bits
            print(f"    bit {i}: z={z_bit}, y={y_bit}, x={x_bit} → {z_bit}{y_bit}{x_bit}")
    
    print(f"\n  Interleaved: {interleaved_bits}")
    print(f"  Morton code: {code}")
    
    return code

# Examples from your data
print("=" * 60)
print("MORTON ENCODING EXAMPLES")
print("=" * 60)

for i in [0, 10, 100]:
    gx, gy, gz = grid_coords[i]
    explain_morton_encoding(gx, gy, gz)

---
## Step 4: Sorting (Serialization)

**Purpose:** Sort points by their Morton codes to create an ordered sequence.

After sorting:
- Points with low Morton codes come first
- Spatially nearby points tend to be close in the sequence
- This is the "serialization" - converting unordered 3D points to ordered 1D sequence

In [None]:
# Sort by Morton codes
sorted_indices = np.argsort(morton_codes)
inverse_indices = np.argsort(sorted_indices)  # To restore original order

print("Serialization complete!")
print(f"\nFirst 10 points in serialized order:")
print(f"{'Rank':<6} {'Orig Idx':<10} {'Grid (x,y,z)':<20} {'Morton Code':<15} {'Continuous (x,y,z)'}")
print("-" * 80)

for rank in range(10):
    orig_idx = sorted_indices[rank]
    gx, gy, gz = grid_coords[orig_idx]
    cx, cy, cz = coords[orig_idx]
    code = morton_codes[orig_idx]
    print(f"{rank:<6} {orig_idx:<10} ({gx:>3},{gy:>3},{gz:>3}){'':>8} {code:<15} ({cx:.2f}, {cy:.2f}, {cz:.2f})")

In [None]:
# Visualize the serialization order
fig = plt.figure(figsize=(14, 5))

# Create colors based on serialization position
serial_position = np.zeros(len(coords))
for i, idx in enumerate(sorted_indices):
    serial_position[idx] = i

# Original order
ax1 = fig.add_subplot(131, projection='3d')
ax1.scatter(coords[:, 0], coords[:, 1], coords[:, 2], 
           c=np.arange(len(coords)), cmap='viridis', s=1)
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
ax1.set_title('Original Order\n(colored by original index)')

# Serialized order (colored by position in serialized sequence)
ax2 = fig.add_subplot(132, projection='3d')
scatter = ax2.scatter(coords[:, 0], coords[:, 1], coords[:, 2], 
                      c=serial_position, cmap='viridis', s=1)
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
ax2.set_title('Z-Order Serialization\n(colored by serial position)')
plt.colorbar(scatter, ax=ax2, shrink=0.5, label='Serial Position')

# Show just the first N points in serialized order
N_SHOW = 200
first_n_indices = sorted_indices[:N_SHOW]
ax3 = fig.add_subplot(133, projection='3d')
ax3.scatter(coords[first_n_indices, 0], 
           coords[first_n_indices, 1], 
           coords[first_n_indices, 2], 
           c=np.arange(N_SHOW), cmap='plasma', s=5)
ax3.set_xlabel('X')
ax3.set_ylabel('Y')
ax3.set_zlabel('Z')
ax3.set_title(f'First {N_SHOW} points in Z-order\n(spatially clustered!)')

plt.tight_layout()
plt.show()

print(f"\nNotice: The first {N_SHOW} points in Z-order are spatially clustered!")
print("This is the locality-preserving property of Morton codes.")

---
## Step 5: Trans Z-Order (Swap X and Y)

**Purpose:** Create a different serialization pattern by swapping axes.

PTv3 uses **4 different patterns** and cycles through them:
1. Z-order (x, y, z)
2. Trans Z-order (y, x, z) ← swap x and y
3. Hilbert
4. Trans Hilbert

This ensures different points are grouped together in different layers.

In [None]:
def z_order_encode_trans(grid_coords: np.ndarray, depth: int = 16) -> np.ndarray:
    """
    Trans Z-order: Swap x and y axes before encoding.
    """
    # Swap x and y columns
    transposed = grid_coords[:, [1, 0, 2]]  # [y, x, z] instead of [x, y, z]
    return z_order_encode(transposed, depth)

# Compute Trans Z-order codes
morton_codes_trans = z_order_encode_trans(grid_coords)
sorted_indices_trans = np.argsort(morton_codes_trans)

print("Comparing Z-order vs Trans Z-order:")
print("\nFirst 10 points in each order:")
print(f"{'Rank':<6} {'Z-order Idx':<12} {'Trans Z-order Idx':<18} {'Same?'}")
print("-" * 50)

for rank in range(10):
    z_idx = sorted_indices[rank]
    zt_idx = sorted_indices_trans[rank]
    same = "✓" if z_idx == zt_idx else "✗"
    print(f"{rank:<6} {z_idx:<12} {zt_idx:<18} {same}")

# Count how many are in different positions
different = np.sum(sorted_indices != sorted_indices_trans)
print(f"\nPoints in different positions: {different} / {len(sorted_indices)} ({100*different/len(sorted_indices):.1f}%)")

In [None]:
# Visualize both serializations
fig = plt.figure(figsize=(12, 5))

N_SHOW = 300

# Z-order first N points
ax1 = fig.add_subplot(121, projection='3d')
first_z = sorted_indices[:N_SHOW]
ax1.scatter(coords[first_z, 0], coords[first_z, 1], coords[first_z, 2], 
           c=np.arange(N_SHOW), cmap='plasma', s=3)
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
ax1.set_title(f'Z-Order: First {N_SHOW} points')

# Trans Z-order first N points
ax2 = fig.add_subplot(122, projection='3d')
first_zt = sorted_indices_trans[:N_SHOW]
ax2.scatter(coords[first_zt, 0], coords[first_zt, 1], coords[first_zt, 2], 
           c=np.arange(N_SHOW), cmap='plasma', s=3)
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
ax2.set_title(f'Trans Z-Order: First {N_SHOW} points')

plt.tight_layout()
plt.show()

print("Notice: The spatial clustering pattern is different!")
print("This is why PTv3 uses multiple serialization patterns.")

---
## Step 6: Patch Grouping

**Purpose:** Split the serialized sequence into fixed-size patches.

**Why?** 
- Global attention has O(N²) complexity - too expensive!
- Local attention within patches: O(K²) per patch, O(N/K × K²) = O(NK) total
- With K=64 and N=2000: speedup from 4M to 128K operations!

```
Serialized: [P₁, P₂, P₃, P₄, P₅, P₆, P₇, P₈, P₉, P₁₀, ...]
             └─────────────┘  └─────────────┘
                 Patch 0          Patch 1
```

In [None]:
def group_into_patches(
    n_points: int,
    sorted_indices: np.ndarray,
    patch_size: int = 64
) -> List[np.ndarray]:
    """
    Group serialized points into non-overlapping patches.
    
    Args:
        n_points: Total number of points
        sorted_indices: Serialization order
        patch_size: Number of points per patch (K in PTv3, typically 1024)
    
    Returns:
        patches: List of index arrays, one per patch
    """
    patches = []
    
    for i in range(0, n_points, patch_size):
        end_idx = min(i + patch_size, n_points)
        patch_indices = sorted_indices[i:end_idx].copy()
        
        # Pad last patch if necessary (borrow from previous points)
        if len(patch_indices) < patch_size and i > 0:
            pad_size = patch_size - len(patch_indices)
            padding = sorted_indices[i - pad_size:i]
            patch_indices = np.concatenate([patch_indices, padding])
        
        patches.append(patch_indices)
    
    return patches

# Create patches
PATCH_SIZE = 64  # PTv3 uses 1024, we use smaller for visualization

patches = group_into_patches(len(coords), sorted_indices, PATCH_SIZE)

print(f"Patch size: {PATCH_SIZE}")
print(f"Number of patches: {len(patches)}")
print(f"\nPatch details:")
print(f"{'Patch':<8} {'Size':<8} {'First Idx':<12} {'Last Idx':<12} {'Spatial Extent (Δx, Δy, Δz)'}")
print("-" * 70)

for i, patch in enumerate(patches[:10]):  # Show first 10 patches
    patch_coords = coords[patch]
    dx = patch_coords[:, 0].max() - patch_coords[:, 0].min()
    dy = patch_coords[:, 1].max() - patch_coords[:, 1].min()
    dz = patch_coords[:, 2].max() - patch_coords[:, 2].min()
    print(f"{i:<8} {len(patch):<8} {patch[0]:<12} {patch[-1]:<12} ({dx:.2f}, {dy:.2f}, {dz:.2f})")

In [None]:
# Visualize patches
fig = plt.figure(figsize=(14, 6))

# View 1: All patches with different colors
ax1 = fig.add_subplot(121, projection='3d')
colors = plt.cm.tab20(np.linspace(0, 1, 20))

for i, patch in enumerate(patches):
    patch_coords = coords[patch]
    ax1.scatter(patch_coords[:, 0], patch_coords[:, 1], patch_coords[:, 2],
               c=[colors[i % 20]], s=2, alpha=0.7, 
               label=f'Patch {i}' if i < 5 else None)

ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
ax1.set_title(f'All {len(patches)} Patches\n(20 colors cycling)')
ax1.legend(loc='upper left', fontsize=8)

# View 2: First 5 patches only
ax2 = fig.add_subplot(122, projection='3d')
distinct_colors = ['red', 'blue', 'green', 'orange', 'purple']

for i, patch in enumerate(patches[:5]):
    patch_coords = coords[patch]
    ax2.scatter(patch_coords[:, 0], patch_coords[:, 1], patch_coords[:, 2],
               c=distinct_colors[i], s=10, alpha=0.8, label=f'Patch {i}')

ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
ax2.set_title('First 5 Patches\n(spatially localized!)')
ax2.legend()

plt.tight_layout()
plt.show()

print("\nKey insight: Each patch contains spatially nearby points!")
print("Attention within each patch captures local relationships.")

---
## Step 7: Shuffle Order Strategy

**Problem:** With fixed patches, points in Patch 0 never attend to points in Patch 5.

**Solution:** Use different serialization patterns in different attention layers!

```
Layer 0: Z-order        → Patches: [A,B,C,D], [E,F,G,H], ...
Layer 1: Trans Z-order  → Patches: [A,C,E,G], [B,D,F,H], ...  ← Different grouping!
Layer 2: Hilbert        → Patches: [A,B,E,F], [C,D,G,H], ...  ← Another grouping!
Layer 3: Trans Hilbert  → Patches: [A,D,E,H], [B,C,F,G], ...
```

Over multiple layers, every point eventually attends to every other point!

In [None]:
def hilbert_encode_simple(grid_coords: np.ndarray, depth: int = 10) -> np.ndarray:
    """
    Simplified Hilbert-like encoding (not true Hilbert, but demonstrates variety).
    """
    x = grid_coords[:, 0].astype(np.uint64)
    y = grid_coords[:, 1].astype(np.uint64)
    z = grid_coords[:, 2].astype(np.uint64)
    
    codes = np.zeros(len(x), dtype=np.uint64)
    
    for i in range(depth):
        x_bit = (x >> i) & 1
        y_bit = (y >> i) & 1
        z_bit = (z >> i) & 1
        
        # Different interleaving pattern than Z-order
        if i % 2 == 0:
            codes |= (z_bit << (3 * i))
            codes |= (x_bit << (3 * i + 1))
            codes |= (y_bit << (3 * i + 2))
        else:
            codes |= (y_bit << (3 * i))
            codes |= (z_bit << (3 * i + 1))
            codes |= (x_bit << (3 * i + 2))
    
    return codes

def serialize_all_patterns(grid_coords: np.ndarray) -> dict:
    """
    Compute all 4 serialization patterns used in PTv3.
    """
    patterns = {}
    
    # Z-order
    codes = z_order_encode(grid_coords)
    patterns['z'] = np.argsort(codes)
    
    # Trans Z-order
    codes = z_order_encode_trans(grid_coords)
    patterns['z-trans'] = np.argsort(codes)
    
    # Hilbert (simplified)
    codes = hilbert_encode_simple(grid_coords)
    patterns['hilbert'] = np.argsort(codes)
    
    # Trans Hilbert
    transposed = grid_coords[:, [1, 0, 2]]
    codes = hilbert_encode_simple(transposed)
    patterns['hilbert-trans'] = np.argsort(codes)
    
    return patterns

# Compute all patterns
all_patterns = serialize_all_patterns(grid_coords)

print("All 4 serialization patterns computed!")
print("\nFirst 10 indices in each pattern:")
for name, indices in all_patterns.items():
    print(f"  {name:<15}: {indices[:10].tolist()}")

In [None]:
# Visualize how Patch 0 changes across patterns
fig = plt.figure(figsize=(16, 4))

pattern_names = list(all_patterns.keys())
colors_patch = ['red', 'blue', 'green', 'orange']

for idx, (name, sorted_idx) in enumerate(all_patterns.items()):
    ax = fig.add_subplot(1, 4, idx + 1, projection='3d')
    
    # Get first patch
    patch_0 = sorted_idx[:PATCH_SIZE]
    patch_0_coords = coords[patch_0]
    
    # Plot all points in gray
    ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], 
              c='lightgray', s=1, alpha=0.3)
    
    # Highlight Patch 0
    ax.scatter(patch_0_coords[:, 0], patch_0_coords[:, 1], patch_0_coords[:, 2],
              c=colors_patch[idx], s=10, alpha=0.9)
    
    # Compute centroid
    centroid = patch_0_coords.mean(axis=0)
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(f'{name}\nPatch 0 centroid: ({centroid[0]:.1f}, {centroid[1]:.1f}, {centroid[2]:.1f})')

plt.suptitle('Shuffle Order: Patch 0 is DIFFERENT in each serialization pattern!', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

print("\nKey insight: Different serialization patterns group different points together!")
print("This is why PTv3 cycles through patterns: [z → z-trans → hilbert → hilbert-trans → z → ...]")

In [None]:
# Analyze overlap between Patch 0 across different patterns
print("Overlap analysis between Patch 0 across patterns:")
print("=" * 60)

pattern_names = list(all_patterns.keys())
patch_0_sets = {name: set(indices[:PATCH_SIZE]) for name, indices in all_patterns.items()}

print(f"\n{'Pattern 1':<15} {'Pattern 2':<15} {'Overlap':<10} {'Percentage'}")
print("-" * 55)

for i, name1 in enumerate(pattern_names):
    for name2 in pattern_names[i+1:]:
        overlap = len(patch_0_sets[name1] & patch_0_sets[name2])
        pct = 100 * overlap / PATCH_SIZE
        print(f"{name1:<15} {name2:<15} {overlap:<10} {pct:.1f}%")

print(f"\nWith {PATCH_SIZE} points per patch, low overlap means different groupings!")

---
## Step 8: Attention Complexity Analysis

**Why serialization + local attention matters:**

In [None]:
N = len(coords)  # Total points
K = PATCH_SIZE   # Patch size
num_patches = (N + K - 1) // K

print("ATTENTION COMPLEXITY ANALYSIS")
print("=" * 60)
print(f"\nNumber of points (N): {N:,}")
print(f"Patch size (K): {K}")
print(f"Number of patches: {num_patches}")

# Global attention
global_ops = N * N
print(f"\n1. GLOBAL ATTENTION (no serialization):")
print(f"   Operations: N² = {global_ops:,}")

# Local attention with patches
local_ops = num_patches * K * K
print(f"\n2. LOCAL PATCH ATTENTION (with serialization):")
print(f"   Operations: (N/K) × K² = {local_ops:,}")

# Speedup
speedup = global_ops / local_ops
print(f"\n3. SPEEDUP: {speedup:.1f}×")

# Memory savings
print(f"\n4. MEMORY (attention matrix):")
print(f"   Global: {N}×{N} = {N*N:,} elements")
print(f"   Local:  {num_patches} × ({K}×{K}) = {num_patches * K * K:,} elements")
print(f"   Memory reduction: {N*N / (num_patches * K * K):.1f}×")

# PTv3 scale
print(f"\n5. AT PTv3 SCALE (K=1024, N=100,000):")
N_ptv3 = 100000
K_ptv3 = 1024
global_ptv3 = N_ptv3 * N_ptv3
local_ptv3 = (N_ptv3 // K_ptv3) * K_ptv3 * K_ptv3
print(f"   Global: {global_ptv3:,} operations")
print(f"   Local:  {local_ptv3:,} operations")
print(f"   Speedup: {global_ptv3/local_ptv3:.0f}×")

---
## Step 9: Complete Serialization Pipeline (PTv3 Style)

Let's put it all together in a single class, similar to how PTv3 implements it.

In [None]:
class PointCloudSerializer:
    """
    Complete serialization pipeline for point cloud transformers.
    
    This mimics the serialization in PTv3/LitePT.
    """
    
    def __init__(
        self,
        grid_size: float = 0.05,
        patch_size: int = 64,
        orders: tuple = ('z', 'z-trans', 'hilbert', 'hilbert-trans'),
        depth: int = 16
    ):
        self.grid_size = grid_size
        self.patch_size = patch_size
        self.orders = orders
        self.depth = depth
    
    def serialize(self, coords: np.ndarray) -> dict:
        """
        Complete serialization of a point cloud.
        
        Args:
            coords: (N, 3) array of [x, y, z] coordinates
        
        Returns:
            dict with:
            - grid_coords: quantized coordinates
            - serialized_codes: list of codes for each pattern
            - serialized_order: list of sort indices for each pattern
            - serialized_inverse: list of inverse indices
            - n_points: number of points
            - n_patches: number of patches
        """
        # Step 1: Normalize and quantize
        coords_norm = coords - coords.min(axis=0)
        grid_coords = np.floor(coords_norm / self.grid_size).astype(np.int64)
        
        # Step 2: Compute all serialization patterns
        result = {
            'coords': coords_norm,
            'grid_coords': grid_coords,
            'serialized_codes': [],
            'serialized_order': [],
            'serialized_inverse': [],
            'n_points': len(coords),
            'n_patches': (len(coords) + self.patch_size - 1) // self.patch_size
        }
        
        for order in self.orders:
            codes = self._encode(grid_coords, order)
            sorted_idx = np.argsort(codes)
            inverse_idx = np.argsort(sorted_idx)
            
            result['serialized_codes'].append(codes)
            result['serialized_order'].append(sorted_idx)
            result['serialized_inverse'].append(inverse_idx)
        
        return result
    
    def _encode(self, grid_coords: np.ndarray, order: str) -> np.ndarray:
        """Encode grid coordinates based on the specified order."""
        if order == 'z':
            return z_order_encode(grid_coords, self.depth)
        elif order == 'z-trans':
            return z_order_encode_trans(grid_coords, self.depth)
        elif order == 'hilbert':
            return hilbert_encode_simple(grid_coords, self.depth)
        elif order == 'hilbert-trans':
            transposed = grid_coords[:, [1, 0, 2]]
            return hilbert_encode_simple(transposed, self.depth)
        else:
            raise ValueError(f"Unknown order: {order}")
    
    def get_patches(self, result: dict, layer_idx: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Get patch indices for a specific layer.
        
        Layer index determines which serialization pattern to use.
        """
        order_idx = layer_idx % len(self.orders)
        sorted_idx = result['serialized_order'][order_idx]
        
        patches = []
        n = result['n_points']
        
        for i in range(0, n, self.patch_size):
            end_idx = min(i + self.patch_size, n)
            patch = sorted_idx[i:end_idx]
            
            # Pad if necessary
            if len(patch) < self.patch_size and i > 0:
                pad_size = self.patch_size - len(patch)
                padding = sorted_idx[i - pad_size:i]
                patch = np.concatenate([patch, padding])
            
            patches.append(patch)
        
        return patches, sorted_idx

# Test the complete pipeline
print("Testing complete serialization pipeline...")
print("=" * 60)

serializer = PointCloudSerializer(
    grid_size=0.05,
    patch_size=64,
    orders=('z', 'z-trans', 'hilbert', 'hilbert-trans')
)

# Run serialization
serial_result = serializer.serialize(coords_raw)

print(f"\nSerialization complete!")
print(f"  Points: {serial_result['n_points']}")
print(f"  Patches: {serial_result['n_patches']}")
print(f"  Patterns: {serializer.orders}")

# Test getting patches for different layers
print(f"\nPatches for each layer (showing first patch):")
for layer in range(4):
    patches, _ = serializer.get_patches(serial_result, layer)
    pattern = serializer.orders[layer % len(serializer.orders)]
    print(f"  Layer {layer} ({pattern}): Patch 0 indices = {patches[0][:5].tolist()}...")

---
## Summary

### What we learned:

1. **Grid Quantization**: Continuous coords → Discrete grid coords
   - `(0.15, 0.27, 0.43) → (3, 5, 8)` with grid_size=0.05

2. **Z-Order Encoding**: Grid coords → Morton code (bit interleaving)
   - `(3, 5, 8) → interleave bits → single integer`

3. **Sorting**: Create ordered sequence by Morton codes
   - Spatially nearby points get nearby positions

4. **Patch Grouping**: Split into fixed-size groups
   - Enables local attention: O(NK) instead of O(N²)

5. **Shuffle Order**: Cycle through 4 patterns
   - Different patterns → Different groupings
   - Ensures all points eventually interact

### LitePT's insight:
- Only use serialization + attention at **late stages** (E3, E4)
- Use convolution at **early stages** (E0, E1, E2)
- Replace heavy conv-based positional encoding with **PointROPE** (parameter-free)

In [None]:
# Final visualization: Summary of the complete pipeline
fig = plt.figure(figsize=(16, 12))

# 1. Original point cloud
ax1 = fig.add_subplot(231, projection='3d')
ax1.scatter(coords[:, 0], coords[:, 1], coords[:, 2], 
           c=df['intensity'].values, cmap='viridis', s=1)
ax1.set_title('1. Original Point Cloud')
ax1.set_xlabel('X'); ax1.set_ylabel('Y'); ax1.set_zlabel('Z')

# 2. Grid visualization
ax2 = fig.add_subplot(232, projection='3d')
ax2.scatter(grid_coords[:, 0], grid_coords[:, 1], grid_coords[:, 2], 
           c=morton_codes, cmap='viridis', s=1)
ax2.set_title('2. Grid Coordinates\n(colored by Morton code)')
ax2.set_xlabel('Grid X'); ax2.set_ylabel('Grid Y'); ax2.set_zlabel('Grid Z')

# 3. Serialization order
ax3 = fig.add_subplot(233, projection='3d')
ax3.scatter(coords[:, 0], coords[:, 1], coords[:, 2], 
           c=serial_position, cmap='plasma', s=1)
ax3.set_title('3. Serialization Order\n(colored by position)')
ax3.set_xlabel('X'); ax3.set_ylabel('Y'); ax3.set_zlabel('Z')

# 4. Patch grouping
ax4 = fig.add_subplot(234, projection='3d')
for i, patch in enumerate(patches[:15]):
    patch_coords = coords[patch]
    ax4.scatter(patch_coords[:, 0], patch_coords[:, 1], patch_coords[:, 2],
               c=[colors[i % 20]], s=2, alpha=0.7)
ax4.set_title('4. Patch Grouping\n(first 15 patches)')
ax4.set_xlabel('X'); ax4.set_ylabel('Y'); ax4.set_zlabel('Z')

# 5. Shuffle order comparison
ax5 = fig.add_subplot(235, projection='3d')
for idx, (name, sorted_idx) in enumerate(all_patterns.items()):
    patch_0 = sorted_idx[:PATCH_SIZE]
    patch_coords = coords[patch_0]
    ax5.scatter(patch_coords[:, 0], patch_coords[:, 1], patch_coords[:, 2],
               c=colors_patch[idx], s=5, alpha=0.5, label=name)
ax5.set_title('5. Shuffle Order\n(Patch 0 in each pattern)')
ax5.set_xlabel('X'); ax5.set_ylabel('Y'); ax5.set_zlabel('Z')
ax5.legend(fontsize=8)

# 6. Complexity comparison (text)
ax6 = fig.add_subplot(236)
ax6.axis('off')
summary_text = f"""
SERIALIZATION SUMMARY
{'='*30}

Points: {N:,}
Grid size: {GRID_SIZE}m
Patch size: {PATCH_SIZE}
Patches: {num_patches}

COMPLEXITY:
  Global attention: {N*N:,} ops
  Local attention:  {num_patches*PATCH_SIZE*PATCH_SIZE:,} ops
  Speedup: {N*N/(num_patches*PATCH_SIZE*PATCH_SIZE):.1f}×

PATTERNS:
  • z-order
  • z-order-trans (swap x,y)
  • hilbert
  • hilbert-trans
"""
ax6.text(0.1, 0.9, summary_text, transform=ax6.transAxes, 
        fontsize=11, verticalalignment='top', fontfamily='monospace',
        bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

plt.suptitle('Point Cloud Serialization Pipeline (PTv3/LitePT)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()