In [13]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score

`Helper functions`

In [14]:
def load_image(image_path, target_size):
    image = Image.open(image_path).convert('RGB' if target_size[2] == 3 else 'L')
    transform = transforms.Compose([
        transforms.Resize(target_size[:2]),
        transforms.ToTensor()
    ])
    return transform(image)

def load_real_data(data_dir, target_size=(256, 256)):
    landslide_dir = os.path.join(data_dir, 'landslide')
    non_landslide_dir = os.path.join(data_dir, 'non-landslide')

    images = []
    dems = []
    masks = []

    # Load landslide data
    for filename in os.listdir(os.path.join(landslide_dir, 'image')):
        if filename.endswith(".png"):
            image_path = os.path.join(landslide_dir, 'image', filename)
            dem_path = os.path.join(landslide_dir, 'dem', filename)
            mask_path = os.path.join(landslide_dir, 'mask', filename)

            image = load_image(image_path, target_size + (3,))  # RGB (C=3)
            dem = load_image(dem_path, target_size + (1,))      # Grayscale (C=1)
            mask = load_image(mask_path, target_size + (1,))    # Grayscale (C=1)

            images.append(image)
            dems.append(dem)
            masks.append(mask)

    # Load non-landslide data
    for filename in os.listdir(os.path.join(non_landslide_dir, 'image')):
        if filename.endswith(".png"):
            image_path = os.path.join(non_landslide_dir, 'image', filename)
            dem_path = os.path.join(non_landslide_dir, 'dem', filename)

            image = load_image(image_path, target_size + (3,))  # RGB (C=3)
            dem = load_image(dem_path, target_size + (1,))      # Grayscale (C=1)
            mask = torch.zeros((1, *target_size), dtype=torch.float32)  # Mask is all zeros, (C=1)

            images.append(image)
            dems.append(dem)
            masks.append(mask)

    # Stack tensors
    images = torch.stack(images)
    dems = torch.stack(dems)
    masks = torch.stack(masks)

    return images, dems, masks

`Network components`

In [15]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, skip_channels):
        super(AttentionBlock, self).__init__()
        self.g1 = nn.Conv2d(in_channels, skip_channels, kernel_size=1)
        self.x1 = nn.Conv2d(skip_channels, skip_channels, kernel_size=1)
        self.psi = nn.Conv2d(skip_channels, 1, kernel_size=1)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, skip):
        # Upsample g1 to match the spatial dimensions of skip
        g1 = self.upsample(self.g1(x))
        x1 = self.x1(skip)
        psi = self.relu(g1 + x1)  # Element-wise addition
        psi = self.sigmoid(self.psi(psi))  # Generate attention weights
        return skip * psi  # Element-wise multiplication

