In [1]:
#!/usr/bin/env python
# coding: utf-8

# # Optimized Biomass Prediction Pipeline for Gradio Interface
# 
# This notebook contains an optimized version of the biomass prediction pipeline
# specifically designed for processing smaller GeoTIFF chips within a Gradio interface.
# 
# Author: najahpokkiri
# Date: 2025-05-17

# ## 1. Import Libraries

import os
import sys
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import gc
import warnings

# Deep learning libraries
import torch
import torch.nn as nn

# Geospatial libraries
import rasterio
from rasterio.windows import Window

# Data processing
import joblib
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.decomposition import PCA
from io import BytesIO

# Suppress warnings
warnings.filterwarnings("ignore")

# Check if scikit-image is available for texture features
try:
    from skimage.feature import graycomatrix, graycoprops, local_binary_pattern
    from skimage.filters import sobel
    SKIMAGE_AVAILABLE = True
except ImportError:
    print("Warning: scikit-image not available. Texture features will be disabled.")
    SKIMAGE_AVAILABLE = False

# ## 2. Define Model Architecture

class StableResNet(nn.Module):
    """Numerically stable ResNet for biomass regression"""
    def __init__(self, n_features, dropout=0.2):
        super().__init__()
        
        self.input_proj = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.layer1 = self._make_simple_resblock(256, 256)
        self.layer2 = self._make_simple_resblock(256, 128)
        self.layer3 = self._make_simple_resblock(128, 64)
        
        self.regressor = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
    
    def _make_simple_resblock(self, in_dim, out_dim):
        return nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.BatchNorm1d(out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim),
            nn.BatchNorm1d(out_dim),
            nn.ReLU()
        ) if in_dim == out_dim else nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.BatchNorm1d(out_dim),
            nn.ReLU(),
        )
    
    def forward(self, x):
        x = self.input_proj(x)
        
        identity = x
        out = self.layer1(x)
        x = out + identity
        
        x = self.layer2(x)
        x = self.layer3(x)
        
        x = self.regressor(x)
        return x.squeeze()

# ## 3. Model Loading Functions

