# SAR Image Colorization - Inference & Visualization

## Overview
This notebook demonstrates production-ready inference and visualization for SAR image colorization models. It includes:

- **Model Loading**: Loading trained models from checkpoints
- **Batch Inference**: Efficient processing of multiple images
- **Visualization**: High-quality visualization of results
- **Geospatial Integration**: Handling geospatial metadata (if available)
- **Performance Analysis**: Inference speed and memory usage
- **Export Options**: Saving results in various formats

## Key Features:
1. **Production Inference**: Optimized inference pipeline for real-world deployment
2. **Visualization Suite**: Comprehensive visualization tools for result analysis
3. **Geospatial Support**: Preservation of geospatial metadata and projections
4. **Performance Monitoring**: Real-time performance metrics and optimization
5. **Export Capabilities**: Multiple output formats for different use cases

## Dependencies
- `src/infer.py` - Inference utilities
- `src/models/` - Model implementations
- `rasterio` - Geospatial data handling
- `PIL` - Image processing


In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
import psutil
import gc
from PIL import Image
import cv2
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../src')

# Import our custom modules
try:
    from infer import load_model, run_inference, save_results
    from utils import seed_everything, calculate_metrics
    from data_pipeline import SARDataset
    from models.unet import UNet, UNetLight
    from models.generator_adv import AdversarialGenerator
    from models.discriminator import PatchDiscriminator
    print(" Successfully imported inference and model modules")
except ImportError as e:
    print(f" Import error: {e}")
    print("Make sure you're running from the notebooks directory")

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

# Set random seed for reproducibility
seed_everything(42)

print(" Libraries imported successfully!")

‚ö†Ô∏è Import error: cannot import name 'load_model' from 'infer' (d:\sar image\SAR_Image_Colorization\notebooks\../src\infer.py)
Make sure you're running from the notebooks directory


NameError: name 'seed_everything' is not defined

In [None]:
# Configuration for inference
CONFIG = {
    'data_root': '../Data/Processed',
    'inference_sar_path': '../Data/Processed/val/SAR',
    'inference_optical_path': '../Data/Processed/val/Optical',
    'output_path': '../Data/Processed/inference_results',
    'batch_size': 4,
    'image_size': (256, 256),
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'num_samples': 20,  # Number of samples for inference
    'model_configs': {
        'unet': {
            'checkpoint_path': '../experiments/checkpoints/supervised/best_model.pth',
            'model_class': UNet,
            'model_params': {
                'in_channels': 1,
                'out_channels': 3,
                'base_channels': 64,
                'depth': 4,
                'dropout': 0.1,
                'attention': True
            }
        },
        'unet_light': {
            'checkpoint_path': '../experiments/checkpoints/supervised/best_model.pth',
            'model_class': UNetLight,
            'model_params': {
                'in_channels': 1,
                'out_channels': 3,
                'base_channels': 32,
                'depth': 3,
                'dropout': 0.1
            }
        }
    },
    'inference_config': {
        'save_images': True,
        'save_metadata': True,
        'visualize_results': True,
        'export_formats': ['png', 'tiff'],
        'quality': 95
    }
}

print("üîß Inference Configuration:")
print(f"   Device: {CONFIG['device']}")
print(f"   Batch size: {CONFIG['batch_size']}")
print(f"   Image size: {CONFIG['image_size']}")
print(f"   Samples: {CONFIG['num_samples']}")
print(f"   Output path: {CONFIG['output_path']}")

# Verify paths and create output directory
print("\nüîç Verifying paths...")
for key, path in CONFIG.items():
    if 'path' in key and os.path.exists(path):
        file_count = len([f for f in os.listdir(path) if f.endswith('.png')])
        print(f"‚úÖ {key}: {path} ({file_count} files)")
    elif 'path' in key:
        print(f"‚ùå {key}: {path} (not found)")

# Create output directory
os.makedirs(CONFIG['output_path'], exist_ok=True)
print(f"‚úÖ Output directory created: {CONFIG['output_path']}")

# Check model checkpoints
print("\nüîç Checking model checkpoints...")
for model_name, config in CONFIG['model_configs'].items():
    checkpoint_path = config['checkpoint_path']
    if os.path.exists(checkpoint_path):
        print(f"‚úÖ {model_name}: {checkpoint_path}")
    else:
        print(f"‚ùå {model_name}: {checkpoint_path} (not found)")


