# Hurricane Damage Classification - Training Notebook

This notebook trains a CNN classifier to detect hurricane damage from satellite images using transfer learning.

**Steps:**
1. Setup & Dependencies
2. Download Dataset from Kaggle
3. Data Preprocessing & Augmentation
4. Model Training
5. Evaluation

## 1. Setup & Dependencies

In [None]:
# Install required packages (uncomment if running on Colab)
# !pip install kagglehub torch torchvision tqdm pillow matplotlib

In [None]:
import os
import shutil
import random
from pathlib import Path
from typing import Tuple, Optional

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

import kagglehub
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Download Dataset from Kaggle

In [None]:
def download_dataset(target_dir: Optional[str] = None) -> str:
    """Download the hurricane damage satellite images dataset from Kaggle."""
    print("Downloading Hurricane Damage Satellite Images dataset...")
    path = kagglehub.dataset_download("kmader/satellite-images-of-hurricane-damage")
    print(f"Dataset downloaded to: {path}")
    
    if target_dir:
        target_path = Path(target_dir)
        target_path.mkdir(parents=True, exist_ok=True)
        
        src_path = Path(path)
        for item in src_path.rglob("*"):
            if item.is_file():
                relative = item.relative_to(src_path)
                dest = target_path / relative
                dest.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy2(item, dest)
        
        print(f"Dataset copied to: {target_dir}")
        return target_dir
    
    return path

# Download the dataset
DATA_DIR = "./data"
dataset_path = download_dataset(DATA_DIR)

In [None]:
def explore_dataset(data_path: str) -> dict:
    """Explore the dataset structure and return statistics."""
    data_path = Path(data_path)
    stats = {
        "total_images": 0,
        "classes": {},
        "image_formats": set(),
        "sample_sizes": []
    }
    
    for root, dirs, files in os.walk(data_path):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
                stats["total_images"] += 1
                stats["image_formats"].add(file.split('.')[-1].lower())
                
                class_name = Path(root).name
                if class_name not in stats["classes"]:
                    stats["classes"][class_name] = 0
                stats["classes"][class_name] += 1
                
                if len(stats["sample_sizes"]) < 10:
                    try:
                        img_path = os.path.join(root, file)
                        with Image.open(img_path) as img:
                            stats["sample_sizes"].append(img.size)
                    except Exception as e:
                        print(f"Could not read {file}: {e}")
    
    stats["image_formats"] = list(stats["image_formats"])
    return stats

# Explore the dataset
stats = explore_dataset(dataset_path)
print(f"Total images: {stats['total_images']}")
print(f"Image formats: {stats['image_formats']}")
print(f"Classes: {stats['classes']}")
print(f"Sample image sizes: {stats['sample_sizes']}")

## 3. Data Preprocessing & Augmentation

In [None]:
def prepare_data_splits(
    data_path: str,
    output_dir: str,
    train_ratio: float = 0.7,
    val_ratio: float = 0.15,
    test_ratio: float = 0.15,
    seed: int = 42
) -> Tuple[str, str, str]:
    """Split the dataset into train, validation, and test sets."""
    random.seed(seed)
    
    data_path = Path(data_path)
    output_path = Path(output_dir)
    
    train_dir = output_path / "train"
    val_dir = output_path / "val"
    test_dir = output_path / "test"
    
    # Find all class directories
    class_dirs = []
    for item in data_path.rglob("*"):
        if item.is_dir():
            images = list(item.glob("*.jpeg")) + list(item.glob("*.jpg")) + \
                     list(item.glob("*.png")) + list(item.glob("*.tif"))
            if images:
                class_dirs.append(item)
    
    if not class_dirs:
        all_images = []
        for ext in ['*.jpeg', '*.jpg', '*.png', '*.tif', '*.tiff']:
            all_images.extend(data_path.rglob(ext))
        
        if all_images:
            print(f"Found {len(all_images)} images in flat structure")
            classes = set()
            for img in all_images:
                parent = img.parent.name
                classes.add(parent)
            class_dirs = [data_path / c for c in classes if (data_path / c).exists()]
    
    print(f"Found {len(class_dirs)} class directories")
    
    for class_dir in tqdm(class_dirs, desc="Processing classes"):
        class_name = class_dir.name
        
        images = []
        for ext in ['*.jpeg', '*.jpg', '*.png', '*.tif', '*.tiff']:
            images.extend(class_dir.glob(ext))
        
        if not images:
            continue
            
        random.shuffle(images)
        
        n = len(images)
        train_end = int(n * train_ratio)
        val_end = train_end + int(n * val_ratio)
        
        splits = [
            (images[:train_end], train_dir / class_name),
            (images[train_end:val_end], val_dir / class_name),
            (images[val_end:], test_dir / class_name)
        ]
        
        for split_images, split_dir in splits:
            split_dir.mkdir(parents=True, exist_ok=True)
            for img_path in split_images:
                dest = split_dir / img_path.name
                shutil.copy2(img_path, dest)
    
    print(f"Train: {train_dir}")
    print(f"Validation: {val_dir}")
    print(f"Test: {test_dir}")
    
    return str(train_dir), str(val_dir), str(test_dir)