class UpSampleConcat(nn.Module):
    def __init__(self):
        super(UpSampleConcat, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x, skip):
        x = self.upsample(x)
        return torch.cat([x, skip], dim=1)

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.conv1_rgb = ConvBlock(3, 16)
        self.conv1_dem = ConvBlock(1, 8)
        self.conv2 = ConvBlock(24, 32)
        self.conv3 = ConvBlock(32, 64)
        self.conv4 = ConvBlock(64, 128)
        self.bottleneck = ConvBlock(128, 256)

        self.att4 = AttentionBlock(256, 128)
        self.up4 = UpSampleConcat()
        self.conv5 = ConvBlock(384, 128)

        self.att3 = AttentionBlock(128, 64)
        self.up3 = UpSampleConcat()
        self.conv6 = ConvBlock(192, 64)

        self.att2 = AttentionBlock(64, 32)
        self.up2 = UpSampleConcat()
        self.conv7 = ConvBlock(96, 32)

        self.conv8 = ConvBlock(56, 16)

        self.out = nn.Sequential(
            nn.Conv2d(16, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x_rgb, x_dem):
        c1_rgb = self.conv1_rgb(x_rgb)
        c1_dem = self.conv1_dem(x_dem)
        combined = torch.cat([c1_rgb, c1_dem], dim=1)

        c2 = self.conv2(combined)
        p2 = nn.MaxPool2d(kernel_size=2)(c2)

        c3 = self.conv3(p2)
        p3 = nn.MaxPool2d(kernel_size=2)(c3)

        c4 = self.conv4(p3)
        p4 = nn.MaxPool2d(kernel_size=2)(c4)

        bn = self.bottleneck(p4)

        a4 = self.att4(bn, c4)
        u4 = self.up4(bn, a4)
        c5 = self.conv5(u4)

        a3 = self.att3(c5, c3)
        u3 = self.up3(c5, a3)
        c6 = self.conv6(u3)

        a2 = self.att2(c6, c2)
        u2 = self.up2(c6, a2)
        c7 = self.conv7(u2)

        u1 = torch.cat([c7, combined], dim=1) # upsampling not needed here
        c8 = self.conv8(u1)

        return self.out(c8)

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [17]:
data_dir = 'Bijie_dataset/Bijie_dataset'
images, dems, masks = load_real_data(data_dir)

In [18]:
from sklearn.model_selection import train_test_split

# train-test split
X_train_img, X_test_img, X_train_dem, X_test_dem, y_train, y_test = train_test_split(
    images, dems, masks, test_size=0.2, random_state=42
)

print(X_train_img.shape, X_train_dem[0].shape, y_train[0].shape)

# ####### Using only 10% of data locally
# def subset_data(X_img, X_dem, y, fraction=0.1):
#     subset_size = int(len(X_img) * fraction)
#     return X_img[:subset_size], X_dem[:subset_size], y[:subset_size]

# X_train_img, X_train_dem, y_train = subset_data(X_train_img, X_train_dem, y_train, fraction=0.1)
# X_test_img, X_test_dem, y_test = subset_data(X_test_img, X_test_dem, y_test, fraction=0.1)

print("Training data size:", len(X_train_img))
print("Testing data size:", len(X_test_img))

torch.Size([2218, 3, 256, 256]) torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
Training data size: 2218
Testing data size: 555


In [19]:
# from torch.utils.data import DataLoader, Dataset
# import torch.nn.functional as F

class CustomDataset(Dataset):
    def __init__(self, images, dems, masks):
        self.images = images
        self.dems = dems
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return {
            'image': self.images[idx],
            'dem': self.dems[idx],
            'mask': self.masks[idx]
        }

train_dataset = CustomDataset(X_train_img, X_train_dem, y_train)
val_dataset = CustomDataset(X_test_img, X_test_dem, y_test)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [20]:
model = UNet().to(device)
epochs = 30
best_val_loss = float('inf')
best_model_path = 'best_unet_model.pth'

In [21]:
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train loop
for epoch in range(epochs):
    model.train()
    train_loss = 0.0

    for batch in train_loader:
        images = batch['image'].to(device)
        dems = batch['dem'].to(device)
        masks = batch['mask'].to(device)

        optimizer.zero_grad()
        outputs = model(images, dems).squeeze(1)
        loss = criterion(outputs.unsqueeze(1), masks)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # Validation loop
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(device)
            dems = batch['dem'].to(device)
            masks = batch['mask'].to(device)

            outputs = model(images, dems).squeeze(1)
            loss = criterion(outputs.unsqueeze(1), masks)

            val_loss += loss.item()

            # Collect predictions and targets for metric calculation
            preds = (outputs > 0.5).cpu().numpy().astype(int)  # Thresholding at 0.5
            targets = masks.cpu().numpy().astype(int)

            all_preds.append(preds)
            all_targets.append(targets)

    val_loss /= len(val_loader)

    # Flatten lists for metric calculations
    all_preds = np.concatenate([pred.flatten() for pred in all_preds])
    all_targets = np.concatenate([target.flatten() for target in all_targets])

    # Compute metrics
    precision = precision_score(all_targets, all_preds)
    recall = recall_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    iou = jaccard_score(all_targets, all_preds)

    print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss}, Val Loss: {val_loss}, "
          f"Precision: {precision}, Recall: {recall}, F1 Score: {f1}, IoU: {iou}")

    # Save the model if val loss improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Saved best model with val loss: {best_val_loss:.4f}")

print("Training complete. Best model saved to", best_model_path)

KeyboardInterrupt: 

In [15]:
from torchinfo import summary

model = UNet().to(device)
summary(model, input_size=[(1, 3, 256, 256), (1, 1, 256, 256)]) 

Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [1, 1, 256, 256]          --
├─ConvBlock: 1-1                         [1, 16, 256, 256]         --
│    └─Sequential: 2-1                   [1, 16, 256, 256]         --
│    │    └─Conv2d: 3-1                  [1, 16, 256, 256]         448
│    │    └─BatchNorm2d: 3-2             [1, 16, 256, 256]         32
│    │    └─ReLU: 3-3                    [1, 16, 256, 256]         --
│    │    └─Conv2d: 3-4                  [1, 16, 256, 256]         2,320
│    │    └─BatchNorm2d: 3-5             [1, 16, 256, 256]         32
│    │    └─ReLU: 3-6                    [1, 16, 256, 256]         --
├─ConvBlock: 1-2                         [1, 8, 256, 256]          --
│    └─Sequential: 2-2                   [1, 8, 256, 256]          --
│    │    └─Conv2d: 3-7                  [1, 8, 256, 256]          80
│    │    └─BatchNorm2d: 3-8             [1, 8, 256, 256]          16
│    │    └