In [23]:
import torch
import torch.nn.functional as F
from torchmetrics.functional import dice, jaccard_index
from skimage.metrics import hausdorff_distance
from medpy.metric.binary import precision as medpy_precision, recall as medpy_recall
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
from torchmetrics.classification import Dice


from levee_hunter.augmentations import (
    no_deformations_transform,
    normalize_only,
    train_transform,
)

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model architecture (must match the saved model)
model = smp.DeepLabV3Plus(
    encoder_name="resnet34",  # Same backbone as before
    encoder_weights=None,  # No pre-trained weights (we load our own)
    in_channels=1,  # Input is grayscale (1 channel)
    classes=1,  # Binary segmentation (1 class output)
).to(device)

In [16]:
model_path = "../models/w4-based-models/DeepLabV3Plus-1m_512.pth"
model.load_state_dict(torch.load(model_path, map_location=device))

# Set model to evaluation mode (important for inference)
model.eval()

  model.load_state_dict(torch.load(model_path, map_location=device))


DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

# Datasets

In [17]:
# Load the datasets
train_dataset_1m = torch.load(
    "../data/processed/w3-4-based-datasets/train_dataset_1m_512.pth"
)
val_dataset_1m = torch.load(
    "../data/processed/w3-4-based-datasets/val_dataset_1m_512.pth"
)
bad_dataset_1m = torch.load(
    "../data/intermediate/w3-4-based-datasets/bad_dataset_1m_512.pth"
)

train_dataset_13 = torch.load(
    "../data/processed/w3-4-based-datasets/train_dataset_13_512.pth"
)
val_dataset_13 = torch.load(
    "../data/processed/w3-4-based-datasets/val_dataset_13_512.pth"
)
bad_dataset_13 = torch.load(
    "../data/intermediate/w3-4-based-datasets/bad_dataset_13_512.pth"
)

print(len(train_dataset_1m))
print(len(val_dataset_1m))
print(len(bad_dataset_1m))
print(len(train_dataset_13))
print(len(val_dataset_13))
print(len(bad_dataset_13))

  train_dataset_1m = torch.load(
  val_dataset_1m = torch.load(
  bad_dataset_1m = torch.load(
  train_dataset_13 = torch.load(
  val_dataset_13 = torch.load(
  bad_dataset_13 = torch.load(


704
177
112
343
86
190


In [18]:
val_dataset_13.transform

Compose([
  Normalize(p=1.0, mean=0.0, std=1.0, max_pixel_value=255.0, normalization='standard'),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)

In [19]:
val_dataset_1m.transform

Compose([
  Normalize(p=1.0, mean=0.0, std=1.0, max_pixel_value=255.0, normalization='standard'),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)

In [20]:
val_loader_13 = DataLoader(
    val_dataset_13, batch_size=len(val_dataset_13), shuffle=False
)
val_loader_1m = DataLoader(
    val_dataset_1m, batch_size=len(val_dataset_1m), shuffle=False
)

In [None]:
# Store all predictions & targets for final evaluation
all_preds = []
all_targets = []

with torch.no_grad():
    for images, masks in val_loader_13:
        images, masks = images.to(device), masks.to(device)

        # Get model predictions
        outputs = model(images)
        preds = torch.sigmoid(outputs)  # Convert logits to probabilities
        preds = (preds > 0.5).float()  # Convert to binary mask

        # Store results for evaluation
        all_preds.append(preds.cpu())
        all_targets.append(masks.cpu())

# Concatenate all batches into single tensors
all_preds = torch.cat(all_preds, dim=0)
all_targets = torch.cat(all_targets, dim=0)

# Compute IoU (Jaccard Index)
iou_score = jaccard_index(all_preds, all_targets, task="binary").item()


# Compute Precision & Recall using MedPy
precision_score = medpy_precision(
    all_preds.numpy().squeeze(), all_targets.numpy().squeeze()
)
recall_score = medpy_recall(all_preds.numpy().squeeze(), all_targets.numpy().squeeze())

# Compute Hausdorff Distance (for line structures)
hausdorff_distances = []
for i in range(all_preds.shape[0]):  # Loop through each image in batch
    hausdorff_distances.append(
        hausdorff_distance(
            all_preds[i].numpy().squeeze(), all_targets[i].numpy().squeeze()
        )
    )

hausdorff_dist = sum(hausdorff_distances) / len(
    hausdorff_distances
)  # Average over batch

# Print Results
print(f"IoU Score: {iou_score:.4f}")
# print(f"Dice Score: {dice_score:.4f}")
print(f"Precision: {precision_score:.4f}")
print(f"Recall: {recall_score:.4f}")
print(f"Hausdorff Distance: {hausdorff_dist:.4f}")

IoU Score: 0.9747
Precision: 0.9753
Recall: 0.9994
Hausdorff Distance: 7.1570