# Create data splits
SPLITS_DIR = "./data_splits"
train_dir, val_dir, test_dir = prepare_data_splits(dataset_path, SPLITS_DIR)

In [None]:
# Define image transforms
IMAGE_SIZE = 224
BATCH_SIZE = 32

train_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dir, transform=val_transforms)
test_dataset = datasets.ImageFolder(test_dir, transform=val_transforms)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Classes: {train_dataset.classes}")

In [None]:
# Visualize some samples
def show_samples(dataset, num_samples=8):
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    for i, ax in enumerate(axes.flat):
        if i >= num_samples:
            break
        img, label = dataset[i]
        img = img * std + mean  # Denormalize
        img = img.permute(1, 2, 0).numpy().clip(0, 1)
        ax.imshow(img)
        ax.set_title(dataset.classes[label])
        ax.axis('off')
    plt.tight_layout()
    plt.show()

show_samples(train_dataset)

## 4. Model Definition & Training

In [None]:
class HurricaneDamageClassifier(nn.Module):
    """CNN classifier using transfer learning."""
    
    def __init__(
        self,
        num_classes: int = 2,
        backbone: str = "resnet18",
        pretrained: bool = True,
        dropout_rate: float = 0.5,
        freeze_backbone: bool = False
    ):
        super().__init__()
        
        self.backbone_name = backbone
        self.num_classes = num_classes
        
        # Load pre-trained backbone
        if backbone == "resnet18":
            weights = models.ResNet18_Weights.DEFAULT if pretrained else None
            self.backbone = models.resnet18(weights=weights)
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        elif backbone == "resnet50":
            weights = models.ResNet50_Weights.DEFAULT if pretrained else None
            self.backbone = models.resnet50(weights=weights)
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        elif backbone == "efficientnet_b0":
            weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
            self.backbone = models.efficientnet_b0(weights=weights)
            num_features = self.backbone.classifier[1].in_features
            self.backbone.classifier = nn.Identity()
        else:
            raise ValueError(f"Unknown backbone: {backbone}")
        
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
        
        # Custom classifier head
        self.classifier = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout_rate / 2),
            nn.Linear(256, num_classes)
        )
        
        self._init_classifier_weights()
    
    def _init_classifier_weights(self):
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

# Create model
NUM_CLASSES = len(train_dataset.classes)
model = HurricaneDamageClassifier(
    num_classes=NUM_CLASSES,
    backbone="resnet18",
    pretrained=True,
    dropout_rate=0.5,
    freeze_backbone=False
).to(device)

print(f"Model created with backbone: {model.backbone_name}")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Training configuration
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': loss.item(), 'acc': 100. * correct / total})
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def validate(model, dataloader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

In [None]:
# Training loop
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc = 0.0
best_model_path = "best_model.pth"

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 30)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Record history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"âœ“ Saved best model with val acc: {val_acc:.2f}%")

print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.2f}%")

## 5. Evaluation

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history['train_loss'], label='Train')
ax1.plot(history['val_loss'], label='Validation')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend()

ax2.plot(history['train_acc'], label='Train')
ax2.plot(history['val_acc'], label='Validation')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training & Validation Accuracy')
ax2.legend()

plt.tight_layout()
plt.show()

In [None]:
# Load best model and evaluate on test set
model.load_state_dict(torch.load(best_model_path))
test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")

In [None]:
# Confusion matrix and classification report
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, preds = outputs.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

# Print classification report
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=train_dataset.classes))

# Plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(cm, cmap='Blues')
ax.set_xticks(range(len(train_dataset.classes)))
ax.set_yticks(range(len(train_dataset.classes)))
ax.set_xticklabels(train_dataset.classes)
ax.set_yticklabels(train_dataset.classes)
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
ax.set_title('Confusion Matrix')

for i in range(len(train_dataset.classes)):
    for j in range(len(train_dataset.classes)):
        ax.text(j, i, cm[i, j], ha='center', va='center', color='white' if cm[i, j] > cm.max()/2 else 'black')

plt.colorbar(im)
plt.tight_layout()
plt.show()

In [None]:
# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'classes': train_dataset.classes,
    'backbone': model.backbone_name,
    'num_classes': model.num_classes,
    'history': history
}, 'hurricane_classifier_final.pth')

print("Model saved to hurricane_classifier_final.pth")