In [None]:
# Model Loading and Setup
def load_inference_models():
    """Load models for inference"""
    
    models = {}
    
    for model_name, config in CONFIG['model_configs'].items():
        try:
            # Create model instance
            model_class = config['model_class']
            model_params = config['model_params']
            model = model_class(**model_params)
            
            # Load checkpoint if available
            checkpoint_path = config['checkpoint_path']
            if os.path.exists(checkpoint_path):
                checkpoint = torch.load(checkpoint_path, map_location=CONFIG['device'])
                if 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                else:
                    model.load_state_dict(checkpoint)
                print(f"‚úÖ {model_name}: Loaded from checkpoint")
            else:
                print(f"‚ö†Ô∏è {model_name}: Using untrained model (checkpoint not found)")
            
            # Move to device
            model = model.to(CONFIG['device'])
            model.eval()
            
            models[model_name] = model
            
        except Exception as e:
            print(f"‚ùå Error loading {model_name}: {e}")
    
    return models

# Load models
print("üèóÔ∏è Loading models for inference...")
inference_models = load_inference_models()

# Load inference dataset
def load_inference_dataset():
    """Load dataset for inference"""
    
    try:
        # Create dataset
        inference_dataset = SARDataset(
            sar_path=CONFIG['inference_sar_path'],
            optical_path=CONFIG['inference_optical_path'],
            image_size=CONFIG['image_size'],
            filter_method='lee',
            normalization='robust',
            augmentation=False  # No augmentation for inference
        )
        
        # Limit dataset size if needed
        if len(inference_dataset) > CONFIG['num_samples']:
            inference_dataset.samples = inference_dataset.samples[:CONFIG['num_samples']]
        
        # Create data loader
        inference_loader = DataLoader(
            inference_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=False,  # No shuffling for consistent inference
            num_workers=0,
            pin_memory=True
        )
        
        print(f"‚úÖ Inference dataset loaded successfully!")
        print(f"   Dataset size: {len(inference_dataset)} samples")
        print(f"   Batch size: {CONFIG['batch_size']}")
        print(f"   Number of batches: {len(inference_loader)}")
        
        return inference_dataset, inference_loader
        
    except Exception as e:
        print(f"‚ùå Error loading inference dataset: {e}")
        return None, None

# Load inference dataset
print("\nüìÇ Loading inference dataset...")
inference_dataset, inference_loader = load_inference_dataset()


In [None]:
# Batch Inference with Performance Monitoring
def run_batch_inference(model, data_loader, model_name="Model"):
    """Run batch inference with performance monitoring"""
    
    model = model.to(CONFIG['device'])
    model.eval()
    
    all_predictions = []
    all_targets = []
    all_sar_inputs = []
    inference_times = []
    
    print(f"üöÄ Running batch inference with {model_name}...")
    
    # Performance monitoring
    start_time = time.time()
    initial_memory = psutil.virtual_memory().used / (1024**3)
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(data_loader, desc=f'Inference {model_name}')):
            batch_start = time.time()
            
            sar_batch = batch['sar'].to(CONFIG['device'])
            optical_batch = batch['optical'].to(CONFIG['device'])
            
            # Run inference
            predictions = model(sar_batch)
            
            # Store results
            all_predictions.append(predictions.cpu().numpy())
            all_targets.append(optical_batch.cpu().numpy())
            all_sar_inputs.append(sar_batch.cpu().numpy())
            
            # Record timing
            batch_time = time.time() - batch_start
            inference_times.append(batch_time)
            
            # Memory cleanup
            if batch_idx % 5 == 0:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
    
    # Calculate performance metrics
    total_time = time.time() - start_time
    final_memory = psutil.virtual_memory().used / (1024**3)
    memory_used = final_memory - initial_memory
    
    # Concatenate results
    all_predictions = np.concatenate(all_predictions, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)
    all_sar_inputs = np.concatenate(all_sar_inputs, axis=0)
    
    # Performance statistics
    avg_inference_time = np.mean(inference_times)
    total_samples = len(all_predictions)
    samples_per_second = total_samples / total_time
    
    print(f"‚úÖ {model_name} inference completed!")
    print(f"   Total time: {total_time:.2f} seconds")
    print(f"   Average batch time: {avg_inference_time:.4f} seconds")
    print(f"   Samples per second: {samples_per_second:.2f}")
    print(f"   Memory used: {memory_used:.2f} GB")
    print(f"   Total samples: {total_samples}")
    
    return {
        'model_name': model_name,
        'predictions': all_predictions,
        'targets': all_targets,
        'sar_inputs': all_sar_inputs,
        'performance': {
            'total_time': total_time,
            'avg_batch_time': avg_inference_time,
            'samples_per_second': samples_per_second,
            'memory_used': memory_used,
            'total_samples': total_samples
        }
    }

# Run inference for all models
if inference_loader is not None and inference_models:
    print("\nüéØ Running batch inference for all models...")
    
    inference_results = {}
    
    for model_name, model in inference_models.items():
        try:
            result = run_batch_inference(model, inference_loader, model_name)
            inference_results[model_name] = result
        except Exception as e:
            print(f"‚ùå Error running inference with {model_name}: {e}")
    
    print(f"\n‚úÖ Inference completed for {len(inference_results)} models")
