<a href="https://colab.research.google.com/github/vikashkodati/mygig/blob/main/PINNDataPreprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageEnhance
import albumentations as A
from albumentations.pytorch import ToTensorV2
from typing import Tuple, List, Dict, Optional
import os
from pathlib import Path

class SatelliteImagePreprocessor:
    """Preprocessor for satellite images with various enhancement techniques."""

    def __init__(self, target_size: Tuple[int, int] = (256, 256)):
        self.target_size = target_size
        self.normalize_mean = [0.485, 0.456, 0.406]
        self.normalize_std = [0.229, 0.224, 0.225]

    def preprocess_image(self, image_path: str) -> torch.Tensor:
        """Preprocess a single satellite image."""
        # Load image
        image = Image.open(image_path).convert('RGB')

        # Apply preprocessing
        image = self._enhance_image(image)
        image = self._resize_image(image)
        image = self._normalize_image(image)

        return image

    def _enhance_image(self, image: Image.Image) -> Image.Image:
        """Apply image enhancement techniques for satellite imagery."""
        # Enhance contrast
        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(1.2)

        # Enhance sharpness
        enhancer = ImageEnhance.Sharpness(image)
        image = enhancer.enhance(1.1)

        # Enhance brightness
        enhancer = ImageEnhance.Brightness(image)
        image = enhancer.enhance(1.05)

        return image

    def _resize_image(self, image: Image.Image) -> Image.Image:
        """Resize image to target size."""
        return image.resize(self.target_size, Image.Resampling.LANCZOS)

    def _normalize_image(self, image: Image.Image) -> torch.Tensor:
        """Convert to tensor and normalize."""
        # Convert to tensor
        image_tensor = torch.from_numpy(np.array(image)).float() / 255.0

        # Normalize
        for i in range(3):
            image_tensor[:, :, i] = (image_tensor[:, :, i] - self.normalize_mean[i]) / self.normalize_std[i]

        # Convert to channels-first format
        image_tensor = image_tensor.permute(2, 0, 1)

        return image_tensor

