# Notebook E — 05_create_training_data.ipynb (Prepare Dataset)

This notebook loads annotated point clouds (or labeled folders) and generates a training dataset (features + labels) for the classifier.

## Concepts
- **Features**: Geometric and color features extracted from points (XYZ, Normals, Colors, Height).
- **Labels**: Ground truth class derived from `.las` classification or folder structure.

In [None]:
# Cell E0 — Install dependencies
import sys
import subprocess

def install_packages(packages):
    subprocess.check_call([sys.executable, '-m', 'pip', 'install'] + packages)

try:
    import open3d
    import laspy
except ImportError:
    print("Installing dependencies...")
    install_packages(['numpy', 'open3d', 'trimesh', 'laspy', 'scipy', 'onnxruntime', 'numba', 'py-vox-io'])

In [None]:
# Cell E1 — Mount Drive & Setup
from google.colab import drive
import os
import sys
import shutil
import subprocess
import importlib

drive.mount('/content/drive')

BASE = "/content/drive/MyDrive/voxel_engine"
REPO_DIR = "/content/spec-kit"
SRC_DIR = os.path.join(REPO_DIR, "src")

# --- SELF-HEALING SETUP ---
# Check if loader.py exists. If not, re-clone.
loader_path = os.path.join(SRC_DIR, "loader.py")

if not os.path.exists(loader_path):
    print("Dependencies missing or corrupted. Re-cloning repository...")
    if os.path.exists(REPO_DIR):
        shutil.rmtree(REPO_DIR)
    
    # Clone repo using subprocess to avoid syntax errors in IDEs
    subprocess.check_call(["git", "clone", "https://github.com/yamatotakeru616/spec-kit.git", REPO_DIR])
    
    # Install requirements
    req_path = os.path.join(REPO_DIR, "requirements.txt")
    if os.path.exists(req_path):
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", req_path])
else:
    print("Repository found.")

# Add to path
if SRC_DIR not in sys.path:
    sys.path.append(SRC_DIR)
    print(f"Added {SRC_DIR} to sys.path")

# Reload modules to ensure we have latest version
try:
    import loader, classify, utils
    import numpy as np
    importlib.reload(loader)
    importlib.reload(classify)
    importlib.reload(utils)
    print("Imports successful!")
except ImportError as e:
    print("CRITICAL ERROR: Could not import modules even after clone.")
    # Fallback: try to see what is in the directory
    print(f"Contents of {SRC_DIR}:", os.listdir(SRC_DIR) if os.path.exists(SRC_DIR) else "Directory not found")
    raise e

INPUT_DIR = os.path.join(BASE, "input")
TRAIN_DIR = os.path.join(BASE, "training_data")

os.makedirs(TRAIN_DIR, exist_ok=True)
print("Ready to process data in:", INPUT_DIR)

In [None]:
# Cell E2 — Define Feature Extraction Helper

def process_file(path, label_override=None):
    """
    Loads a file, computes features, and returns (features, labels).
    If label_override is provided, forces all points to have that label.
    Otherwise, tries to read labels from file (LAS classification).
    """
    print(f"Processing {os.path.basename(path)}...")
    
    # Load with upgraded loader that supports labels
    # Note: If loader.py wasn't updated in previous steps, this might fail to get labels if not using folder mode.
    try:
        if hasattr(loader, 'load_annotated_pointcloud'):
            pts, cols, file_labels = loader.load_annotated_pointcloud(path)
        else:
            # Fallback if old loader
            pts, cols = loader.load_pointcloud(path)
            file_labels = None
    except Exception as e:
        print(f"Failed to load {path}: {e}")
        return None, None

    # Compute features
    # compute_features(points, colors=None, normals=None)
    feats = classify.compute_features(pts, colors=cols)
    
    # Determine Labels
    target_labels = None
    
    if label_override is not None:
        # Create full array of this label
        target_labels = np.full(pts.shape[0], label_override, dtype=np.uint8)
    elif file_labels is not None:
        target_labels = file_labels
    else:
        print(f"Warning: No labels found for {path} and no override provided. Skipping labels.")
    
    return feats, target_labels

In [None]:
# Cell E3 — Process Dataset (Auto-Detect Mode)
    
