<a href="https://colab.research.google.com/github/zztanmayzz/zigzaggerz/blob/main/proposed_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [39]:
import torch
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from pathlib import Path

class FloodSegmentationDataset(Dataset):
    """
    Custom dataset for flood management semantic segmentation
    """
    def __init__(self, image_dir, mask_dir, transform=None, target_transform=None):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.transform = transform
        self.target_transform = target_transform

        # Get all image files
        self.images = list(self.image_dir.glob("*.jpg")) + list(self.image_dir.glob("*.png"))

        # Class mapping
        self.class_colors = {
            (255, 0, 0): 1,    # Buildings - red
            (0, 255, 0): 2,    # Vegetation - green
            (128, 128, 128): 3, # Roads - grey
            (139, 69, 19): 4,  # Bare ground - brown
            (255, 255, 255): 0 # Background - white
        }

    def __len__(self):
        return len(self.images)

    def rgb_to_class_mask(self, rgb_mask):
        """Convert RGB mask to class indices"""
        h, w = rgb_mask.shape[:2]
        class_mask = np.zeros((h, w), dtype=np.long)

        for color, class_id in self.class_colors.items():
            mask = np.all(rgb_mask == color, axis=-1)
            class_mask[mask] = class_id

        return class_mask

    def __getitem__(self, idx):
        # Load image
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')

        # Load corresponding mask (same name, different directory)
        mask_path = self.mask_dir / img_path.name
        mask = Image.open(mask_path).convert('RGB')
        mask = np.array(mask)

        # Convert RGB mask to class indices
        mask = self.rgb_to_class_mask(mask)
        mask = Image.fromarray(mask.astype(np.uint8))

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)

        return image, mask

def create_data_loaders(train_dir, mask_dir, batch_size=4, img_size=224):
    """Create training and validation data loaders"""

    # Data transforms
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    mask_transform = transforms.Compose([
        transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST),
        transforms.ToTensor()
    ])

    # Create dataset
    dataset = FloodSegmentationDataset(
        train_dir, mask_dir,
        transform=train_transform,
        target_transform=mask_transform
    )

    # Split dataset (80% train, 20% val)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

# Example usage
def setup_dataset_structure():
    """
    Create the required directory structure
    """
    directories = [
        "data/satellite_images",    # Original satellite images
        "data/annotations",         # Your corrected segmentation masks
        "data/elevation",           # DEM data
        "data/soil",                # Soil properties
        "models",                   # Trained models
        "outputs"                   # Results and predictions
    ]

    for dir_path in directories:
        os.makedirs(dir_path, exist_ok=True)
        print(f"Created directory: {dir_path}")

if __name__ == "__main__":
    setup_dataset_structure()


Created directory: data/satellite_images
Created directory: data/annotations
Created directory: data/elevation
Created directory: data/soil
Created directory: models
Created directory: outputs


In [40]:
!ls
!cp -r ./image{0..10}.png ./data/satellite_images/
!touch ./data/annotations/image{0..10}.png

 data	       image3.png	 image6.png	   image9.png
 image10.png   image4.png	 image7.png	   models
 image1.png    image5.png	'image8 (1).png'   outputs
 image2.png   'image6 (1).png'	 image8.png	   sample_data
cp: cannot stat './image0.png': No such file or directory


In [41]:
#from google.colab import files
#files.upload()

In [42]:
def fix_color_values(image_path, output_path):
    """
    Fix the color values in annotated images to match exact legend values
    """
    img = Image.open(image_path).convert('RGB')  # Force convert to RGB
    img_array = np.array(img)

    # Color mapping for correction
    color_corrections = {
        # Current -> Target
        (254, 0, 0): (255, 0, 0),      # Buildings: red
        (1, 255, 0): (0, 255, 0),      # Vegetation: green
        (255, 176, 49): (139, 69, 19), # Convert orange areas to bare ground: brown
    }

    # Apply color corrections
    for old_color, new_color in color_corrections.items():
        mask = np.all(img_array == old_color, axis=-1)
        img_array[mask] = new_color

    # Save corrected image
    corrected_img = Image.fromarray(img_array)
    corrected_img.save(output_path)
    print(f"Corrected image saved to: {output_path}")


In [43]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import yaml
from pathlib import Path
import numpy as np
from huggingface_hub import hf_hub_download
#from dataset_preparation import FloodSegmentationDataset, create_data_loaders

