In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2


In [3]:
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.psi(F.relu(g1 + x1))
        return x * psi


class AttentionUNet(nn.Module):
    def __init__(self, in_channels=6, out_channels=1, dropout_prob=0.3):
        super(AttentionUNet, self).__init__()

        self.encoder = nn.Sequential(
            self.conv_block(in_channels, 64),
            self.conv_block(64, 128),
            self.conv_block(128, 256),
            self.conv_block(256, 512)
        )

        self.center = self.conv_block(512, 1024)

        self.decoder = nn.Sequential(
            self.up_conv(1024, 512),
            AttentionBlock(512, 512, 256),
            self.conv_block(512, 256),

            self.up_conv(256, 256),
            AttentionBlock(256, 256, 128),
            self.conv_block(256, 128),

            self.up_conv(128, 128),
            AttentionBlock(128, 128, 64),
            self.conv_block(128, 64)
        )

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

        # MC Dropout layers
        self.dropout = nn.Dropout2d(p=dropout_prob)

    def conv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def up_conv(self, in_c, out_c):
        return nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)

    def forward(self, x):
        e1 = self.encoder[0](x)
        e2 = self.encoder[1](e1)
        e3 = self.encoder[2](e2)
        e4 = self.encoder[3](e3)

        center = self.center(e4)
        center = self.dropout(center)  # MC Dropout

        d4 = self.decoder[0](center)
        d4 = torch.cat((d4, e4), dim=1)
        d4 = self.decoder[2](d4)

        d3 = self.decoder[3](d4)
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.decoder[5](d3)

        d2 = self.decoder[6](d3)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.decoder[8](d2)

        d1 = torch.cat((d2, e1), dim=1)
        d1 = self.decoder[9](d1)

        out = self.final(d1)
        return torch.sigmoid(out)


In [4]:
import rasterio
from glob import glob
import os

class SARFloodDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
        self.mask_paths = sorted(glob(os.path.join(mask_dir, "*.png")))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load SAR image (TIF)
        with rasterio.open(self.image_paths[idx]) as src:
            img = src.read()  # (C, H, W) format

        img = np.moveaxis(img, 0, -1)  # Convert to (H, W, C) for Albumentations

        # Load mask (PNG)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)  # (H, W)
        mask = (mask > 127).astype(np.float32)  # Convert to binary mask

        # Apply transformations
        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img, mask = augmented['image'], augmented['mask']

        return img, mask.float()


In [5]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets):
        BCE = F.binary_cross_entropy(inputs, targets, reduction='none')
        p_t = inputs * targets + (1 - inputs) * (1 - targets)
        loss = self.alpha * (1 - p_t) ** self.gamma * BCE
        return loss.mean()

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        return 1 - ((2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth))

def combined_loss(inputs, targets):
    return FocalLoss()(inputs, targets) + DiceLoss()(inputs, targets)


In [6]:
def train_model(model, dataloader, optimizer, num_epochs=20):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        for images, masks in tqdm(dataloader):
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = combined_loss(outputs, masks)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader)}")


In [7]:
def predict_with_uncertainty(f_model, images, n_iter=10):
    f_model.train()  # Keep dropout layers active
    preds = torch.stack([f_model(images) for _ in range(n_iter)], dim=0)
    mean_pred = preds.mean(dim=0)
    uncertainty = preds.var(dim=0)
    return mean_pred, uncertainty


In [8]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.Resize(512, 512),
    A.RandomRotate90(),
    A.HorizontalFlip(p=0.5),  # Use this instead of A.Flip()
    A.VerticalFlip(p=0.5),
    A.Normalize(mean=[0], std=[1]),  # Adjust SAR normalization as needed
    ToTensorV2()
])


In [None]:
import torch

# Auto-detect device (use CPU if no GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize dataset and DataLoader
train_dataset = SARFloodDataset(
    image_dir="/content/drive/MyDrive/train/images",
    mask_dir="/content/drive/MyDrive/train/labels",
    transform=train_transform
)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)

# Initialize model & optimizer on correct device
model = AttentionUNet(in_channels=6).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train model
train_model(model, train_loader, optimizer, num_epochs=10)


  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