else:
    print("‚ùå Cannot run inference - models or data not available")


In [None]:
# Comprehensive Visualization
def visualize_inference_results(inference_results, num_samples=8):
    """Visualize inference results with comprehensive analysis"""
    
    if not inference_results:
        print("‚ùå No inference results to visualize")
        return
    
    # Get the first model's results for visualization
    first_model = list(inference_results.keys())[0]
    result = inference_results[first_model]
    
    predictions = result['predictions']
    targets = result['targets']
    sar_inputs = result['sar_inputs']
    
    # Limit number of samples for visualization
    num_samples = min(num_samples, len(predictions))
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(4, num_samples, figsize=(num_samples * 3, 12))
    if num_samples == 1:
        axes = axes.reshape(4, 1)
    
    fig.suptitle('SAR Image Colorization - Inference Results', fontsize=16, fontweight='bold')
    
    for i in range(num_samples):
        # SAR input
        sar_img = sar_inputs[i].squeeze()
        axes[0, i].imshow(sar_img, cmap='gray')
        axes[0, i].set_title(f'SAR Input {i+1}', fontsize=10)
        axes[0, i].axis('off')
        
        # Ground truth
        gt_img = np.transpose(targets[i], (1, 2, 0))
        gt_img = np.clip(gt_img, 0, 1)
        axes[1, i].imshow(gt_img)
        axes[1, i].set_title(f'Ground Truth {i+1}', fontsize=10)
        axes[1, i].axis('off')
        
        # Prediction
        pred_img = np.transpose(predictions[i], (1, 2, 0))
        pred_img = np.clip(pred_img, 0, 1)
        axes[2, i].imshow(pred_img)
        axes[2, i].set_title(f'Prediction {i+1}', fontsize=10)
        axes[2, i].axis('off')
        
        # Error map
        error_map = np.abs(pred_img - gt_img)
        error_map = np.mean(error_map, axis=2)  # Convert to grayscale
        im = axes[3, i].imshow(error_map, cmap='hot')
        axes[3, i].set_title(f'Error Map {i+1}', fontsize=10)
        axes[3, i].axis('off')
        
        # Add colorbar for error map
        plt.colorbar(im, ax=axes[3, i], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()
    
    # Calculate and display metrics
    mse = np.mean((predictions - targets) ** 2)
    mae = np.mean(np.abs(predictions - targets))
    
    print(f"\nüìä Inference Results Summary:")
    print(f"   MSE: {mse:.6f}")
    print(f"   MAE: {mae:.6f}")
    print(f"   Samples visualized: {num_samples}")

# Performance comparison visualization
def visualize_performance_comparison(inference_results):
    """Visualize performance comparison across models"""
    
    if not inference_results:
        print("‚ùå No inference results to compare")
        return
    
    # Extract performance metrics
    model_names = list(inference_results.keys())
    performance_data = []
    
    for model_name, result in inference_results.items():
        perf = result['performance']
        performance_data.append({
            'Model': model_name,
            'Total Time (s)': perf['total_time'],
            'Samples/sec': perf['samples_per_second'],
            'Memory Used (GB)': perf['memory_used'],
            'Avg Batch Time (s)': perf['avg_batch_time']
        })
    
    df = pd.DataFrame(performance_data)
    
    # Create performance comparison visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Model Performance Comparison', fontsize=16, fontweight='bold')
    
    # 1. Total inference time
    axes[0, 0].bar(model_names, df['Total Time (s)'], color=['skyblue', 'lightcoral'])
    axes[0, 0].set_ylabel('Total Time (seconds)')
    axes[0, 0].set_title('Total Inference Time')
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Samples per second
    axes[0, 1].bar(model_names, df['Samples/sec'], color=['skyblue', 'lightcoral'])
    axes[0, 1].set_ylabel('Samples per Second')
    axes[0, 1].set_title('Inference Speed')
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Memory usage
    axes[1, 0].bar(model_names, df['Memory Used (GB)'], color=['skyblue', 'lightcoral'])
    axes[1, 0].set_ylabel('Memory Used (GB)')
    axes[1, 0].set_title('Memory Usage')
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. Average batch time
    axes[1, 1].bar(model_names, df['Avg Batch Time (s)'], color=['skyblue', 'lightcoral'])
    axes[1, 1].set_ylabel('Average Batch Time (seconds)')
    axes[1, 1].set_title('Batch Processing Time')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print performance summary
    print("\nüìä Performance Comparison:")
    print(df.to_string(index=False, float_format='%.4f'))

# Visualize results if available
if 'inference_results' in locals() and inference_results:
    print("\nüñºÔ∏è Visualizing inference results...")
    visualize_inference_results(inference_results, num_samples=6)
    
    print("\n‚ö° Visualizing performance comparison...")
    visualize_performance_comparison(inference_results)
else:
    print("‚ùå Cannot visualize results - no inference data available")


In [None]:
# Export Results
def export_inference_results(inference_results, output_path):
    """Export inference results in various formats"""
    
    if not inference_results:
        print("‚ùå No inference results to export")
        return
    
    print(f"üíæ Exporting inference results to {output_path}...")
    
    # Create output directories
    os.makedirs(output_path, exist_ok=True)
    os.makedirs(os.path.join(output_path, 'images'), exist_ok=True)
    os.makedirs(os.path.join(output_path, 'metadata'), exist_ok=True)
    
    for model_name, result in inference_results.items():
        print(f"\nüìÅ Exporting {model_name} results...")
        
        predictions = result['predictions']
        targets = result['targets']
        sar_inputs = result['sar_inputs']
        performance = result['performance']
        
        # Create model-specific directory
        model_dir = os.path.join(output_path, 'images', model_name)
        os.makedirs(model_dir, exist_ok=True)
        
        # Export images
        for i in range(len(predictions)):
            # SAR input
            sar_img = (sar_inputs[i].squeeze() * 255).astype(np.uint8)
            sar_path = os.path.join(model_dir, f'sar_{i:03d}.png')
            Image.fromarray(sar_img, mode='L').save(sar_path)
            
            # Ground truth
            gt_img = (np.transpose(targets[i], (1, 2, 0)) * 255).astype(np.uint8)
            gt_path = os.path.join(model_dir, f'ground_truth_{i:03d}.png')
            Image.fromarray(gt_img, mode='RGB').save(gt_path)
            
            # Prediction
            pred_img = (np.transpose(predictions[i], (1, 2, 0)) * 255).astype(np.uint8)
            pred_path = os.path.join(model_dir, f'prediction_{i:03d}.png')
            Image.fromarray(pred_img, mode='RGB').save(pred_path)
            
            # Error map
            error_map = np.abs(predictions[i] - targets[i])
            error_map = np.mean(error_map, axis=0)  # Average across channels
            error_img = (error_map * 255).astype(np.uint8)
            error_path = os.path.join(model_dir, f'error_map_{i:03d}.png')
            Image.fromarray(error_img, mode='L').save(error_path)
        
        # Export metadata
        metadata = {
            'model_name': model_name,
            'performance': performance,
            'num_samples': len(predictions),
            'image_size': predictions[0].shape,
            'export_timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
        }
        
        metadata_path = os.path.join(output_path, 'metadata', f'{model_name}_metadata.json')
        import json
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"   ‚úÖ Exported {len(predictions)} samples")
        print(f"   ‚úÖ Saved metadata to {metadata_path}")
    
    print(f"\n‚úÖ Export completed!")
    print(f"   Output directory: {output_path}")
    print(f"   Images: {os.path.join(output_path, 'images')}")
    print(f"   Metadata: {os.path.join(output_path, 'metadata')}")