class SatelliteDataAugmenter:
    """Data augmentation specifically designed for satellite imagery."""

    def __init__(self, p: float = 0.5):
        self.p = p
        self.augmentation = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.HueSaturationValue(p=0.3),
            A.GaussNoise(p=0.2),
            A.GaussianBlur(p=0.2),
            A.CLAHE(p=0.3),  # Contrast Limited Adaptive Histogram Equalization
            A.RandomGamma(p=0.3),
            ToTensorV2()
        ], p=p)

    def augment_pair(self, before_img: np.ndarray, after_img: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
        """Augment a pair of before/after images consistently."""
        # Apply same augmentation to both images
        augmented = self.augmentation(image=before_img, image2=after_img)

        return augmented['image'], augmented['image2']

class SatelliteVisualizer:
    """Visualization tools for satellite change detection results."""

    def __init__(self):
        self.class_colors = {
            0: [0, 0, 0],      # Background - Black
            1: [255, 0, 0],    # Building - Red
            2: [0, 0, 255],    # Water - Blue
            3: [0, 255, 0],    # Vegetation - Green
            4: [255, 255, 0]   # Debris - Yellow
        }
        self.class_names = ['Background', 'Building', 'Water', 'Vegetation', 'Debris']

    def visualize_change_detection(self, before_img: torch.Tensor, after_img: torch.Tensor,
                                 change_map: torch.Tensor, save_path: Optional[str] = None) -> None:
        """Visualize change detection results."""
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))

        # Convert tensors to numpy arrays
        before_np = self._tensor_to_numpy(before_img)
        after_np = self._tensor_to_numpy(after_img)
        change_np = change_map.detach().cpu().numpy()

        # Original images
        axes[0, 0].imshow(before_np)
        axes[0, 0].set_title('Before Image')
        axes[0, 0].axis('off')

        axes[0, 1].imshow(after_np)
        axes[0, 1].set_title('After Image')
        axes[0, 1].axis('off')

        # Change map
        change_vis = self._create_change_visualization(change_np)
        axes[0, 2].imshow(change_vis)
        axes[0, 2].set_title('Change Detection Map')
        axes[0, 2].axis('off')

        # Class-specific changes
        for i, class_name in enumerate(self.class_names[1:], 1):  # Skip background
            class_mask = change_np[i]
            axes[1, i-1].imshow(class_mask, cmap='hot')
            axes[1, i-1].set_title(f'{class_name} Changes')
            axes[1, i-1].axis('off')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

    def _tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
        """Convert normalized tensor to numpy array for visualization."""
        # Denormalize
        tensor = tensor.clone()
        for i in range(3):
            tensor[i] = tensor[i] * 0.229 + 0.485
            tensor[i+3] = tensor[i+3] * 0.229 + 0.485

        # Convert to numpy and transpose
        numpy_img = tensor.detach().cpu().numpy()
        numpy_img = np.transpose(numpy_img, (1, 2, 0))

        # Clip to valid range
        numpy_img = np.clip(numpy_img, 0, 1)

        return numpy_img

    def _create_change_visualization(self, change_map: np.ndarray) -> np.ndarray:
        """Create a colored visualization of the change map."""
        # Get the class with maximum probability for each pixel
        class_map = np.argmax(change_map, axis=0)

        # Create colored visualization
        colored_map = np.zeros((*class_map.shape, 3), dtype=np.uint8)

        for class_id, color in self.class_colors.items():
            mask = class_map == class_id
            colored_map[mask] = color

        return colored_map

    def plot_training_history(self, history: Dict[str, List[float]], save_path: Optional[str] = None) -> None:
        """Plot training history."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

        # Loss plot
        ax1.plot(history['train_losses'], label='Train Loss', color='blue')
        ax1.plot(history['val_losses'], label='Validation Loss', color='red')
        ax1.set_title('Training History')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Physics loss components (if available)
        if 'physics_losses' in history:
            ax2.plot(history['physics_losses'], label='Physics Loss', color='green')
            ax2.set_title('Physics Constraint Loss')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Physics Loss')
            ax2.legend()
            ax2.grid(True, alpha=0.3)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

class SatelliteMetrics:
    """Metrics for evaluating satellite change detection performance."""

    def __init__(self):
        self.metrics = {}

    def calculate_metrics(self, predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
        """Calculate various metrics for change detection."""
        # Convert to numpy for calculation
        pred_np = predictions.detach().cpu().numpy()
        target_np = targets.detach().cpu().numpy()

        # Get predicted classes
        pred_classes = np.argmax(pred_np, axis=1)
        target_classes = np.argmax(target_np, axis=1)

        # Calculate metrics
        metrics = {
            'accuracy': self._calculate_accuracy(pred_classes, target_classes),
            'precision': self._calculate_precision(pred_classes, target_classes),
            'recall': self._calculate_recall(pred_classes, target_classes),
            'f1_score': self._calculate_f1_score(pred_classes, target_classes),
            'iou': self._calculate_iou(pred_classes, target_classes)
        }

        return metrics

    def _calculate_accuracy(self, pred: np.ndarray, target: np.ndarray) -> float:
        """Calculate pixel-wise accuracy."""
        return float(np.mean(pred == target))

    def _calculate_precision(self, pred: np.ndarray, target: np.ndarray) -> float:
        """Calculate precision for each class."""
        precisions = []
        for class_id in range(5):  # 5 classes
            tp = np.sum((pred == class_id) & (target == class_id))
            fp = np.sum((pred == class_id) & (target != class_id))
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            precisions.append(precision)
        return float(np.mean(precisions))

    def _calculate_recall(self, pred: np.ndarray, target: np.ndarray) -> float:
        """Calculate recall for each class."""
        recalls = []
        for class_id in range(5):  # 5 classes
            tp = np.sum((pred == class_id) & (target == class_id))
            fn = np.sum((pred != class_id) & (target == class_id))
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            recalls.append(recall)
        return float(np.mean(recalls))

    def _calculate_f1_score(self, pred: np.ndarray, target: np.ndarray) -> float:
        """Calculate F1 score."""
        precision = self._calculate_precision(pred, target)
        recall = self._calculate_recall(pred, target)
        return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    def _calculate_iou(self, pred: np.ndarray, target: np.ndarray) -> float:
        """Calculate Intersection over Union (IoU) for each class."""
        ious = []
        for class_id in range(5):  # 5 classes
            intersection = np.sum((pred == class_id) & (target == class_id))
            union = np.sum((pred == class_id) | (target == class_id))
            iou = intersection / union if union > 0 else 0
            ious.append(iou)
        return float(np.mean(ious))

def create_sample_satellite_data(output_dir: str = "sample_satellite_data", num_samples: int = 10):
    """Create realistic sample satellite data for testing."""
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)

    # Create different types of changes
    change_types = [
        "flooding",      # Water expansion
        "construction",  # Building changes
        "vegetation",    # Land cover changes
        "debris",        # Debris accumulation
        "erosion"        # Landform changes
    ]

    for i in range(num_samples):
        change_type = change_types[i % len(change_types)]

        # Create before image with realistic terrain
        before_img = np.random.randint(50, 200, (256, 256, 3), dtype=np.uint8)

        # Add some structure based on change type
        if change_type == "flooding":
            # Add water bodies
            before_img[100:150, 50:200] = [100, 150, 200]  # Blue water
        elif change_type == "construction":
            # Add buildings
            before_img[80:120, 80:120] = [150, 150, 150]  # Gray buildings
        elif change_type == "vegetation":
            # Add vegetation
            before_img[120:180, 100:180] = [50, 150, 50]  # Green vegetation
        elif change_type == "debris":
            # Add debris
            before_img[90:130, 90:130] = [100, 100, 50]  # Brown debris
        elif change_type == "erosion":
            # Add eroded areas
            before_img[140:200, 60:160] = [120, 100, 80]  # Brown eroded soil

        # Create after image with changes
        after_img = before_img.copy()

        # Apply changes based on type
        if change_type == "flooding":
            # Expand water
            after_img[80:170, 30:220] = [80, 120, 180]
        elif change_type == "construction":
            # Add more buildings
            after_img[60:140, 60:140] = [180, 180, 180]
            after_img[160:200, 160:200] = [160, 160, 160]
        elif change_type == "vegetation":
            # Change vegetation
            after_img[100:160, 80:160] = [30, 120, 30]
        elif change_type == "debris":
            # Spread debris
            after_img[70:150, 70:150] = [120, 120, 70]
        elif change_type == "erosion":
            # Expand erosion
            after_img[120:220, 40:180] = [100, 80, 60]

        # Save images
        cv2.imwrite(str(output_path / f"region_{i}_before.jpg"), before_img)
        cv2.imwrite(str(output_path / f"region_{i}_after.jpg"), after_img)

        # Create metadata
        metadata = {
            "change_type": change_type,
            "before_path": f"region_{i}_before.jpg",
            "after_path": f"region_{i}_after.jpg",
            "description": f"Sample {change_type} change detection"
        }

        with open(output_path / f"region_{i}_metadata.json", 'w') as f:
            json.dump(metadata, f, indent=2)

    print(f"Created {num_samples} sample satellite image pairs in {output_dir}")

if __name__ == "__main__":
    # Create sample data
    create_sample_satellite_data()

    # Test preprocessing
    preprocessor = SatelliteImagePreprocessor()
    sample_image = "sample_satellite_data/region_0_before.jpg"

    if os.path.exists(sample_image):
        processed_img = preprocessor.preprocess_image(sample_image)
        print(f"Processed image shape: {processed_img.shape}")

        # Test visualization
        visualizer = SatelliteVisualizer()
        print("Visualization tools ready")

        # Test metrics
        metrics = SatelliteMetrics()
        print("Metrics calculator ready")

        print("All utilities tested successfully!")