# Wildfire Burnt Area Detection - Quick Start Example

This notebook demonstrates how to use the Bi-temporal Attention U-Net for wildfire burnt area detection in remote sensing images.

## 📋 Contents
1. [Setup and Installation](#setup)
2. [Data Preparation](#data)
3. [Model Training](#training)
4. [Inference](#inference)
5. [Visualization](#visualization)

## 1. Setup and Installation <a id="setup"></a>

In [None]:
# Install required packages (uncomment if needed)
# !pip install torch torchvision opencv-python albumentations pillow matplotlib seaborn scikit-learn pyyaml tqdm

import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import yaml
from pathlib import Path

# Add project root to path
project_root = Path('..').resolve()
sys.path.append(str(project_root))

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")

## 2. Data Preparation <a id="data"></a>

In [None]:
from utils.data_utils import analyze_dataset, check_file_naming, visualize_random_samples

# Set your data paths
data_dir = "path/to/your/wildfire_data"  # Update this path
pre_dir = os.path.join(data_dir, "pre_fire")
post_dir = os.path.join(data_dir, "post_fire")
mask_dir = os.path.join(data_dir, "burnt_masks")

# Check if directories exist
for directory in [pre_dir, post_dir]:
    if not os.path.exists(directory):
        print(f"⚠️  Directory not found: {directory}")
        print("Please update the data_dir path above")
    else:
        print(f"✅ Found directory: {directory}")

# Analyze dataset (uncomment when you have data)
# stats = analyze_dataset(pre_dir, post_dir, mask_dir)
# naming = check_file_naming(pre_dir, post_dir, mask_dir)

In [None]:
# Visualize random samples from your wildfire dataset
# visualize_random_samples(pre_dir, post_dir, mask_dir, n_samples=2)

## 3. Model Training <a id="training"></a>

In [None]:
from models.attention_unet import get_model
from datasets.bitemporal_dataset import BiTemporalDataset
from utils.losses import get_loss_function
from utils.metrics import calculate_metrics
from torch.utils.data import DataLoader

# Load configuration
with open('../configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"Model type: {config['model']['type']}")
print(f"Image size: {config['data']['image_size']}")
print(f"Batch size: {config['training']['batch_size']}")

In [None]:
# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_model(config['model']['type'], **config['model']['params'])
model.to(device)

print(f"Model created: {config['model']['type']}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Device: {device}")

In [None]:
# Test model with dummy data
batch_size = 64
image_size = config['data']['image_size']

# Create dummy input
x_pre = torch.randn(batch_size, 3, image_size[0], image_size[1]).to(device)
x_post = torch.randn(batch_size, 3, image_size[0], image_size[1]).to(device)

# Test forward pass
model.eval()
with torch.no_grad():
    if 'bitemporal' in config['model']['type']:
        output = model(x_pre, x_post)
    else:
        # Concatenate for single-input models
        x_concat = torch.cat([x_pre, x_post], dim=1)
        output = model(x_concat)

print(f"✅ Model test successful!")
print(f"Input shape: {x_pre.shape} + {x_post.shape}")
print(f"Output shape: {output.shape}")

In [None]:
# Create dataset and dataloader (uncomment when you have data)
# dataset = BiTemporalDataset(
#     pre_dir=pre_dir,
#     post_dir=post_dir,
#     mask_dir=mask_dir,
#     image_size=config['data']['image_size'],
#     normalize=config['data']['normalize']
# )

# dataloader = DataLoader(
#     dataset,
#     batch_size=config['training']['batch_size'],
#     shuffle=True,
#     num_workers=2
# )

# print(f"Wildfire dataset size: {len(dataset)}")
# print(f"Batch size: {config['training']['batch_size']}")
# print(f"Number of batches: {len(dataloader)}")

## 4. Training Loop (Simplified)

In [None]:
# Setup training components
loss_fn = get_loss_function(
    config['loss']['type'], 
    **config['loss']['params']
)

optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=config['training']['lr'],
    weight_decay=config['training']['weight_decay']
)

print("Training components initialized:")
print(f"Loss function: {config['loss']['type']}")
print(f"Optimizer: Adam (lr={config['training']['lr']})")

In [None]:
# Simplified training loop (for demonstration)
def train_one_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch_idx, (pre_imgs, post_imgs, targets) in enumerate(dataloader):
        pre_imgs = pre_imgs.to(device)
        post_imgs = post_imgs.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        if 'bitemporal' in config['model']['type']:
            outputs = model(pre_imgs, post_imgs)
        else:
            concat_imgs = torch.cat([pre_imgs, post_imgs], dim=1)
            outputs = model(concat_imgs)
        
        # Calculate loss
        loss = loss_fn(outputs, targets)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}')
    
    return total_loss / len(dataloader)

# For actual training, uncomment:
# num_epochs = 5
# for epoch in range(num_epochs):
#     avg_loss = train_one_epoch(model, dataloader, loss_fn, optimizer, device)
#     print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}')

print("Training function defined. Uncomment the training loop to start training.")

## 5. Inference <a id="inference"></a>

In [None]:
# Load a pretrained model (if available)
# model_path = "path/to/your/trained_model.pth"
# if os.path.exists(model_path):
#     checkpoint = torch.load(model_path, map_location=device)
#     model.load_state_dict(checkpoint)
#     print(f"✅ Model loaded from {model_path}")
# else:
#     print("No pretrained model found. Using randomly initialized model for demonstration.")

model.eval()
print("Model set to evaluation mode")

In [None]:
# Inference function
def predict_burnt_area(model, pre_img, post_img, device):
    """
    Predict burnt areas between pre-fire and post-fire images
    
    Args:
        model: Trained model
        pre_img: Pre-fire image tensor [1, 3, H, W]
        post_img: Post-fire image tensor [1, 3, H, W]
        device: Device to run inference on
    
    Returns:
        prediction: Binary burnt area mask [H, W]
        probability: Burnt area probability map [H, W]
    """
    model.eval()
    
    with torch.no_grad():
        pre_img = pre_img.to(device)
        post_img = post_img.to(device)
        
        # Forward pass
        if 'bitemporal' in config['model']['type']:
            output = model(pre_img, post_img)
        else:
            concat_img = torch.cat([pre_img, post_img], dim=1)
            output = model(concat_img)
        
        # Convert to probability
        if output.shape[1] == 1:  # Binary segmentation
            probability = torch.sigmoid(output).cpu().numpy()[0, 0]
        else:  # Multi-class
            probability = torch.softmax(output, dim=1).cpu().numpy()[0, 1]
        
        # Binary prediction
        prediction = (probability > 0.5).astype(np.uint8)
        
    return prediction, probability

print("Wildfire burnt area inference function defined")

In [None]:
# Demo inference with dummy data
# Create dummy test images
test_pre = torch.randn(1, 3, *config['data']['image_size'])
test_post = torch.randn(1, 3, *config['data']['image_size'])

# Run wildfire burnt area detection
pred, prob = predict_burnt_area(model, test_pre, test_post, device)

print(f"✅ Wildfire detection inference completed!")
print(f"Prediction shape: {pred.shape}")
print(f"Probability range: [{prob.min():.3f}, {prob.max():.3f}]")
print(f"Burnt pixels: {np.sum(pred)} / {pred.size} ({np.sum(pred)/pred.size*100:.1f}%)")

## 6. Visualization <a id="visualization"></a>

In [None]:
# Visualize prediction results
def visualize_prediction(pre_img, post_img, prediction, probability, title="Wildfire Burnt Area Detection Result"):
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    # Convert tensors to numpy for visualization
    if isinstance(pre_img, torch.Tensor):
        pre_np = pre_img.cpu().numpy()[0].transpose(1, 2, 0)
        pre_np = (pre_np - pre_np.min()) / (pre_np.max() - pre_np.min())
    else:
        pre_np = pre_img
    
    if isinstance(post_img, torch.Tensor):
        post_np = post_img.cpu().numpy()[0].transpose(1, 2, 0)
        post_np = (post_np - post_np.min()) / (post_np.max() - post_np.min())
    else:
        post_np = post_img
    
    # Pre-fire image
    axes[0].imshow(pre_np)
    axes[0].set_title('Pre-fire')
    axes[0].axis('off')
    
    # Post-fire image
    axes[1].imshow(post_np)
    axes[1].set_title('Post-fire')
    axes[1].axis('off')
    
    # Probability map
    im2 = axes[2].imshow(probability, cmap='hot', vmin=0, vmax=1)
    axes[2].set_title('Burnt Area Probability')
    axes[2].axis('off')
    plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
    
    # Binary prediction overlay
    axes[3].imshow(post_np)
    # Create red overlay for burnt areas
    overlay = np.zeros_like(post_np)
    overlay[prediction == 1] = [1, 0, 0]  # Red for burnt areas
    axes[3].imshow(overlay, alpha=0.6)
    axes[3].set_title('Burnt Area Overlay')
    axes[3].axis('off')
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

# Visualize the demo result
visualize_prediction(test_pre, test_post, pred, prob, "Wildfire Burnt Area Demo Result")

In [None]:
# Example of how to process real images (when you have data)
from utils.visualize import save_prediction_overlay

def process_image_pair(pre_path, post_path, model, device, config):
    """
    Process a single image pair for wildfire burnt area detection
    """
    from torchvision import transforms
    
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize(config['data']['image_size']),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) if config['data']['normalize'] else transforms.Lambda(lambda x: x)
    ])
    
    # Load and preprocess images
    pre_img = Image.open(pre_path).convert('RGB')
    post_img = Image.open(post_path).convert('RGB')
    
    pre_tensor = transform(pre_img).unsqueeze(0)
    post_tensor = transform(post_img).unsqueeze(0)
    
    # Run wildfire detection inference
    prediction, probability = predict_burnt_area(model, pre_tensor, post_tensor, device)
    
    # Visualize
    visualize_prediction(pre_tensor, post_tensor, prediction, probability, 
                        f"Wildfire Detection Result: {os.path.basename(pre_path)}")
    
    return prediction, probability

# Example usage (uncomment when you have actual image files):
# pre_fire_path = "path/to/pre_fire_image.jpg"
# post_fire_path = "path/to/post_fire_image.jpg"
# 
# if os.path.exists(pre_fire_path) and os.path.exists(post_fire_path):
#     pred, prob = process_image_pair(pre_fire_path, post_fire_path, model, device, config)
#     print(f"Wildfire detection completed for {os.path.basename(pre_fire_path)}")

print("Wildfire image processing function defined. Uncomment to process real images.")

## 🎯 Summary

This notebook demonstrated:

1. **Setup**: Loading the Bi-temporal Attention U-Net framework for wildfire detection
2. **Data Analysis**: Tools to analyze your wildfire dataset structure
3. **Model Creation**: Instantiating different model architectures for burnt area detection
4. **Training**: Basic training loop structure for wildfire detection models
5. **Inference**: Running burnt area detection on pre/post-fire image pairs
6. **Visualization**: Displaying results with overlays and probability maps

## 🚀 Next Steps

1. **Prepare your data**: Organize into pre_fire/post_fire/burnt_masks directories
2. **Update paths**: Modify the data directory paths in this notebook
3. **Configure model**: Adjust `configs/config.yaml` for your specific wildfire detection needs
4. **Train model**: Use `scripts/train.py` for full training pipeline
5. **Evaluate**: Use `scripts/evaluate.py` for comprehensive evaluation

## 📚 Documentation

- See `README.md` for detailed setup instructions
- Check `configs/config.yaml` for all configuration options
- Explore `utils/` for additional functionality

Happy wildfire detection! 🛰️🔥🔍