In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm 
from IPython.display import  clear_output
from torchvision import transforms
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set training and testing dataset paths
train_folder_paths = [f"Datasets/Breach{i}/{j}" for i in [1, 3, 5, 7, 9] for j in [1, 2, 3]]
test_folder_paths = [f"Datasets/Breach{i}/{j}" for i in [2] for j in [1, 2, 3]]

Transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
])

ROLL_LENGTH = 4  # number of steps for rolling prediction

Using device: cuda


In [3]:
# Load and display mask file
mask_data = np.loadtxt('Datasets/Mask.dat')
mask = mask_data[:, 2].reshape(474, 320)
mask = np.vstack([mask, np.full((6, 320), 0)])
plt.imshow(np.flipud(mask), cmap='jet', interpolation='nearest')
plt.show()

FileNotFoundError: Datasets/Mask.dat not found.

In [None]:
def load_data_from_folder(folder_path, shape=(474, 320), length=24):
    # Read .dat files
    file_list = sorted(f for f in os.listdir(folder_path) if f.endswith('.dat'))
    data_list = []
    for file_name in file_list[0:length]:
        file_path = os.path.join(folder_path, file_name)
        try:
            data = np.loadtxt(file_path)
            data_list.append(data)
        except ValueError as e:
            print(f"Cannot read file {file_name}: {e}")

    if not data_list:
        raise ValueError("No valid data found in folder!")

    dx = dy = dt = 1  # unit spacing and time
    all_variables = []

    # Extract elevation from first file
    elevation = data_list[0][:, 2].reshape(shape)
    elevation = np.vstack([elevation, np.full((6, 320), 35)])
    elevation_new = elevation * mask
    elevation_new[elevation_new == 0] = 35
    all_variables.append(elevation_new)

    # Initialize depth at t0 with zeros
    depth_init_0 = np.full((480, 320), 0)
    all_variables.append(depth_init_0)

    # Compute water depth from water level and elevation
    for data in data_list:
        water_level = data[:, 3] - elevation[:474].flatten()
        depth = water_level.reshape(shape)
        depth = np.vstack([depth, np.full((6, 320), 0)]) * mask
        all_variables.append(depth)

    # Extract source terms
    source_variables = []
    for data in data_list:
        source = data[:, 4].reshape(shape)
        source = source * dt / (dx * dy)
        source = np.vstack([source, np.full((6, 320), 0)])
        source_variables.append(source)

    # Average consecutive source terms for smoother input
    source_variables = [np.full((480, 320), 0)] + source_variables
    source_variables1 = [(source_variables[i] + source_variables[i + 1]) / 2 for i in range(length)]
    all_variables = all_variables + source_variables1 + [np.full((480, 320), 0)]

    return all_variables

In [None]:
class AllVariablesDataset(Dataset):
    def __init__(self, all_variables_list, roll_length=4, transform=Transform):
        self.all_variables_list = [np.copy(item) for item in all_variables_list]
        self.elevation = np.array(self.all_variables_list[0])
        mid_index = (len(self.all_variables_list) - 1) // 2 + 1
        self.depth = np.array(self.all_variables_list[1:mid_index])
        self.source = np.array(self.all_variables_list[mid_index:])
        self.roll_length = roll_length
        self.transform = transform
        self.dataset_length = len(self.depth) - self.roll_length

    def __len__(self):
        return self.dataset_length

    def __getitem__(self, idx):
        # Prepare model input: elevation + depth + source
        seq_depth = [self.depth[idx]]
        seq_source_1 = [self.source[idx]]
        sequence = [self.elevation] + seq_depth + seq_source_1
        sequence = np.stack(sequence, axis=0)
        sequence = torch.tensor(sequence, dtype=torch.float32)

        # Prepare targets and source for rolling prediction
        target_roll = self.depth[idx+1:idx+1+self.roll_length]
        target_roll = torch.tensor(target_roll, dtype=torch.float32)

        source_roll = self.source[idx+1:idx+1+self.roll_length]
        source_roll = torch.tensor(source_roll, dtype=torch.float32)

        if self.transform:
            seed = torch.randint(0, 2**32, (1,)).item()
            torch.manual_seed(seed)
            sequence = self.transform(sequence)
            torch.manual_seed(seed)
            source_roll = self.transform(source_roll)
            torch.manual_seed(seed)
            target_roll = self.transform(target_roll)

        return sequence, source_roll, target_roll