print(f"--- Checking Input Directory: {INPUT_DIR} ---")
if os.path.exists(INPUT_DIR):
    contents = os.listdir(INPUT_DIR)
    print(f"Contents of input/: {contents}")
else:
    print(f"Error: Input directory {INPUT_DIR} not found.")
    contents = []

all_feats = []
all_labels = []

# Define your class mapping here if using folders
CLASS_MAP = {
    "ground": 1,
    "vegetation": 2,
    "building": 3,
    "vehicle": 4,
    "other": 5  # Added example
}

# Auto-detect mode if possible, or default to folder
# logic: if any class folder exists, use folder mode. Else if any laz/las file in root, use file mode.

scan_mode = "auto" # "folder", "file", or "auto"

if scan_mode == "auto":
    has_class_folders = any(os.path.exists(os.path.join(INPUT_DIR, k)) for k in CLASS_MAP.keys())
    has_root_files = any(f.lower().endswith(('.las', '.laz')) for f in contents)
    
    if has_class_folders:
        scan_mode = "folder"
        print("Auto-mode: Detected class folders. Using 'folder' mode.")
    elif has_root_files:
        scan_mode = "file"
        print("Auto-mode: Detected point cloud files in root. Using 'file' mode.")
    else:
        print("Auto-mode: Could not detect valid data structure. Defaulting to 'folder'.")
        scan_mode = "folder"

if scan_mode == "folder":
    print(f"Scanning for folders: {list(CLASS_MAP.keys())} ...")
    found_any_folder = False
    for folder_name, label_id in CLASS_MAP.items():
        folder_path = os.path.join(INPUT_DIR, folder_name)
        if not os.path.exists(folder_path): 
            # print(f"  (Folder '{folder_name}' not found, skipping)")
            continue
        
        found_any_folder = True
        files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.las', '.laz', '.ply', '.obj'))]
        print(f"Found {len(files)} files for class '{folder_name}' (ID: {label_id})")
        
        for f in files:
            fp = os.path.join(folder_path, f)
            f_feats, f_lbls = process_file(fp, label_override=label_id)
            if f_feats is not None and f_lbls is not None:
                all_feats.append(f_feats)
                all_labels.append(f_lbls)
    
    if not found_any_folder:
        print("Warning: No class folders found matching CLASS_MAP keys.")
        print("Please create folders like 'ground', 'building' inside 'input/' and put files there.")
        print("OR, if you have a single .las file with classification, ensure it is in 'input/' and re-run.")

# Cell E4 — Process Dataset (Option B: Single file with embedded labels)
if scan_mode == "file":
     files = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith(('.las', '.laz'))]
     print(f"Scanning files in root: {files}")
     
     if len(files) == 0:
         print("Error: scan_mode is 'file' but no .las/.laz files found in input/.")
     
     for f in files:
         fp = os.path.join(INPUT_DIR, f)
         f_feats, f_lbls = process_file(fp)
         if f_feats is not None and f_lbls is not None:
            all_feats.append(f_feats)
            all_labels.append(f_lbls)
            
if len(all_feats) > 0:
    X = np.vstack(all_feats)
    y = np.concatenate(all_labels)
    print(f"\n=== SUMMARY ===")
    print(f"Total Training Data: {X.shape[0]} points")
    print(f"Features: {X.shape[1]}")
    print(f"Classes found: {np.unique(y)}")
else:
    print("\n=== RESULT: NO DATA GENERATED ===")
    print("Reasons match empty output:")
    print("1. No files found in the expected locations.")
    print("2. Files found but failed to load.")
    print("3. Files loaded but classification/labels were missing (and no override provided).")

In [None]:
# Cell E5 — Save Training Data
if len(all_feats) > 0:
    save_path = os.path.join(TRAIN_DIR, "training_data.npz")
    np.savez_compressed(save_path, X=X, y=y)
    print(f"Saved dataset to {save_path}")
    
    # Optional: Save a small sample for verification
    # sample_idx = np.random.choice(len(X), min(10000, len(X)), replace=False)
    # np.savez(os.path.join(TRAIN_DIR, "sample_data.npz"), X=X[sample_idx], y=y[sample_idx])