In [1]:
"""
Optimal Preprocessing Pipeline for Multi-Class Road Damage Segmentation
Handles extreme class imbalance (95% normal, 5% damage across 4 classes)
"""

import numpy as np
import pandas as pd
import laspy
import h5py
from sklearn.cluster import DBSCAN
from scipy.spatial import KDTree
import os
from tqdm import tqdm

# ============================================================================
# PREPROCESSING CONFIGURATION
# ============================================================================

CONFIG = {
    # Class definitions
    'classes': {
        0: 'Normal Road',
        1: 'Pothole',
        2: 'Crack',
        3: 'Edge Crack',
        4: 'Flushed Out'
    },
    
    # Patch settings
    'patch_size': 10.0,  # 10m x 10m patches (better for long roads)
    'points_per_patch': 8192,  # Larger patches for better context
    'stride': 5.0,  # 50% overlap
    
    # Sampling strategy
    'damage_patch_ratio': 0.4,  # 40% of patches should contain damage
    'min_damage_points': 50,  # Minimum damage points to consider patch valid
    
    # Data augmentation
    'augmentation': {
        'rotation': True,
        'jitter': 0.01,
        'scale': (0.9, 1.1),
        'dropout': 0.05  # Random point dropout
    }
}

# ============================================================================
# STEP 1: PARSE BOUNDING BOXES
# ============================================================================

def parse_damage_bboxes(csv_file, margin=0.3):
    """
    Parse CSV with S/F points and create bounding boxes for each damage type
    
    CSV format:
    Name,X,Y,Z,Type
    Site1_P1_S,x,y,z,Pothole
    Site1_P1_F,x,y,z,Pothole
    """
    df = pd.read_csv(csv_file)
    
    # Extract damage ID and type
    df['Damage_ID'] = df['Name'].str.extract(r'_(P\d+|C\d+|E\d+|F\d+)')[0]
    df['Point_Type'] = df['Name'].str.extract(r'_(S|F)$')[0]
    
    # Determine damage class from Name or Type column
    def get_damage_class(row):
        name = row['Name'].upper()
        if 'P' in name or (pd.notna(row.get('Type')) and 'POTHOLE' in str(row['Type']).upper()):
            return 1  # Pothole
        elif 'C' in name and 'E' not in name or (pd.notna(row.get('Type')) and 'CRACK' in str(row['Type']).upper() and 'EDGE' not in str(row['Type']).upper()):
            return 2  # Crack
        elif 'E' in name or (pd.notna(row.get('Type')) and 'EDGE' in str(row['Type']).upper()):
            return 3  # Edge Crack
        elif 'F' in name or (pd.notna(row.get('Type')) and 'FLUSH' in str(row['Type']).upper()):
            return 4  # Flushed Out
        return 0
    
    df['Class'] = df.apply(get_damage_class, axis=1)
    
    bboxes = []
    damage_ids = df['Damage_ID'].dropna().unique()
    
    for dmg_id in damage_ids:
        dmg_df = df[df['Damage_ID'] == dmg_id]
        start = dmg_df[dmg_df['Point_Type'] == 'S']
        finish = dmg_df[dmg_df['Point_Type'] == 'F']
        
        if len(start) > 0 and len(finish) > 0:
            x_s, y_s, z_s = start.iloc[0][['X', 'Y', 'Z']].values
            x_f, y_f, z_f = finish.iloc[0][['X', 'Y', 'Z']].values
            damage_class = start.iloc[0]['Class']
            
            bbox = {
                'damage_id': dmg_id,
                'class': damage_class,
                'x_min': min(x_s, x_f) - margin,
                'x_max': max(x_s, x_f) + margin,
                'y_min': min(y_s, y_f) - margin,
                'y_max': max(y_s, y_f) + margin,
                'z_min': min(z_s, z_f) - margin,
                'z_max': max(z_s, z_f) + margin,
            }
            bboxes.append(bbox)
    
    return bboxes

# ============================================================================
# STEP 2: LABEL POINTS WITH DAMAGE CLASSES
# ============================================================================