In [9]:
datasets1 = []
for folder_path in train_folder_paths:
    all_variables = load_data_from_folder(folder_path)
    datasets1.append(AllVariablesDataset(all_variables,transform=Transform)) 
combined_dataset = ConcatDataset(datasets1)
datasets2 = []
for folder_path in test_folder_paths:
    all_variables = load_data_from_folder(folder_path)
    datasets2.append(AllVariablesDataset(all_variables,transform=None)) 
test_dataset = ConcatDataset(datasets2)

combined_dataloader = DataLoader(combined_dataset, batch_size=1, shuffle=False)

test_dataloader = DataLoader(test_dataset, batch_size=1,shuffle=False)

In [None]:
class MobileResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=4, dilation=1, dropout_prob=0.0):
        super().__init__()
        expanded_channels = in_channels * expansion
        self.stride = 1
        
        # Inverted residual block: expand -> depthwise -> project
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, expanded_channels, 1, bias=False),  # expand channels
            nn.BatchNorm2d(expanded_channels, track_running_stats=False),
            nn.ReLU6(),
            
            nn.Conv2d(expanded_channels, expanded_channels, 3, 
                      stride=1, padding=dilation, dilation=dilation,
                      groups=expanded_channels, bias=False),  # depthwise conv with dilation
            nn.BatchNorm2d(expanded_channels, track_running_stats=False),
            nn.ReLU6(),
            
            nn.Conv2d(expanded_channels, out_channels, 1, bias=False),  # project channels
            nn.BatchNorm2d(out_channels, track_running_stats=False)
        )
        
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels, track_running_stats=False)
            )
            
        self.dropout = nn.Dropout2d(dropout_prob)
    
    def forward(self, x):
        out = self.conv(x) + self.shortcut(x)
        return self.dropout(F.relu(out))

class LightSE(nn.Module):
    """Lightweight channel attention module"""
    def __init__(self, channel, reduction=8):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel//reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channel//reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.shape
        y = self.avgpool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x * y.expand_as(x)

