<a href="https://colab.research.google.com/github/zztanmayzz/zigzaggerz/blob/main/segmentation_indian_cities.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
#!pip install kaggle
#from google.colab import files

# Upload kaggle.json from your local machine (get from Kaggle account)
#files.upload()

In [12]:
#!mkdir -p ~/.kaggle
#!mv "kaggle(1).json" ~/.kaggle/kaggle.json
#!chmod 600 ~/.kaggle/kaggle.json

In [10]:
#!kaggle datasets download -d balraj98/deepglobe-road-extraction-dataset
#!unzip deepglobe-road-extraction-dataset.zip -d deepglobe_dataset

Dataset URL: https://www.kaggle.com/datasets/balraj98/deepglobe-road-extraction-dataset
License(s): other
deepglobe-road-extraction-dataset.zip: Skipping, found more recently modified local copy (use --force to force download)
Archive:  deepglobe-road-extraction-dataset.zip
replace deepglobe_dataset/class_dict.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


In [10]:
# Install dependencies (run in notebook or terminal)
!pip install segmentation-models-pytorch torch torchvision opencv-python matplotlib albumentations

import torch
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader, Dataset
import cv2
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2


# Define dataset class for loading images and masks
class SatelliteSegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)  # assuming single channel mask
        # normalize and preprocess here
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        image = image.transpose(2, 0, 1) / 255.0  # HWC to CHW and normalize
        mask = np.array(mask)
        return torch.tensor(image, dtype=torch.float32), torch.tensor(mask, dtype=torch.long)


# Define transforms for training and inference
def get_transforms():
    return A.Compose([
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.Normalize(),
        ToTensorV2(),
    ])


# Define model (e.g., DeepLabV3+ with ResNet encoder)
model = smp.DeepLabV3Plus(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=4  # background + roads + water + concrete/soil
)


# Define loss and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


# Training loop (simplified)
def train_epoch(dataloader, model, loss_fn, optimizer):
    model.train()
    for images, masks in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()


# Visualization function to display segmented features
def visualize_segmentation(image, mask_pred):
    # map classes to colors
    colors = {
        0: [0, 0, 0],         # background black
        1: [255, 0, 0],       # roads red
        2: [0, 0, 255],       # water blue
        3: [128, 128, 128]    # concrete gray or soil brown as needed
    }
    color_mask = np.zeros((mask_pred.shape[0], mask_pred.shape[1], 3), dtype=np.uint8)
    for cls, color in colors.items():
        color_mask[mask_pred == cls] = color
    overlay = cv2.addWeighted(image, 0.7, color_mask, 0.3, 0)
    plt.imshow(overlay)
    plt.axis('off')
    plt.show()


def predict_mask(model, image_path, transform):
    model.eval()
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Apply same preprocessing for inference
    augmented = transform(image=image)
    input_tensor = augmented['image'].unsqueeze(0)  # batch dimension

    with torch.no_grad():
        output = model(input_tensor)
        pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

    return image, pred_mask


def visualize_side_by_side(image, mask_pred):
    # Create color overlay for mask
    colors = {
        0: [0, 0, 0],         # background black
        1: [255, 0, 0],       # roads red
        2: [0, 0, 255],       # water blue
        3: [128, 128, 128]    # concrete gray
    }
    color_mask = np.zeros((mask_pred.shape[0], mask_pred.shape[1], 3), dtype=np.uint8)
    for cls, color in colors.items():
        color_mask[mask_pred == cls] = color
    overlay = cv2.addWeighted(image, 0.7, color_mask, 0.3, 0)

    # Plot side by side
    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    plt.title("Original Image")
    plt.imshow(image)
    plt.axis('off')

    plt.subplot(1,2,2)
    plt.title("Segmentation Overlay")
    plt.imshow(overlay)
    plt.axis('off')

    plt.show()


# Example usage:
test_transform = get_transforms()  # same transform used in training

image_path = "path/to/test_image.png"  # Replace with your test image path
image, pred_mask = predict_mask(model, image_path, test_transform)
visualize_side_by_side(image, pred_mask)