def label_points_from_bboxes(points, bboxes):
    """Assign class labels to points based on bounding boxes"""
    labels = np.zeros(len(points), dtype=np.int32)  # Default: normal road
    
    for bbox in bboxes:
        mask = (
            (points[:, 0] >= bbox['x_min']) & (points[:, 0] <= bbox['x_max']) &
            (points[:, 1] >= bbox['y_min']) & (points[:, 1] <= bbox['y_max']) &
            (points[:, 2] >= bbox['z_min']) & (points[:, 2] <= bbox['z_max'])
        )
        labels[mask] = bbox['class']
    
    return labels

# ============================================================================
# STEP 3: REFINE LABELS WITH DBSCAN
# ============================================================================

def refine_labels_with_clustering(points, labels, eps=0.5, min_samples=10):
    """Use DBSCAN to refine damage labels and remove outliers"""
    refined_labels = labels.copy()
    
    for damage_class in [1, 2, 3, 4]:
        damage_mask = labels == damage_class
        if np.sum(damage_mask) < min_samples:
            continue
        
        damage_points = points[damage_mask]
        
        # Cluster damage points
        clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(damage_points)
        
        # Remove noise points (label -1)
        noise_mask = clustering.labels_ == -1
        damage_indices = np.where(damage_mask)[0]
        noise_indices = damage_indices[noise_mask]
        refined_labels[noise_indices] = 0  # Set to normal
    
    return refined_labels

# ============================================================================
# STEP 4: INTELLIGENT PATCH EXTRACTION
# ============================================================================

def extract_balanced_patches(points, labels, config):
    """
    Extract patches with balanced sampling:
    - Ensure adequate representation of damage classes
    - Use sliding window for full coverage
    - Oversample damage-containing patches
    """
    patch_size = config['patch_size']
    stride = config['stride']
    points_per_patch = config['points_per_patch']
    
    x_min, y_min = points[:, 0].min(), points[:, 1].min()
    x_max, y_max = points[:, 0].max(), points[:, 1].max()
    
    damage_patches = []
    normal_patches = []
    
    # Sliding window extraction
    x = x_min
    while x < x_max:
        y = y_min
        while y < y_max:
            # Define patch bounds
            mask = (
                (points[:, 0] >= x) & (points[:, 0] < x + patch_size) &
                (points[:, 1] >= y) & (points[:, 1] < y + patch_size)
            )
            
            patch_indices = np.where(mask)[0]
            
            if len(patch_indices) < 500:  # Skip sparse patches
                y += stride
                continue
            
            patch_points = points[patch_indices]
            patch_labels = labels[patch_indices]
            
            # Count damage points
            damage_count = np.sum(patch_labels > 0)
            damage_ratio = damage_count / len(patch_labels)
            
            # Sample points
            if len(patch_points) >= points_per_patch:
                choice = np.random.choice(len(patch_points), points_per_patch, replace=False)
            else:
                choice = np.random.choice(len(patch_points), points_per_patch, replace=True)
            
            sampled_points = patch_points[choice]
            sampled_labels = patch_labels[choice]
            
            # Categorize patch
            if damage_count >= config['min_damage_points']:
                damage_patches.append({
                    'points': sampled_points,
                    'labels': sampled_labels,
                    'damage_ratio': damage_ratio,
                    'class_distribution': np.bincount(sampled_labels, minlength=5)
                })
            else:
                normal_patches.append({
                    'points': sampled_points,
                    'labels': sampled_labels
                })
            
            y += stride
        x += stride
    
    return damage_patches, normal_patches

# ============================================================================
# STEP 5: BALANCE DATASET
# ============================================================================