class FloodUNet(nn.Module):
    def __init__(self, dropout_prob=0.0):
        super().__init__()
        ch = [16, 32, 64, 128]  # channel configuration
        
        # Encoder (downsampling path)
        self.encoder1 = nn.Sequential(
            MobileResidualBlock(3, ch[0], dilation=1),
            MobileResidualBlock(ch[0], ch[0], dilation=2),
            LightSE(ch[0])
        )
        self.down1 = nn.Conv2d(ch[0], ch[0], 3, stride=2, padding=1)
        
        self.encoder2 = nn.Sequential(
            MobileResidualBlock(ch[0], ch[1], dilation=1),
            MobileResidualBlock(ch[1], ch[1], dilation=2),
            LightSE(ch[1])
        )
        self.down2 = nn.Conv2d(ch[1], ch[1], 3, stride=2, padding=1)
        
        self.encoder3 = nn.Sequential(
            MobileResidualBlock(ch[1], ch[2], dilation=1),
            MobileResidualBlock(ch[2], ch[2], dilation=2),
            LightSE(ch[2])
        )
        self.down3 = nn.Conv2d(ch[2], ch[2], 3, stride=2, padding=1)
        
        # Bottleneck with large receptive field
        self.bottom = nn.Sequential(
            MobileResidualBlock(ch[2], ch[3], dilation=3),
            MobileResidualBlock(ch[3], ch[3], dilation=5),
            LightSE(ch[3])
        )
        
        # Decoder (upsampling path)
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[3], ch[2], kernel_size=1)
        )
        self.decoder3 = nn.Sequential(
            MobileResidualBlock(ch[2]*2, ch[2]),
            MobileResidualBlock(ch[2], ch[2])
        )
        
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[2], ch[1], kernel_size=1)
        )
        self.decoder2 = nn.Sequential(
            MobileResidualBlock(ch[1]*2, ch[1]),
            MobileResidualBlock(ch[1], ch[1])
        )
        
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[1], ch[0], kernel_size=1)
        )
        self.decoder1 = nn.Sequential(
            MobileResidualBlock(ch[0]*2, ch[0]),
            MobileResidualBlock(ch[0], ch[0])
        )
        
        # Output head
        self.output = nn.Sequential(
            nn.Conv2d(ch[0], 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 1, kernel_size=1),
            nn.ReLU()
        )
        
    def forward(self, x):
        # Encode
        e1 = self.encoder1(x)      # [16, 480,320]
        d1 = self.down1(e1)        # [16, 240,160]
        
        e2 = self.encoder2(d1)     # [32, 240,160]
        d2 = self.down2(e2)        # [32, 120,80]
        
        e3 = self.encoder3(d2)     # [64, 120,80]
        d3 = self.down3(e3)        # [64, 60,40]
        
        # Bottleneck
        b = self.bottom(d3)        # [128, 60,40]
        
        # Decode
        u3 = self.up3(b)           # [64, 120,80]
        cat3 = torch.cat([e3, u3], dim=1)
        dec3 = self.decoder3(cat3) # [64, 120,80]
        
        u2 = self.up2(dec3)        # [32, 240,160]
        cat2 = torch.cat([e2, u2], dim=1)
        dec2 = self.decoder2(cat2) # [32, 240,160]
        
        u1 = self.up1(dec2)        # [16, 480,320]
        cat1 = torch.cat([e1, u1], dim=1)
        dec1 = self.decoder1(cat1) # [16, 480,320]
        
        return self.output(dec1)   # [1, 480,320]

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder (downsampling)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),  # downsample by 2
            
            nn.Conv2d(16, 32, 5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # downsample by 4
            
            nn.Conv2d(32, 64, 5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # downsample by 8
        )
        
        # Decoder (upsampling)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            nn.ConvTranspose2d(32, 16, 2, stride=2),
            nn.Conv2d(16, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            
            nn.ConvTranspose2d(16, 1, 2, stride=2),
            nn.ReLU()  # ensure non-negative output
        )

    def forward(self, x):
        x = self.encoder(x)
        return self.decoder(x)

# Basic UNet-like model
class ConvBlock(nn.Module):
    """Basic convolutional block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(3, 16)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(16, 32)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(32, 64)
        self.pool3 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = ConvBlock(64, 128)
        
        # Decoder
        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec3 = ConvBlock(128, 64)
        self.up2 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec2 = ConvBlock(64, 32)
        self.up1 = nn.ConvTranspose2d(32, 16, 2, stride=2)
        self.dec1 = ConvBlock(32, 16)
        
        # Output layer
        self.output = nn.Conv2d(16, 1, 1)
        self.final_act = nn.ReLU()  # ensure non-negative output

    def forward(self, x):
        # Encode
        e1 = self.enc1(x)                # [16, H, W]
        e2 = self.enc2(self.pool1(e1))   # [32, H/2, W/2]
        e3 = self.enc3(self.pool2(e2))   # [64, H/4, W/4]
        
        # Bottleneck
        b = self.bottleneck(self.pool3(e3))  # [128, H/8, W/8]
        
        # Decode
        d3 = self.up3(b)                     # [64, H/4, W/4]
        d3 = torch.cat([e3, d3], dim=1)      # [128, H/4, W/4]
        d3 = self.dec3(d3)                   # [64, H/4, W/4]
        
        d2 = self.up2(d3)                    # [32, H/2, W/2]
        d2 = torch.cat([e2, d2], dim=1)      # [64, H/2, W/2]
        d2 = self.dec2(d2)                   # [32, H/2, W/2]
        
        d1 = self.up1(d2)                    # [16, H, W]
        d1 = torch.cat([e1, d1], dim=1)      # [32, H, W]
        d1 = self.dec1(d1)                   # [16, H, W]
        
        return self.final_act(self.output(d1))  # [1, H, W]


In [None]:
class NoResidualBlock(nn.Module):
    """Standard convolutional block without residuals"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, dropout_prob=0.0):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels, track_running_stats=False),
            nn.ReLU(inplace=True)
        )
        self.dropout = nn.Dropout2d(dropout_prob)

    def forward(self, x):
        return self.dropout(self.conv(x))