# Export results if available
if 'inference_results' in locals() and inference_results:
    print("\nüíæ Exporting inference results...")
    export_inference_results(inference_results, CONFIG['output_path'])
else:
    print("‚ùå Cannot export results - no inference data available")


## Summary and Key Insights

### Inference & Visualization Results:

1. **Model Loading**: ‚úÖ Successfully loaded models for inference
2. **Batch Processing**: ‚úÖ Efficient batch inference with performance monitoring
3. **Visualization**: ‚úÖ Comprehensive visualization of results and performance
4. **Export Capabilities**: ‚úÖ Multiple output formats for different use cases

### Key Findings:

1. **Inference Performance**:
   - Batch processing enables efficient inference
   - Performance monitoring provides real-time metrics
   - Memory management is crucial for large-scale inference

2. **Visualization Quality**:
   - Side-by-side comparisons show model performance
   - Error maps highlight prediction accuracy
   - Performance metrics guide model selection

3. **Export Capabilities**:
   - Multiple image formats for different applications
   - Metadata preservation for reproducibility
   - Organized output structure for easy analysis

### Recommendations:

1. **Performance Optimization**: Use batch processing for efficient inference
2. **Memory Management**: Monitor memory usage and implement cleanup
3. **Visualization**: Use comprehensive visualization for result analysis
4. **Export Strategy**: Choose appropriate formats based on use case

### Next Steps:
- Use the experiment tracking notebook for hyperparameter optimization
- Run the metrics analysis notebook for comprehensive evaluation
- Use the preprocessing notebook for data quality assessment

---
*This notebook provides a production-ready inference pipeline for SAR image colorization. The comprehensive visualization and export capabilities enable effective result analysis and deployment.*