def create_balanced_dataset(damage_patches, normal_patches, config):
    """
    Create balanced training dataset:
    - Keep all damage patches (minority class)
    - Undersample normal patches
    - Add extra copies of rare damage types
    """
    num_damage = len(damage_patches)
    damage_ratio = config['damage_patch_ratio']
    
    # Calculate how many normal patches to keep
    num_normal_target = int(num_damage * (1 - damage_ratio) / damage_ratio)
    
    # Randomly sample normal patches
    if len(normal_patches) > num_normal_target:
        normal_indices = np.random.choice(len(normal_patches), num_normal_target, replace=False)
        selected_normal = [normal_patches[i] for i in normal_indices]
    else:
        selected_normal = normal_patches
    
    # Oversample rare damage types
    class_counts = {}
    for patch in damage_patches:
        main_class = np.argmax(patch['class_distribution'][1:]) + 1  # Ignore class 0
        class_counts[main_class] = class_counts.get(main_class, 0) + 1
    
    # Find rarest class
    if class_counts:
        max_count = max(class_counts.values())
        oversampled_damage = damage_patches.copy()
        
        for damage_class, count in class_counts.items():
            if count < max_count * 0.5:  # If less than 50% of most common
                # Find patches of this class
                class_patches = [p for p in damage_patches 
                               if np.argmax(p['class_distribution'][1:]) + 1 == damage_class]
                # Duplicate them
                oversample_times = int(max_count / count) - 1
                oversampled_damage.extend(class_patches * oversample_times)
    else:
        oversampled_damage = damage_patches
    
    # Combine
    all_patches = oversampled_damage + selected_normal
    np.random.shuffle(all_patches)
    
    print(f"Dataset balance:")
    print(f"  Damage patches: {len(oversampled_damage)}")
    print(f"  Normal patches: {len(selected_normal)}")
    print(f"  Total patches: {len(all_patches)}")
    print(f"  Damage ratio: {len(oversampled_damage)/len(all_patches)*100:.1f}%")
    
    return all_patches

# ============================================================================
# STEP 6: NORMALIZATION & AUGMENTATION
# ============================================================================

def normalize_patch(points):
    """Normalize patch to unit sphere"""
    centroid = np.mean(points[:, :3], axis=0)
    points[:, :3] -= centroid
    max_dist = np.max(np.sqrt(np.sum(points[:, :3]**2, axis=1)))
    if max_dist > 0:
        points[:, :3] /= max_dist
    return points

def augment_patch(points, labels, config):
    """Apply data augmentation"""
    aug_config = config['augmentation']
    
    # Random rotation around Z-axis
    if aug_config['rotation']:
        theta = np.random.uniform(0, 2 * np.pi)
        cos_t, sin_t = np.cos(theta), np.sin(theta)
        rotation = np.array([[cos_t, -sin_t, 0],
                           [sin_t, cos_t, 0],
                           [0, 0, 1]])
        points[:, :3] = points[:, :3] @ rotation.T
    
    # Random jitter
    points[:, :3] += np.random.normal(0, aug_config['jitter'], points[:, :3].shape)
    
    # Random scaling
    scale = np.random.uniform(*aug_config['scale'])
    points[:, :3] *= scale
    
    # Random point dropout
    if aug_config['dropout'] > 0:
        keep_mask = np.random.random(len(points)) > aug_config['dropout']
        keep_indices = np.where(keep_mask)[0]
        if len(keep_indices) >= len(points) // 2:  # Keep at least 50%
            points = points[keep_indices]
            labels = labels[keep_indices]
            # Resample to original size
            resample_indices = np.random.choice(len(points), CONFIG['points_per_patch'], replace=True)
            points = points[resample_indices]
            labels = labels[resample_indices]
    
    return points, labels

# ============================================================================
# STEP 7: SAVE TO H5 FORMAT
# ============================================================================