class NoAttentionFloodUNet(nn.Module):
    """Ablation model: removes channel attention modules"""
    def __init__(self, dropout_prob=0.0):
        super().__init__()
        ch = [16, 32, 64, 128]

        # Encoder (LightSE removed)
        self.encoder1 = nn.Sequential(
            MobileResidualBlock(3, ch[0], dilation=1),
            MobileResidualBlock(ch[0], ch[0], dilation=2)
        )
        self.down1 = nn.Conv2d(ch[0], ch[0], 3, stride=2, padding=1)

        self.encoder2 = nn.Sequential(
            MobileResidualBlock(ch[0], ch[1], dilation=1),
            MobileResidualBlock(ch[1], ch[1], dilation=2)
        )
        self.down2 = nn.Conv2d(ch[1], ch[1], 3, stride=2, padding=1)

        self.encoder3 = nn.Sequential(
            MobileResidualBlock(ch[1], ch[2], dilation=1),
            MobileResidualBlock(ch[2], ch[2], dilation=2)
        )
        self.down3 = nn.Conv2d(ch[2], ch[2], 3, stride=2, padding=1)

        # Bottleneck
        self.bottom = nn.Sequential(
            MobileResidualBlock(ch[2], ch[3], dilation=3),
            MobileResidualBlock(ch[3], ch[3], dilation=5)
        )

        # Decoder (same as FloodUNet)
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[3], ch[2], kernel_size=1)
        )
        self.decoder3 = nn.Sequential(
            MobileResidualBlock(ch[2]*2, ch[2]),
            MobileResidualBlock(ch[2], ch[2])
        )

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[2], ch[1], kernel_size=1)
        )
        self.decoder2 = nn.Sequential(
            MobileResidualBlock(ch[1]*2, ch[1]),
            MobileResidualBlock(ch[1], ch[1])
        )

        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[1], ch[0], kernel_size=1)
        )
        self.decoder1 = nn.Sequential(
            MobileResidualBlock(ch[0]*2, ch[0]),
            MobileResidualBlock(ch[0], ch[0])
        )

        # Output layer
        self.output = nn.Sequential(
            nn.Conv2d(ch[0], 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 1, kernel_size=1),
            nn.ReLU()
        )

    def forward(self, x):
        e1 = self.encoder1(x)
        d1 = self.down1(e1)
        e2 = self.encoder2(d1)
        d2 = self.down2(e2)
        e3 = self.encoder3(d2)
        d3 = self.down3(e3)
        b = self.bottom(d3)
        u3 = self.up3(b)
        dec3 = self.decoder3(torch.cat([e3, u3], dim=1))
        u2 = self.up2(dec3)
        dec2 = self.decoder2(torch.cat([e2, u2], dim=1))
        u1 = self.up1(dec2)
        dec1 = self.decoder1(torch.cat([e1, u1], dim=1))
        return self.output(dec1)


