**Last Updated**: *29 May 2025*

**Application 2D CNN**

In [None]:
# === Modules ===
import numpy as np
from osgeo import gdal
import pandas as pd
import torch
import torch.nn as nn # nn models
import torch.nn.functional as F # activation functions (incl. ReLu)
from skimage.segmentation import flood
from tqdm import tqdm
from numpy.lib.stride_tricks import as_strided
import csv
from skimage.measure import label, regionprops
from scipy.ndimage import binary_erosion
from itertools import product

In [2]:
# === Functions ===
def load_data(file_path):
    """
    Input: DEM (raster) file path 
    Output: DEM (array) and geotransform information
    """
    try:
        gdal.UseExceptions()
        data = gdal.Open(file_path)
        if data is None:
            raise FileNotFoundError(f"Could not open file: {file_path}")
        data_array = data.ReadAsArray()
        geotransform = data.GetGeoTransform()
        return data_array, geotransform
    except Exception as e:
        raise RuntimeError(f"Error loading DEM: {e}")

def normalize_data(data, min, max):
    """
    Input: DEM (array)
    Output: Normalized DEM (array)
    """
    range = max - min
    return (data - min) / range   # 0-1 range

def detect_local_minima(dem_norm, elevation_range, window_size=5):
    """
    Detects local minima based on elevation and slope criteria.
    Returns a refined mask based on centrality and neighboring structure.
    """
    rows, cols = dem_norm.shape
    shape = (rows - window_size + 1, cols - window_size + 1, window_size, window_size)
    strides = (dem_norm.strides[0], dem_norm.strides[1],
               dem_norm.strides[0], dem_norm.strides[1])

    dem_sliding_windows = as_strided(dem_norm, shape, strides)
    local_minima = np.percentile(dem_sliding_windows, 40, axis=(2, 3)) 
    

    center_values = dem_norm[window_size // 2:rows - window_size // 2,
                             window_size // 2:cols - window_size // 2]

    depth = np.ptp(dem_sliding_windows, axis=(2, 3))

    mask = (center_values <= local_minima) & (depth >= 1.0 / elevation_range)   # crit1 & crit2

    labeled_mask = label(mask.astype(bool), connectivity=1)
    central_mask = np.zeros_like(mask, dtype=bool)

    for region in regionprops(labeled_mask):
        area = region.area
        coords = region.coords
        
        if area < 2:    
            continue    # crit3
        
        minr, minc, maxr, maxc = region.bbox
        submask = np.zeros((maxr - minr, maxc - minc), dtype=bool)
        submask[coords[:, 0] - minr, coords[:, 1] - minc] = True

        if area > 30:
            structure = np.array([[0, 1, 0],
                                  [1, 1, 1],
                                  [0, 1, 0]], dtype=bool)
            eroded = binary_erosion(submask, structure=structure)
            
            if np.any(eroded):
                submask = eroded
            else:
                continue

        for subregion in regionprops(label(submask)):
            centroid = subregion.centroid
            r, c = int(round(centroid[0] + minr)), int(round(centroid[1] + minc))

            if 0 <= r < labeled_mask.shape[0] and 0 <= c < labeled_mask.shape[1]:
                central_mask[r, c] = True

    return central_mask

def extract_coordinates(mask, geotransform, window_size):
    """
    Input: Local minima (mask), geotransform information and focal window size
    Output: Local minima (coordinates)
    """
    coordinates = []
    for i, j in zip(*np.nonzero(mask)):
        center_x = window_size // 2 + j
        center_y = window_size // 2 + i
        geo_x = geotransform[0] + center_x * geotransform[1] + center_y * geotransform[2]
        geo_y = geotransform[3] + center_x * geotransform[4] + center_y * geotransform[5]
        coordinates.append((len(coordinates) + 1, geo_x, geo_y))
    return coordinates

def save_to_csv(data, filename, headers):
    """
    Input: Local minima (coordinates), filename and headers
    Output: Local minima (CSV)
    """
    try:
        with open(filename, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(headers)
            writer.writerows(data)
        print(f"Results saved to {filename}")
    except Exception as e:
        raise RuntimeError(f"Error writing to CSV: {e}")

def extract_indices(depression_df, geotransform):
    """
    Input: Depressions (dataframe) and geotransform information
    Output: Depressions (dataframe) with location in pixel indices
    """

    # Extract geotransform parameters
    x_origin = geotransform[0]
    y_origin = geotransform[3]
    pixel_width = geotransform[1]
    pixel_height = geotransform[5]

    # Convert to raster pixel indices
    depression_df["col"] = ((depression_df["longitude"] - x_origin) / pixel_width).astype(int)
    depression_df["row"] = ((depression_df["latitude"] - y_origin) / pixel_height).astype(int)

    return depression_df

# CNN
class CNN(nn.Module):
    def __init__(self, in_channels=3, num_classes=1):
        super(CNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)

        # Batch normalization layers
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
       
        # Max pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Global average pooling layer
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Fully connected layers
        self.fc1 = nn.Linear(256*1*1, 128)
        self.fc2 = nn.Linear(128, 64)
        self.output = nn.Linear(64, num_classes)

        # Dropout layers
        self.dropout = nn.Dropout(0.4) # at least 0.4 for robustness

    def forward(self, x):
        # Convolutional layers with ReLU activation & batch normalization 
        # & Max pooling layers
        x = F.relu(self.bn1(self.conv1(x)))         # (batch, 32, 50, 50)
        x = F.relu(self.bn2(self.conv2(x)))         # (batch, 64, 50, 50)
        x = self.pool(x)                            # (batch, 64, 25, 25)
        
        x = F.relu(self.bn3(self.conv3(x)))         # (batch, 128, 25, 25)
        x = F.relu(self.bn4(self.conv4(x)))         # (batch, 256, 25, 25)
        x = self.pool(x)                            # (batch, 256, 12, 12)

        # Average pooling layer
        x = self.global_pool(x)                     # (batch, 256, 1, 1)
        
        # Flatten
        x = torch.flatten(x, 1)                     # (batch, 256)

        # Fully connected layers with ReLU activation 
        # & Dropout layers
        x = self.dropout(F.relu(self.fc1(x)))      # (batch, 128)
        x = self.dropout(F.relu(self.fc2(x)))      # (batch, 64)
        x = self.output(x)                         # (batch, 1)

        return x

In [3]:
# === Load the model ===
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Device will run on {device}")

model = CNN().to(device)
model.load_state_dict(torch.load("cnn_2d_final_run.pt"))

Device will run on cuda


<All keys matched successfully>

In [None]:
# === Loop over all files ===
tile_ids = [f'{i}' for i in range(1, 46)]
window_size_local_minima = 9
window_size_model = 75

# New functions
def process_local_minima_tile(dem_path_5m, output_csv, window_size):
    dem_array, geotransform = load_data(dem_path_5m)
    dem_array = np.where(dem_array < -50, np.nan, dem_array)
    dem_array = normalize_data(dem_array, -50, 233.860001)
    if dem_array.shape[0] < 25000 and dem_array.shape[1] < 12500:
        print("Processing all at once!")
        split = 1
    else: 
        print("Processing in minitiles!")
        split = 2
    chunk_size = dem_array.shape[0] // split
    margin = window_size // 2
    results = []

    nrows, ncols = dem_array.shape
    total_chunks = ((nrows + chunk_size - 1) // chunk_size) * ((ncols + chunk_size - 1) // chunk_size)
    row_indices = range(0, nrows, chunk_size)
    col_indices = range(0, ncols, chunk_size)

    for i, j in tqdm(product(row_indices, col_indices), total=total_chunks, desc="Processing minitiles", unit="minitile"):
        i_start = max(i - margin, 0)
        i_end = min(i + chunk_size + margin, nrows)
        j_start = max(j - margin, 0)
        j_end = min(j + chunk_size + margin, ncols)
        if i_end <= i_start or j_end <= j_start:
            continue
        dem_chunk = dem_array[i_start:i_end, j_start:j_end]
        if np.all(np.isnan(dem_chunk)):
            continue
        elevation_range = 283.860001
        mask = detect_local_minima(dem_chunk, elevation_range, window_size)       
        offset_transform = (
            geotransform[0] + j_start * geotransform[1],
            geotransform[1], 0,
            geotransform[3] + i_start * geotransform[5],
            0, geotransform[5]
        )
        coords = extract_coordinates(mask, offset_transform, window_size)
        results.extend(coords)

    unique_results = list({(c[1], c[2]): c for c in results}.values())
    save_to_csv(unique_results, output_csv, ['id', 'longitude', 'latitude'])

def run_model_on_minima(dem_path_1m, slope_path_1m, local_min_csv_path, output_csv, model, window_size):
    dem_array, dem_transform = load_data(dem_path_1m)
    dem_array = np.where(dem_array < -50, np.nan, dem_array)
    dem_array = normalize_data(dem_array, -50, 233.860001)
    slope_array, _ = load_data(slope_path_1m)
    slope_array = np.where(slope_array < 0, np.nan, slope_array)
    slope_array = normalize_data(slope_array, 0, 87.252808)
    
    local_min_df = pd.read_csv(local_min_csv_path, dtype={1: int, 2: int})
    local_min_df = extract_indices(local_min_df, dem_transform)
    
    all_probs = []
    half_window = window_size // 2
    model.eval()

    with torch.no_grad():
        for idx, row in tqdm(local_min_df.iterrows(), total=len(local_min_df)):
            row_idx, col_idx = row["row"], row["col"]
            row_min, row_max = max(row_idx - half_window, 0), min(row_idx + half_window, dem_array.shape[0])
            col_min, col_max = max(col_idx - half_window, 0), min(col_idx + half_window, dem_array.shape[1])
            
            window_dem = dem_array[row_min:row_max, col_min:col_max]
            window_slope = slope_array[row_min:row_max, col_min:col_max]

            seed_row, seed_col = row_idx - row_min, col_idx - col_min
            if (0 <= seed_row < window_dem.shape[0]) and (0 <= seed_col < window_dem.shape[1]):
                window_flood = flood(window_dem, (seed_row, seed_col), tolerance=0.005)
            else:
                window_flood = np.zeros_like(window_dem, dtype=np.float32)

            feature_stack = np.stack([
                window_dem,
                window_slope,
                window_flood
            ])

            input_tensor = torch.from_numpy(feature_stack).unsqueeze(0).float().to(device)

            output = model(input_tensor).squeeze(0)
            prob = torch.sigmoid(output).item()

            all_probs.append(prob)
            
    local_min_df["prob"] = all_probs
    local_min_df.to_csv(output_csv, index=False)
    print(f"Results saved to {output_csv}")

# Main loop
for tile_id in tile_ids:
    print(f"\n=== Processing tile {tile_id} ===")
    
    # Step 1 — local minima on 5m DEM
    dem_5m_path = f"D:/all_tifs/dem_{tile_id}_5m.tif"
    local_minima_csv = f"D:/all_results/local_minima_{tile_id}.csv"
    process_local_minima_tile(dem_5m_path, local_minima_csv, window_size_local_minima)
    
    # Step 2 — model on 1m DEM + slope
    dem_1m_path = f"D:/all_tifs/dem_{tile_id}_1m.tif"
    slope_1m_path = f"D:/all_tifs/slope_{tile_id}_1m.tif"
    output_csv = f"D:/all_results/updated_local_minima_{tile_id}.csv"
    run_model_on_minima(dem_1m_path, slope_1m_path, local_minima_csv, output_csv, model, window_size_model)

    torch.cuda.empty_cache()