def save_patches_to_h5(patches, output_file, config, augment=False):
    """Save patches to H5 format for efficient loading"""
    all_points = []
    all_labels = []
    
    for patch in tqdm(patches, desc="Processing patches"):
        points = patch['points'].copy()
        labels = patch['labels'].copy()
        
        # Normalize
        points = normalize_patch(points)
        
        # Augment if training
        if augment:
            points, labels = augment_patch(points, labels, config)
        
        all_points.append(points)
        all_labels.append(labels)
    
    all_points = np.array(all_points, dtype=np.float32)
    all_labels = np.array(all_labels, dtype=np.int32)
    
    # Save to H5
    with h5py.File(output_file, 'w') as f:
        f.create_dataset('data', data=all_points, compression='gzip')
        f.create_dataset('label', data=all_labels, compression='gzip')
        f.create_dataset('num_classes', data=5)
        
        # Save class statistics
        class_counts = np.bincount(all_labels.flatten(), minlength=5)
        f.create_dataset('class_counts', data=class_counts)
        
        # Save config
        f.attrs['patch_size'] = config['patch_size']
        f.attrs['points_per_patch'] = config['points_per_patch']
    
    print(f"\nSaved {len(patches)} patches to {output_file}")
    print("\nClass distribution:")
    class_counts = np.bincount(all_labels.flatten(), minlength=5)
    for i, (class_name, count) in enumerate(zip(config['classes'].values(), class_counts)):
        percentage = 100 * count / class_counts.sum()
        print(f"  {class_name}: {count:,} points ({percentage:.2f}%)")

# ============================================================================
# MAIN PREPROCESSING PIPELINE
# ============================================================================

def preprocess_single_file(las_file, csv_file, config):
    """Preprocess a single LAS file with its labels"""
    print(f"\nProcessing: {os.path.basename(las_file)}")
    
    # Load LAS
    las = laspy.read(las_file)
    points = np.vstack((las.x, las.y, las.z)).T
    
    # Add intensity if available
    if hasattr(las, 'intensity'):
        intensity = las.intensity.reshape(-1, 1) / 255.0
        points = np.hstack((points, intensity))
    
    print(f"  Loaded {len(points):,} points")
    
    # Parse bounding boxes
    bboxes = parse_damage_bboxes(csv_file)
    print(f"  Found {len(bboxes)} damage regions")
    
    # Label points
    labels = label_points_from_bboxes(points, bboxes)
    
    # Refine with clustering
    labels = refine_labels_with_clustering(points, labels)
    
    # Print label statistics
    unique, counts = np.unique(labels, return_counts=True)
    print(f"  Label distribution:")
    for cls, count in zip(unique, counts):
        print(f"    Class {cls} ({config['classes'][cls]}): {count:,} ({100*count/len(labels):.2f}%)")
    
    # Extract patches
    damage_patches, normal_patches = extract_balanced_patches(points, labels, config)
    
    return damage_patches, normal_patches

def preprocess_all_files(las_folder, csv_folder, output_folder, config):
    """Preprocess all LAS files"""
    import glob
    
    os.makedirs(output_folder, exist_ok=True)
    
    las_files = sorted(glob.glob(os.path.join(las_folder, '*.las')))
    print(f"Found {len(las_files)} LAS files")
    
    all_damage_patches = []
    all_normal_patches = []
    
    for las_file in las_files:
        base_name = os.path.splitext(os.path.basename(las_file))[0]
        csv_file = os.path.join(csv_folder, base_name + '.csv')
        
        if not os.path.exists(csv_file):
            print(f"Warning: No CSV found for {base_name}, skipping...")
            continue
        
        try:
            damage_patches, normal_patches = preprocess_single_file(las_file, csv_file, config)
            all_damage_patches.extend(damage_patches)
            all_normal_patches.extend(normal_patches)
        except Exception as e:
            print(f"Error processing {base_name}: {e}")
            continue
    
    print(f"\n{'='*60}")
    print(f"TOTAL STATISTICS")
    print(f"{'='*60}")
    print(f"Total damage patches: {len(all_damage_patches)}")
    print(f"Total normal patches: {len(all_normal_patches)}")
    
    # Create balanced dataset
    balanced_patches = create_balanced_dataset(all_damage_patches, all_normal_patches, config)
    
    # Split train/val
    np.random.shuffle(balanced_patches)
    split_idx = int(0.8 * len(balanced_patches))
    train_patches = balanced_patches[:split_idx]
    val_patches = balanced_patches[split_idx:]
    
    # Save
    save_patches_to_h5(train_patches, 
                      os.path.join(output_folder, 'train_balanced.h5'),
                      config, augment=True)
    save_patches_to_h5(val_patches,
                      os.path.join(output_folder, 'val_balanced.h5'),
                      config, augment=False)
    
    print(f"\n✓ Preprocessing complete!")