def load_model(model_dir, device=None):
    """
    Load the biomass prediction model and associated metadata.
    
    Args:
        model_dir (str): Directory containing model files
        device (torch.device): Device to load the model on
    
    Returns:
        dict: Dictionary containing model, scaler, and metadata
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    result = {
        'model': None,
        'package': None,
        'feature_names': None,
        'metadata': {
            'loaded_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'device': str(device)
        }
    }
    
    try:
        # Load feature names
        feature_path = os.path.join(model_dir, "feature_names.txt")
        if os.path.exists(feature_path):
            with open(feature_path, 'r') as f:
                result['feature_names'] = [line.strip() for line in f if line.strip()]
            print(f"Loaded {len(result['feature_names'])} feature names")
        else:
            print(f"Feature names file not found at {feature_path}")
            result['feature_names'] = []
        
        # Set number of features
        n_features = len(result['feature_names']) if result['feature_names'] else 99
        
        # Load model package
        package_path = os.path.join(model_dir, "model_package.pkl")
        if os.path.exists(package_path):
            result['package'] = joblib.load(package_path)
            print(f"Model package loaded from {package_path}")
            
            # Update n_features if available in package
            if 'n_features' in result['package']:
                n_features = result['package']['n_features']
                print(f"Using n_features={n_features} from package")
        else:
            print(f"Model package not found at {package_path}")
        
        # Load model weights
        model_path = os.path.join(model_dir, "model.pt")
        if os.path.exists(model_path):
            model = StableResNet(n_features=n_features)
            model.load_state_dict(torch.load(model_path, map_location=device))
            model.to(device)
            model.eval()
            result['model'] = model
            print(f"Model loaded successfully from {model_path}")
        else:
            print(f"Model file not found at {model_path}")
        
        return result
        
    except Exception as e:
        print(f"Error loading model: {e}")
        import traceback
        traceback.print_exc()
        return result

# ## 4. Feature Engineering Pipeline

def safe_divide(a, b, fill_value=0.0):
    """Safe division that handles zeros in the denominator"""
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    
    # Handle NaN/Inf in inputs
    a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
    b = np.nan_to_num(b, nan=1e-10, posinf=1e10, neginf=-1e10)
    
    if a.ndim == 0 and b.ndim > 0:
        a = np.full_like(b, a)
    elif b.ndim == 0 and a.ndim > 0:
        b = np.full_like(a, b)
    elif a.ndim == 0 and b.ndim == 0:
        if abs(b) < 1e-10:
            return fill_value
        else:
            return float(a / b)
    
    mask = np.abs(b) < 1e-10
    result = np.full_like(a, fill_value, dtype=np.float32)
    if np.any(~mask):
        result[~mask] = a[~mask] / b[~mask]
    
    result = np.nan_to_num(result, nan=fill_value, posinf=fill_value, neginf=fill_value)
    return result

def extract_features(satellite_data, use_advanced_indices=True, 
                    use_texture_features=True, use_spatial_features=True,
                    use_pca_features=True, pca_components=25):
    """
    Extract all features from satellite data in a single function.
    Optimized for memory efficiency and smaller GeoTIFF files.
    
    Args:
        satellite_data (numpy.ndarray): Satellite image with shape (bands, height, width)
        use_advanced_indices (bool): Whether to calculate spectral indices
        use_texture_features (bool): Whether to extract texture features
        use_spatial_features (bool): Whether to calculate spatial features
        use_pca_features (bool): Whether to calculate PCA features
        pca_components (int): Number of PCA components
    
    Returns:
        dict: Dictionary of feature arrays with shape (height, width)
    """
    # Start time
    start_time = time.time()
    print("Extracting features...")
    
    # Get dimensions
    n_bands, height, width = satellite_data.shape
    print(f"Image dimensions: {width}x{height}, {n_bands} bands")
    
    # Initialize features dictionary
    all_features = {}
    
    # 1. Add original bands
    for i in range(n_bands):
        band_data = satellite_data[i].copy()
        band_data = np.nan_to_num(band_data, nan=0.0)
        all_features[f'Band_{i+1:02d}'] = band_data
    
    # 2. Calculate spectral indices
    if use_advanced_indices:
        indices_start = time.time()
        
        # Enhanced band mapping with error checking
        def safe_get_band(idx):
            return satellite_data[idx] if idx < n_bands else None
        
        # Try to get commonly used bands
        blue = safe_get_band(1)  # Adjust indices based on your data
        green = safe_get_band(2)
        red = safe_get_band(3)
        nir = safe_get_band(7)
        swir1 = safe_get_band(9)
        swir2 = safe_get_band(10)
        
        # Calculate indices if bands are available
        if all(b is not None for b in [red, nir]):
            # NDVI (Normalized Difference Vegetation Index)
            all_features['NDVI'] = safe_divide(nir - red, nir + red)
            
            if blue is not None and green is not None:
                # EVI (Enhanced Vegetation Index)
                all_features['EVI'] = 2.5 * safe_divide(nir - red, nir + 6*red - 7.5*blue + 1)
                
                # SAVI (Soil Adjusted Vegetation Index)
                all_features['SAVI'] = 1.5 * safe_divide(nir - red, nir + red + 0.5)
                
                # MSAVI2 (Modified Soil Adjusted Vegetation Index)
                all_features['MSAVI2'] = 0.5 * (2 * nir + 1 - np.sqrt((2 * nir + 1)**2 - 8 * (nir - red)))
                
                # NDWI (Normalized Difference Water Index)
                all_features['NDWI'] = safe_divide(green - nir, green + nir)
        
        if swir1 is not None and nir is not None:
            # NDMI (Normalized Difference Moisture Index)
            all_features['NDMI'] = safe_divide(nir - swir1, nir + swir1)
        
        if swir2 is not None and nir is not None:
            # NBR (Normalized Burn Ratio)
            all_features['NBR'] = safe_divide(nir - swir2, nir + swir2)
        
        indices_time = time.time() - indices_start
        print(f"Calculated {len(all_features) - n_bands} spectral indices in {indices_time:.2f}s")
    
    # 3. Extract texture features
    if use_texture_features and SKIMAGE_AVAILABLE:
        texture_start = time.time()
        
        # Select representative bands for texture analysis (e.g., NIR bands)
        key_bands = [7]  # NIR band
        
        for band_idx in key_bands:
            if band_idx >= n_bands:
                continue
                
            try:
                band = satellite_data[band_idx].copy()
                
                # Normalize to 0-255 for texture analysis
                band_min, band_max = np.nanpercentile(band[~np.isnan(band)], [1, 99])
                band_norm = np.clip((band - band_min) / (band_max - band_min + 1e-8), 0, 1)
                band_norm = (band_norm * 255).astype(np.uint8)
                
                # Replace NaN with median
                band_norm = np.nan_to_num(band_norm, nan=np.nanmedian(band_norm))
                
                # Edge detection using Sobel
                sobel_response = sobel(band_norm.astype(float))
                all_features[f'Sobel_B{band_idx}'] = sobel_response
                
                # Local Binary Pattern (for small images only)
                if height * width < 1000000:  # ~1MP limit
                    try:
                        lbp = local_binary_pattern(band_norm, 8, 1, method='uniform')
                        all_features[f'LBP_B{band_idx}'] = lbp
                    except Exception as e:
                        print(f"Warning: Error calculating LBP: {e}")
                
                # GLCM properties - simplified approach for efficiency
                # Calculate on a smaller representative patch for memory efficiency
                sample_size = min(128, height, width)
                center_y, center_x = height//2, width//2
                offset = sample_size // 2
                y_start = max(0, center_y - offset)
                y_end = min(height, center_y + offset)
                x_start = max(0, center_x - offset)
                x_end = min(width, center_x + offset)
                
                patch = band_norm[y_start:y_end, x_start:x_end]
                
                if patch.size > 0:
                    try:
                        glcm = graycomatrix(patch, [1], [0], levels=256, symmetric=True, normed=True)
                        # Extract properties as scalar values
                        for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy']:
                            value = float(graycoprops(glcm, prop)[0, 0])
                            # Create 2D arrays with the scalar value
                            all_features[f'GLCM_{prop}_B{band_idx}'] = np.full((height, width), value, dtype=np.float32)
                    except Exception as e:
                        print(f"Warning: Error calculating GLCM: {e}")
                        
            except Exception as e:
                print(f"Warning: Error processing band {band_idx} for texture: {e}")
        
        texture_time = time.time() - texture_start
        texture_count = len(all_features) - n_bands - (len(all_features) - n_bands)
        print(f"Extracted {texture_count} texture features in {texture_time:.2f}s")
    
    # 4. Calculate spatial features
    if use_spatial_features:
        spatial_start = time.time()
        
        # Key bands for spatial analysis
        key_bands = [7]  # NIR band
        
        for band_idx in key_bands:
            if band_idx < n_bands:
                try:
                    band = satellite_data[band_idx].copy()
                    band = np.nan_to_num(band, nan=0.0)
                    
                    # Calculate gradients
                    grad_y, grad_x = np.gradient(band)
                    grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
                    all_features[f'Gradient_B{band_idx}'] = grad_magnitude
                    
                except Exception as e:
                    print(f"Warning: Error calculating spatial features for band {band_idx}: {e}")
        
        # Gradient features for NDVI if available
        if 'NDVI' in all_features:
            try:
                ndvi_clean = np.nan_to_num(all_features['NDVI'], nan=0.0)
                
                # Calculate gradients
                grad_y, grad_x = np.gradient(ndvi_clean)
                grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
                all_features['NDVI_gradient'] = grad_magnitude
                
            except Exception as e:
                print(f"Warning: Error calculating gradient for NDVI: {e}")
        
        spatial_time = time.time() - spatial_start
        spatial_count = len(all_features) - n_bands - (len(all_features) - n_bands) - texture_count
        print(f"Calculated {spatial_count} spatial features in {spatial_time:.2f}s")
    
    # 5. Calculate PCA features (if image is not too large)
    if use_pca_features and height * width < 5000000:  # ~5MP limit for memory
        pca_start = time.time()
        
        # Reshape for PCA (pixels x bands)
        bands_reshaped = satellite_data.reshape(n_bands, -1).T
        
        # Handle NaN values
        valid_mask = ~np.any(np.isnan(bands_reshaped), axis=1)
        bands_clean = bands_reshaped[valid_mask]
        
        if len(bands_clean) > 0:
            try:
                # Standardize and apply PCA
                scaler = StandardScaler()
                bands_scaled = scaler.fit_transform(bands_clean)
                
                # Limit components to avoid memory issues
                n_components = min(pca_components, bands_scaled.shape[1], 25)
                
                pca = PCA(n_components=n_components)
                pca_features = pca.fit_transform(bands_scaled)
                
                # Create full PCA array
                pca_full = np.zeros((height * width, pca_features.shape[1]))
                pca_full[valid_mask] = pca_features
                pca_full = pca_full.reshape(height, width, pca_features.shape[1])
                
                # Convert to dictionary format
                for i in range(pca_full.shape[2]):
                    all_features[f'PCA_{i+1:02d}'] = pca_full[:, :, i]
                
                explained_variance = pca.explained_variance_ratio_.sum()
                print(f"PCA explained variance: {explained_variance:.3f}")
                
            except Exception as e:
                print(f"Warning: Error calculating PCA: {e}")
        else:
            print("Warning: No valid data for PCA")
        
        pca_time = time.time() - pca_start
        pca_count = len(all_features) - n_bands - (len(all_features) - n_bands) - texture_count - spatial_count
        print(f"Calculated {pca_count} PCA features in {pca_time:.2f}s")
    
    # Report total feature count and time
    total_time = time.time() - start_time
    print(f"Extracted {len(all_features)} total features in {total_time:.2f}s")
    
    # Clean up memory
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return all_features

# ## 5. Prediction Function

def predict_biomass(file_path, model_result, use_advanced_indices=True, 
                   use_texture_features=True, use_spatial_features=True,
                   use_pca_features=True, pca_components=25):
    """
    Main function to predict biomass from a GeoTIFF file.
    Optimized for smaller GeoTIFF chips.
    
    Args:
        file_path (str or BytesIO): Path to GeoTIFF file or file-like object
        model_result (dict): Dictionary containing model, package, feature_names
        use_advanced_indices (bool): Whether to calculate spectral indices
        use_texture_features (bool): Whether to extract texture features
        use_spatial_features (bool): Whether to calculate spatial features
        use_pca_features (bool): Whether to calculate PCA features
        pca_components (int): Number of PCA components
    
    Returns:
        dict: Dictionary containing predictions, statistics, and metadata
    """
    # Check if model is loaded
    if model_result['model'] is None:
        return {"error": "Model not loaded"}
    
    if not model_result['feature_names']:
        return {"error": "No feature names provided"}
    
    result = {
        "predictions": None,
        "statistics": {},
        "metadata": {
            "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            "file_path": str(file_path) if isinstance(file_path, str) else "BytesIO object",
            "processing_time": 0
        }
    }
    
    start_time = time.time()
    
    try:
        # Load configuration from model package
        config = {
            "use_log_transform": True,
            "epsilon": 1.0
        }
        
        if model_result['package'] is not None:
            if 'use_log_transform' in model_result['package']:
                config["use_log_transform"] = model_result['package']['use_log_transform']
            if 'epsilon' in model_result['package']:
                config["epsilon"] = model_result['package']['epsilon']
        
        # Open raster file
        with rasterio.open(file_path) as src:
            # Read metadata
            result["metadata"]["width"] = src.width
            result["metadata"]["height"] = src.height
            result["metadata"]["crs"] = str(src.crs)
            result["metadata"]["transform"] = str(src.transform)
            result["metadata"]["bands"] = src.count
            
            # Read entire image
            image_data = src.read()
            
            # Create validity mask
            valid_mask = np.all(np.isfinite(image_data), axis=0)
            
            if np.sum(valid_mask) == 0:
                return {"error": "No valid pixels in image"}
            
            # Extract all features
            features = extract_features(
                image_data,
                use_advanced_indices=use_advanced_indices,
                use_texture_features=use_texture_features,
                use_spatial_features=use_spatial_features,
                use_pca_features=use_pca_features,
                pca_components=pca_components
            )
            
            # Extract features in correct order for model
            feature_names = model_result['feature_names']
            height, width = valid_mask.shape
            y_indices, x_indices = np.where(valid_mask)
            feature_matrix = np.zeros((len(y_indices), len(feature_names)), dtype=np.float32)
            
            print(f"Preparing {len(y_indices)} valid pixels for prediction...")
            for i, feature_name in enumerate(feature_names):
                if feature_name in features:
                    feature_data = features[feature_name]
                    feature_values = feature_data[y_indices, x_indices]
                    feature_values = np.nan_to_num(feature_values, nan=0.0)
                    feature_matrix[:, i] = feature_values
                else:
                    # If feature is missing, use zeros
                    print(f"Warning: Feature '{feature_name}' not found, using zeros")
                    feature_matrix[:, i] = 0.0
            
            # Apply scaling if available
            if model_result['package'] is not None and 'scaler' in model_result['package']:
                try:
                    feature_matrix = model_result['package']['scaler'].transform(feature_matrix)
                except Exception as e:
                    print(f"Warning: Error applying scaler: {e}")
            
            # Make predictions
            predictions = np.zeros((height, width), dtype=np.float32)
            
            print("Running model inference...")
            with torch.no_grad():
                # Process in batches to avoid memory issues
                batch_size = min(10000, len(y_indices))
                
                for i in range(0, len(y_indices), batch_size):
                    end_idx = min(i + batch_size, len(y_indices))
                    batch = feature_matrix[i:end_idx]
                    
                    # Convert to tensor
                    device = next(model_result['model'].parameters()).device
                    batch_tensor = torch.tensor(batch, dtype=torch.float32).to(device)
                    
                    # Get predictions
                    batch_preds = model_result['model'](batch_tensor).cpu().numpy()
                    
                    # Handle scalar case
                    if not isinstance(batch_preds, np.ndarray):
                        batch_preds = np.array([batch_preds])
                    if batch_preds.ndim == 0:
                        batch_preds = np.array([float(batch_preds)])
                    
                    # Convert from log scale if needed
                    if config["use_log_transform"]:
                        batch_preds = np.exp(batch_preds) - config["epsilon"]
                        batch_preds = np.maximum(batch_preds, 0)  # Ensure non-negative
                    
                    # Map predictions back to the correct pixels
                    for j in range(end_idx - i):
                        y_idx = y_indices[i + j]
                        x_idx = x_indices[i + j]
                        predictions[y_idx, x_idx] = batch_preds[j]
                    
                    print(f"\rProcessed {end_idx}/{len(y_indices)} pixels", end="")
            
            print("\nPrediction complete!")
            
            # Calculate statistics
            valid_predictions = predictions[valid_mask]
            
            result["statistics"] = {
                "min": float(np.min(valid_predictions)),
                "max": float(np.max(valid_predictions)),
                "mean": float(np.mean(valid_predictions)),
                "median": float(np.median(valid_predictions)),
                "std": float(np.std(valid_predictions)),
                "sum": float(np.sum(valid_predictions)),
                "pixel_count": int(valid_predictions.size),
                "valid_percentage": float(100 * valid_predictions.size / (height * width))
            }
            
            # Calculate total biomass if transform is available
            if src.transform:
                pixel_area_m2 = abs(src.transform[0] * src.transform[4])
                area_hectares = valid_predictions.size * pixel_area_m2 / 10000
                total_biomass_mg = np.sum(valid_predictions) * pixel_area_m2 / 10000
                
                result["statistics"]["area_hectares"] = float(area_hectares)
                result["statistics"]["total_biomass_mg"] = float(total_biomass_mg)
            
            # Store predictions
            result["predictions"] = predictions
            result["valid_mask"] = valid_mask
            result["metadata"]["processing_time"] = time.time() - start_time
            
            print(f"Prediction complete in {result['metadata']['processing_time']:.2f}s")
            print(f"Biomass range: {result['statistics']['min']:.2f} - {result['statistics']['max']:.2f} Mg/ha")
            print(f"Mean biomass: {result['statistics']['mean']:.2f} Mg/ha")
            
            # Clean up
            del features, feature_matrix
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            return result
    
    except Exception as e:
        import traceback
        error_trace = traceback.format_exc()
        return {
            "error": str(e),
            "traceback": error_trace,
            "metadata": {
                "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                "file_path": str(file_path) if isinstance(file_path, str) else "BytesIO object",
                "processing_time": time.time() - start_time
            }
        }

# ## 6. Visualization Functions

def create_biomass_visualization(result, visualization_type="heatmap", rgb_indexes=None, dpi=100):
    """
    Create visualization of biomass predictions.
    
    Args:
        result (dict): Result dictionary from predict_biomass
        visualization_type (str): Type of visualization ('heatmap' or 'rgb_overlay')
        rgb_indexes (tuple): Indexes for RGB visualization (e.g., (3,2,1) for true color)
        dpi (int): DPI for output image
    
    Returns:
        BytesIO: Image buffer with visualization
    """
    if "error" in result:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, f"Error: {result['error']}", 
                ha='center', va='center', fontsize=12)
        ax.axis('off')
        buf = BytesIO()
        fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight')
        buf.seek(0)
        plt.close(fig)
        return buf
    
    predictions = result["predictions"]
    valid_mask = result.get("valid_mask", np.ones_like(predictions, dtype=bool))
    
    # Use masked array for better visualization
    masked_predictions = np.ma.masked_where(~valid_mask, predictions)
    
    # Get min/max values for better visualization (1-99 percentile)
    if np.sum(valid_mask) > 0:
        vmin = np.nanpercentile(predictions[valid_mask], 1)
        vmax = np.nanpercentile(predictions[valid_mask], 99)
    else:
        vmin, vmax = 0, 100
    
    try:
        if visualization_type == "heatmap":
            # Create heatmap visualization
            fig, ax = plt.subplots(figsize=(10, 8))
            
            im = ax.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax)
            cbar = fig.colorbar(im, ax=ax, label='Biomass (Mg/ha)')
            
            ax.set_title('Predicted Above-Ground Biomass')
            ax.axis('off')
            
            # Add statistics as text
            if "statistics" in result:
                stats = result["statistics"]
                stats_text = (
                    f"Mean: {stats['mean']:.2f} Mg/ha\n"
                    f"Range: {stats['min']:.2f} - {stats['max']:.2f} Mg/ha\n"
                )
                if "total_biomass_mg" in stats:
                    stats_text += f"Total: {stats['total_biomass_mg']:.2f} Mg"
                
                ax.text(0.02, 0.02, stats_text, transform=ax.transAxes,
                       fontsize=9, verticalalignment='bottom',
                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
        
        elif visualization_type == "rgb_overlay" and rgb_indexes is not None:
            # Create RGB + overlay visualization
            fig, ax = plt.subplots(figsize=(10, 8))
            
            # Try to get original image
            if "original_image" in result:
                rgb_bands = result["original_image"][rgb_indexes]
                rgb = np.dstack([rgb_bands[0], rgb_bands[1], rgb_bands[2]])
                
                # Normalize RGB for display
                rgb_min = np.nanpercentile(rgb, 2)
                rgb_max = np.nanpercentile(rgb, 98)
                rgb_norm = np.clip((rgb - rgb_min) / (rgb_max - rgb_min), 0, 1)
                
                # Show RGB
                ax.imshow(rgb_norm)
                
                # Create mask for overlay (where we have predictions)
                mask = valid_mask & ~np.isclose(predictions, 0)
                overlay = np.zeros((*predictions.shape, 4))
                
                # Create colormap for biomass
                norm = plt.cm.colors.Normalize(vmin=vmin, vmax=vmax)
                cmap = plt.cm.viridis
                
                # Apply colormap
                overlay[..., :3] = cmap(norm(predictions))[..., :3]
                overlay[..., 3] = np.where(mask, 0.7, 0)  # Set alpha channel
                
                ax.imshow(overlay)
                cbar = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), 
                                   ax=ax, label='Biomass (Mg/ha)')
                
                ax.set_title('Biomass Prediction Overlay')
            else:
                # Fallback to heatmap if no original image
                im = ax.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax)
                cbar = fig.colorbar(im, ax=ax, label='Biomass (Mg/ha)')
                ax.set_title('Predicted Above-Ground Biomass')
            
            ax.axis('off')
        
        else:
            # Default visualization
            fig, ax = plt.subplots(figsize=(10, 8))
            im = ax.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax)
            cbar = fig.colorbar(im, ax=ax, label='Biomass (Mg/ha)')
            ax.set_title('Predicted Above-Ground Biomass')
            ax.axis('off')
        
        # Save figure to BytesIO buffer
        buf = BytesIO()
        fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight')
        buf.seek(0)
        plt.close(fig)
        
        return buf
        
    except Exception as e:
        print(f"Error creating visualization: {e}")
        import traceback
        traceback.print_exc()
        
        # Return error image
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, f"Visualization error: {e}", 
                ha='center', va='center', fontsize=12)
        ax.axis('off')
        buf = BytesIO()
        fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight')
        buf.seek(0)
        plt.close(fig)
        
        return buf

# ## 7. Testing with Sample Data

def test_prediction_pipeline(model_dir, test_file_path):
    """
    Test the prediction pipeline with a sample file.
    
    Args:
        model_dir (str): Directory containing model files
        test_file_path (str): Path to a test GeoTIFF file
    """
    print(f"Testing prediction pipeline...")
    print(f"Model directory: {model_dir}")
    print(f"Test file: {test_file_path}")
    
    # Load model
    model_result = load_model(model_dir)
    
    if model_result['model'] is None:
        print("Failed to load model")
        return
    
    # Run prediction
    result = predict_biomass(
        test_file_path, 
        model_result,
        use_advanced_indices=True,
        use_texture_features=True,
        use_spatial_features=True,
        use_pca_features=True
    )
    
    if "error" in result:
        print(f"Prediction error: {result['error']}")
        return
    
    # Print statistics
    print("\nBiomass Statistics:")
    for key, value in result["statistics"].items():
        print(f"  {key}: {value}")
    
    # Create visualization
    print("\nCreating visualization...")
    vis_buf = create_biomass_visualization(result, visualization_type="heatmap")
    
    # Save visualization
    output_dir = "test_output"
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, "test_visualization.png")
    
    with open(output_path, 'wb') as f:
        f.write(vis_buf.getvalue())
    
    print(f"Visualization saved to {output_path}")
    
    return result



In [2]:

# Example test usage (uncomment to run)
model_dir = "/teamspace/studios/this_studio/files/biomass_results_20250516_224934/models/StableResNet_20250516_225706"
test_file = "/teamspace/studios/this_studio/torchgeo/experiments/biomass/center_crop_for_prediction.tif"
test_result = test_prediction_pipeline(model_dir, test_file)

Testing prediction pipeline...
Model directory: /teamspace/studios/this_studio/files/biomass_results_20250516_224934/models/StableResNet_20250516_225706
Test file: /teamspace/studios/this_studio/torchgeo/experiments/biomass/center_crop_for_prediction.tif
Loaded 99 feature names
Model package loaded from /teamspace/studios/this_studio/files/biomass_results_20250516_224934/models/StableResNet_20250516_225706/model_package.pkl
Using n_features=99 from package
Model loaded successfully from /teamspace/studios/this_studio/files/biomass_results_20250516_224934/models/StableResNet_20250516_225706/model.pt
Extracting features...
Image dimensions: 512x512, 98 bands
Calculated 7 spectral indices in 0.02s
Extracted 0 texture features in 0.05s
Calculated 0 spatial features in 0.01s
PCA explained variance: 0.971
Calculated 0 PCA features in 1.27s
Extracted 138 total features in 1.51s
Preparing 250745 valid pixels for prediction...
Running model inference...
Processed 250745/250745 pixels
Prediction