class NoResidualFloodUNet(nn.Module):
    """Ablation model: removes residual connections, retains attention"""
    def __init__(self, dropout_prob=0.0):
        super().__init__()
        ch = [16, 32, 64, 128]

        # Encoder with attention (no residuals)
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, ch[0], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[0], track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch[0], ch[0], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[0], track_running_stats=False),
            nn.ReLU(inplace=True),
            LightSE(ch[0])
        )
        self.down1 = nn.Conv2d(ch[0], ch[0], 3, stride=2, padding=1)

        self.enc2 = nn.Sequential(
            nn.Conv2d(ch[0], ch[1], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[1], track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch[1], ch[1], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[1], track_running_stats=False),
            nn.ReLU(inplace=True),
            LightSE(ch[1])
        )
        self.down2 = nn.Conv2d(ch[1], ch[1], 3, stride=2, padding=1)

        self.enc3 = nn.Sequential(
            nn.Conv2d(ch[1], ch[2], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[2], track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch[2], ch[2], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[2], track_running_stats=False),
            nn.ReLU(inplace=True),
            LightSE(ch[2])
        )
        self.down3 = nn.Conv2d(ch[2], ch[2], 3, stride=2, padding=1)

        # Bottleneck
        self.bottom = nn.Sequential(
            nn.Conv2d(ch[2], ch[3], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[3], track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch[3], ch[3], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[3], track_running_stats=False),
            nn.ReLU(inplace=True),
            LightSE(ch[3])
        )

        # Decoder
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[3], ch[2], 1, bias=False)
        )
        self.dec3 = nn.Sequential(
            nn.Conv2d(ch[2]*2, ch[2], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[2], track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch[2], ch[2], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[2], track_running_stats=False),
            nn.ReLU(inplace=True)
        )

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[2], ch[1], 1, bias=False)
        )
        self.dec2 = nn.Sequential(
            nn.Conv2d(ch[1]*2, ch[1], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[1], track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch[1], ch[1], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[1], track_running_stats=False),
            nn.ReLU(inplace=True)
        )

        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[1], ch[0], 1, bias=False)
        )
        self.dec1 = nn.Sequential(
            nn.Conv2d(ch[0]*2, ch[0], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[0], track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch[0], ch[0], 3, padding=1, bias=False),
            nn.BatchNorm2d(ch[0], track_running_stats=False),
            nn.ReLU(inplace=True)
        )

        # Output
        self.output = nn.Sequential(
            nn.Conv2d(ch[0], 8, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(8, 1, 1),
            nn.ReLU(inplace=True)
        )
        self.dropout = nn.Dropout2d(dropout_prob)

    def forward(self, x):
        e1 = self.enc1(x)
        d1 = self.down1(e1)
        e2 = self.enc2(d1)
        d2 = self.down2(e2)
        e3 = self.enc3(d2)
        d3 = self.down3(e3)
        b = self.bottom(d3)
        u3 = self.up3(b)
        dec3 = self.dec3(torch.cat([e3, u3], dim=1))
        u2 = self.up2(dec3)
        dec2 = self.dec2(torch.cat([e2, u2], dim=1))
        u1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat([e1, u1], dim=1))
        out = self.output(dec1)
        return self.dropout(out)

    