# ============================================================================
# USAGE EXAMPLE
# ============================================================================

if __name__ == "__main__":
    # Configure paths
    LAS_FOLDER = "C:/Users/umair.muhammad/Documents/PhD/Research Work/FedLearn/training/All_Nome/LAS_Files_Site1"
    CSV_FOLDER = "C:/Users/umair.muhammad/Documents/PhD/Research Work/FedLearn/training/All_Nome/Labels_Site1"
    OUTPUT_FOLDER = "C:/Users/umair.muhammad/Documents/PhD/Research Work/FedLearn/training/All_Nome/h5"
    
    # Run preprocessing
    preprocess_all_files(LAS_FOLDER, CSV_FOLDER, OUTPUT_FOLDER, CONFIG)
    
    print("\n" + "="*60)
    print("NEXT STEPS:")
    print("="*60)
    print("1. Use the generated H5 files for training")
    print("2. Apply class weights: {0: 1.0, 1: 20.0, 2: 150.0, 3: 500.0, 4: 50.0}")
    print("3. Use Focal Loss for extreme imbalance")
    print("4. Monitor per-class IoU, not just overall accuracy")
    print("="*60)

Found 15 LAS files

Processing: Site1_0.las
  Loaded 1,917,694 points
  Found 1 damage regions
  Label distribution:
    Class 0 (Normal Road): 1,916,653 (99.95%)
    Class 1 (Pothole): 1,041 (0.05%)

Processing: Site1_1.las
  Loaded 306,546 points
  Found 20 damage regions
  Label distribution:
    Class 0 (Normal Road): 298,388 (97.34%)
    Class 1 (Pothole): 8,158 (2.66%)

Processing: Site1_10.las
  Loaded 497,676 points
  Found 38 damage regions
  Label distribution:
    Class 0 (Normal Road): 464,556 (93.35%)
    Class 1 (Pothole): 33,120 (6.65%)

Processing: Site1_11.las
  Loaded 455,059 points
  Found 30 damage regions
  Label distribution:
    Class 0 (Normal Road): 429,402 (94.36%)
    Class 1 (Pothole): 24,298 (5.34%)
    Class 3 (Edge Crack): 1,359 (0.30%)

Processing: Site1_12.las
  Loaded 456,490 points
  Found 35 damage regions
  Label distribution:
    Class 0 (Normal Road): 425,765 (93.27%)
    Class 1 (Pothole): 30,725 (6.73%)

Processing: Site1_13.las
  Loaded 698,678

Processing patches: 100%|██████████| 325/325 [00:00<00:00, 688.19it/s]



Saved 325 patches to C:/Users/umair.muhammad/Documents/PhD/Research Work/FedLearn/training/All_Nome/h5\train_balanced.h5

Class distribution:
  Normal Road: 2,568,810 points (96.48%)
  Pothole: 92,456 points (3.47%)
  Crack: 0 points (0.00%)
  Edge Crack: 1,134 points (0.04%)
  Flushed Out: 0 points (0.00%)


Processing patches: 100%|██████████| 82/82 [00:00<00:00, 2557.14it/s]



Saved 82 patches to C:/Users/umair.muhammad/Documents/PhD/Research Work/FedLearn/training/All_Nome/h5\val_balanced.h5

Class distribution:
  Normal Road: 649,591 points (96.70%)
  Pothole: 22,153 points (3.30%)
  Crack: 0 points (0.00%)
  Edge Crack: 0 points (0.00%)
  Flushed Out: 0 points (0.00%)

✓ Preprocessing complete!

NEXT STEPS:
1. Use the generated H5 files for training
2. Apply class weights: {0: 1.0, 1: 20.0, 2: 150.0, 3: 500.0, 4: 50.0}
3. Use Focal Loss for extreme imbalance
4. Monitor per-class IoU, not just overall accuracy