class PrithviFloodModel(nn.Module):
    """
    Prithvi-based model for flood management segmentation
    """
    def __init__(self, num_classes=5, pretrained_path=None):
        super(PrithviFloodModel, self).__init__()

        # Download Prithvi model if not available
        if pretrained_path is None:
            try:
                model_path = hf_hub_download(
                    repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
                    filename="Prithvi_100M.pt"
                )
                config_path = hf_hub_download(
                    repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
                    filename="Prithvi_100M_config.yaml"
                )
            except:
                print("Warning: Could not download Prithvi model. Using alternative approach.")
                model_path = None
                config_path = None

        self.num_classes = num_classes

        # For hackathon speed, use a simpler segmentation model
        # Replace with actual Prithvi when available
        from torchvision.models.segmentation import deeplabv3_resnet50
        self.model = deeplabv3_resnet50(pretrained=True)

        # Modify classifier for our classes
        self.model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        return self.model(x)['out']

def calculate_iou(pred, target, num_classes):
    """Calculate Intersection over Union for each class"""
    ious = []
    pred = torch.argmax(pred, dim=1)

    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)

        intersection = (pred_cls & target_cls).sum().float()
        union = (pred_cls | target_cls).sum().float()

        if union == 0:
            iou = 1.0  # Perfect score for classes not present
        else:
            iou = intersection / union
        ious.append(iou.item())

    return ious

def train_model(train_loader, val_loader, num_epochs=50, learning_rate=1e-4):
    """Train the flood segmentation model"""

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on device: {device}")

    # Initialize model
    model = PrithviFloodModel(num_classes=5)
    model = model.to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=255)  # Ignore unknown pixels
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

    best_val_iou = 0.0
    train_losses = []
    val_ious = []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        for batch_idx, (images, masks) in enumerate(train_loader):
            images = images.to(device)
            masks = masks.to(device).long().squeeze(1)  # Remove channel dimension

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if batch_idx % 10 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}')

        # Validation phase
        model.eval()
        val_iou_scores = []

        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device).long().squeeze(1)

                outputs = model(images)
                ious = calculate_iou(outputs, masks, num_classes=5)
                val_iou_scores.append(ious)

        # Calculate average IoU
        mean_ious = np.mean(val_iou_scores, axis=0)
        overall_iou = np.mean(mean_ious[1:])  # Exclude background class

        train_losses.append(train_loss / len(train_loader))
        val_ious.append(overall_iou)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_losses[-1]:.4f}')
        print(f'  Val IoU: {overall_iou:.4f}')
        print(f'  Class IoUs: Background={mean_ious[0]:.3f}, Buildings={mean_ious[1]:.3f}, '
              f'Vegetation={mean_ious[2]:.3f}, Roads={mean_ious[3]:.3f}, Bare={mean_ious[4]:.3f}')

        # Save best model
        if overall_iou > best_val_iou:
            best_val_iou = overall_iou
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_iou': overall_iou,
                'class_ious': mean_ious
            }, 'models/best_flood_model.pth')
            print(f'  New best model saved with IoU: {overall_iou:.4f}')

        scheduler.step()
        print('-' * 50)

    return model, train_losses, val_ious

def main():
    """Main training function"""
    print("Starting Prithvi-based flood segmentation model training...")

    # Create data loaders
    try:
        train_loader, val_loader = create_data_loaders(
            train_dir="data/satellite_images",
            mask_dir="data/annotations",
            batch_size=4,
            img_size=224
        )

        print(f"Training samples: {len(train_loader.dataset)}")
        print(f"Validation samples: {len(val_loader.dataset)}")

        # Train model
        model, train_losses, val_ious = train_model(
            train_loader, val_loader,
            num_epochs=100,  # Increase for better results
            learning_rate=1e-4
        )

        print("Training completed!")
        print(f"Best validation IoU: {max(val_ious):.4f}")

    except Exception as e:
        print(f"Error during training: {e}")
        print("Make sure you have:")
        print("1. Fixed and placed images in data/satellite_images/")
        print("2. Corrected annotations in data/annotations/")
        print("3. All 4 classes present in annotations")

if __name__ == "__main__":
    main()


Starting Prithvi-based flood segmentation model training...
Training samples: 10
Validation samples: 3
Training on device: cpu




Error during training: cannot identify image file 'data/annotations/image2.png'
Make sure you have:
1. Fixed and placed images in data/satellite_images/
2. Corrected annotations in data/annotations/
3. All 4 classes present in annotations


In [44]:
!ls

 data	       image3.png	 image6.png	   image9.png
 image10.png   image4.png	 image7.png	   models
 image1.png    image5.png	'image8 (1).png'   outputs
 image2.png   'image6 (1).png'	 image8.png	   sample_data