class PlainFloodUNet(nn.Module):
    """Ablation model: no residuals, no attention modules"""
    def __init__(self, dropout_prob=0.0):
        super().__init__()
        ch = [16, 32, 64, 128]

        # Encoder
        self.enc1 = nn.Sequential(
            NoResidualBlock(3, ch[0], dropout_prob=dropout_prob),
            NoResidualBlock(ch[0], ch[0], dropout_prob=dropout_prob)
        )
        self.down1 = nn.Conv2d(ch[0], ch[0], kernel_size=3, stride=2, padding=1)

        self.enc2 = nn.Sequential(
            NoResidualBlock(ch[0], ch[1], dropout_prob=dropout_prob),
            NoResidualBlock(ch[1], ch[1], dropout_prob=dropout_prob)
        )
        self.down2 = nn.Conv2d(ch[1], ch[1], kernel_size=3, stride=2, padding=1)

        self.enc3 = nn.Sequential(
            NoResidualBlock(ch[1], ch[2], dropout_prob=dropout_prob),
            NoResidualBlock(ch[2], ch[2], dropout_prob=dropout_prob)
        )
        self.down3 = nn.Conv2d(ch[2], ch[2], kernel_size=3, stride=2, padding=1)

        # Bottleneck
        self.bottom = nn.Sequential(
            NoResidualBlock(ch[2], ch[3], dropout_prob=dropout_prob),
            NoResidualBlock(ch[3], ch[3], dropout_prob=dropout_prob)
        )

        # Decoder
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[3], ch[2], kernel_size=1)
        )
        self.dec3 = nn.Sequential(
            NoResidualBlock(ch[2]*2, ch[2], dropout_prob=dropout_prob),
            NoResidualBlock(ch[2], ch[2], dropout_prob=dropout_prob)
        )

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[2], ch[1], kernel_size=1)
        )
        self.dec2 = nn.Sequential(
            NoResidualBlock(ch[1]*2, ch[1], dropout_prob=dropout_prob),
            NoResidualBlock(ch[1], ch[1], dropout_prob=dropout_prob)
        )

        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ch[1], ch[0], kernel_size=1)
        )
        self.dec1 = nn.Sequential(
            NoResidualBlock(ch[0]*2, ch[0], dropout_prob=dropout_prob),
            NoResidualBlock(ch[0], ch[0], dropout_prob=dropout_prob)
        )

        # Output
        self.output = nn.Sequential(
            nn.Conv2d(ch[0], 8, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(8, 1, kernel_size=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        e1 = self.enc1(x)
        d1 = self.down1(e1)
        e2 = self.enc2(d1)
        d2 = self.down2(e2)
        e3 = self.enc3(d2)
        d3 = self.down3(e3)
        b = self.bottom(d3)
        u3 = self.up3(b)
        dec3 = self.dec3(torch.cat([e3, u3], dim=1))
        u2 = self.up2(dec3)
        dec2 = self.dec2(torch.cat([e2, u2], dim=1))
        u1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat([e1, u1], dim=1))
        return self.output(dec1)

In [None]:
class FloodLoss(nn.Module):
    def __init__(self, mse_weight=1.0, flood_threshold=0.05, alpha=0, beta=0):
        super().__init__()
        self.mse_weight = mse_weight
        self.flood_threshold = flood_threshold
        self.alpha = alpha
        self.beta = beta

    def forward(self, preds, targets):
        # Penalty for under-predicted flood area
        flood_mask = targets > 0.1
        under_pred_penalty = torch.clamp(0.1 - preds, min=0)
        flood_penalty = (under_pred_penalty * flood_mask.float()).mean()

        # Penalty for over-predicted non-flood area
        unflood_mask = flood_mask <= 0.1
        under_pred_penalty = torch.clamp(preds - 0.1, min=0)
        flood_penalty += (under_pred_penalty * unflood_mask.float()).mean()

        # Root mean square error
        mse_loss = torch.mean(torch.abs(preds - targets) ** 2) ** 0.5

        # Final loss weighted by flood coverage
        denominator = (flood_mask.float().mean() ** 0.5)
        total_loss = (self.mse_weight * mse_loss + self.alpha * flood_penalty) / denominator
        return total_loss

In [None]:
def train_one_epoch(model, dataloader, optimizer, criterion, scaler, device):
    model.eval()
    running_loss = 0.0
    time_steps = ROLL_LENGTH

    # Apply exponential time-decay weighting
    decay_rate = 0.8
    time_indices = torch.arange(ROLL_LENGTH, device=device)
    weights = torch.exp(-decay_rate * time_indices)
    weights = weights / weights.sum()

    for inputs, source_roll, target_roll in dataloader:
        inputs = inputs.to(device)
        source_roll = source_roll.to(device)
        target_roll = target_roll.to(device)
        optimizer.zero_grad()

        with torch.autocast(device_type='cuda', enabled=True):
            current_input = inputs
            total_loss = 0.0

            for step in range(time_steps):
                outputs = model(current_input)
                current_target = target_roll[:, step].unsqueeze(1)
                step_loss = criterion(outputs, current_target)
                total_loss += weights[step] * step_loss

                if step < time_steps - 1:
                    next_source = source_roll[:, step + 1].unsqueeze(1)
                    current_input = torch.cat([
                        inputs[:, 0:1],  # elevation
                        outputs.detach(),  # predicted depth
                        next_source       # next source input
                    ], dim=1)

            final_loss = total_loss / time_steps

        scaler.scale(final_loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.5)
        scaler.step(optimizer)
        scaler.update()
        running_loss += final_loss.item()

    return running_loss / len(dataloader)

In [None]:
def validate(model, dataloader, criterion, device):
    model.eval()
    val_loss = 0.0
    time_steps = ROLL_LENGTH
    decay_rate = 0.8
    time_indices = torch.arange(ROLL_LENGTH, device=device)
    weights = torch.exp(-decay_rate * time_indices)
    weights = weights / weights.sum()

    with torch.no_grad():
        for inputs, source_roll, target_roll in dataloader:
            inputs = inputs.to(device)
            source_roll = source_roll.to(device)
            target_roll = target_roll.to(device)

            current_input = inputs
            total_loss = 0.0

            for step in range(time_steps):
                outputs = model(current_input)
                current_target = target_roll[:, step].unsqueeze(1)
                step_loss = criterion(outputs, current_target)
                total_loss += weights[step] * step_loss

                if step < time_steps - 1:
                    next_source = source_roll[:, step + 1].unsqueeze(1)
                    current_input = torch.cat([
                        inputs[:, 0:1],
                        outputs,
                        next_source
                    ], dim=1)

            val_loss += (total_loss / time_steps).item()

    return val_loss / len(dataloader)

In [None]:
# 模型和损失函数初始化
model = PlainFloodUNet().to(device)
criterion = FloodLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0002, weight_decay=5e-8)
scaler = torch.cuda.amp.GradScaler()
num_epochs = 200
best_val_loss = float('inf')
early_stop_counter = 0
patience = 200
save_path = "bestmodle.pth"
# 训练跟踪数据
train_losses = []
val_losses = []
best_epoch = 0

model.to(device)
with tqdm(total=num_epochs, desc='Training Progress', unit='epoch') as pbar:
    for epoch in range(num_epochs):
        # 训练阶段
        train_loss = train_one_epoch(model, combined_dataloader, 
                                   optimizer, criterion, scaler, device)
        
        # 验证阶段
        val_loss = validate(model, test_dataloader, criterion, device)
        
        # 保存最佳模型
        if val_loss <= best_val_loss:
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'loss': val_loss
            }, save_path)
            best_val_loss = val_loss
            best_epoch = epoch
            early_stop_counter = 0
        else:
            early_stop_counter += 1

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # 早停判断
        if early_stop_counter >= patience:
            print(f"\nEarly stopping triggered at epoch {epoch}! Best epoch: {best_epoch}")
            break
        
        # 更新进度条
        pbar.set_postfix({
            'Train Loss': f'{train_loss:.4f}',
            'Val Loss': f'{val_loss:.4f}',
            'Best Val': f'{best_val_loss:.4f}',
            'LR': optimizer.param_groups[0]['lr']
        })
        pbar.update(1)
        
        # 实时可视化（每epoch刷新）
        if epoch % 1 == 0 or epoch == num_epochs-1:
            clear_output(wait=True)
            pbar.update(1)
            plt.figure(figsize=(12, 6))
            
            ax1 = plt.subplot(121)
            ax1.semilogy(train_losses, label='Train', marker='.', alpha=0.8)
            ax1.semilogy(val_losses, label='Val', marker='.', alpha=0.8)
            ax1.scatter(best_epoch, best_val_loss, c='r', label=f'Best Val: {best_val_loss:.4f}')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.legend()
            
            ax2 = plt.subplot(122)
            ax2.axhline(y=0.00005, color='g', linestyle='--', label='Fixed LR')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Learning Rate')
            ax2.set_title('Constant Learning Rate')
            ax2.legend()
            
            plt.tight_layout()
            plt.show()

print("\nTraining completed!")
print(f"Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}")