In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import json
import os
from PIL import Image
import numpy as np
import cv2

In [None]:
CHECKPOINT_PATH = "/kaggle/input/culane-attention-module-trained-model/culane_attention_module_trained_model/best_model_culane_attention.pth.tar"

CULANE_ROOT = "/kaggle/input/culane"
EVAL_ANNOTATION_FILE = "/kaggle/input/culane/CULane/culane_val_annotations.json"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
IMAGE_HEIGHT, IMAGE_WIDTH = 160, 320

print(f"Using device: {DEVICE}")
print(f"Model path: {CHECKPOINT_PATH}")
print(f"Dataset path: {CULANE_ROOT}")

Using device: cuda
Model path: /kaggle/input/culane-attention-module-trained-model/culane_attention_module_trained_model/best_model_culane_attention.pth.tar
Dataset path: /kaggle/input/culane


In [None]:
# DATASET AND MODEL DEFINITIONS 

class LaneDataset(Dataset):
    def __init__(self, annotation_path, root_dir, transform=None):
        with open(annotation_path, 'r') as f: self.annotations = json.load(f)
        self.root_dir = root_dir; self.transform = transform
    def __len__(self): return len(self.annotations)
    def __getitem__(self, idx):
        ann = self.annotations[idx]; img_rel_path = ann['image'].replace('\\', '/'); mask_rel_path = ann['mask'].replace('\\', '/')
        img_path = os.path.join(self.root_dir, img_rel_path); mask_path = os.path.join(self.root_dir, mask_rel_path)
        try:
            image = np.array(Image.open(img_path).convert("RGB")); mask = np.array(Image.open(mask_path).convert("L"))
        except FileNotFoundError:
            image = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH, 3), dtype=np.uint8); mask = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH), dtype=np.uint8)
        if self.transform:
            augmented = self.transform(image=image, mask=mask); image = augmented["image"]; mask = augmented["mask"]
        return image, (mask > 0).float().unsqueeze(0)


In [4]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__(); self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))
    def forward(self, x): return self.double_conv(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__(); self.W_g = nn.Sequential(nn.Conv2d(F_g, F_int, kernel_size=1, bias=True), nn.BatchNorm2d(F_int)); self.W_x = nn.Sequential(nn.Conv2d(F_l, F_int, kernel_size=1, bias=True), nn.BatchNorm2d(F_int)); self.psi = nn.Sequential(nn.Conv2d(F_int, 1, kernel_size=1, bias=True), nn.BatchNorm2d(1), nn.Sigmoid()); self.relu = nn.ReLU(inplace=True)
    def forward(self, g, x): g1 = self.W_g(g); x1 = self.W_x(x); psi = self.relu(g1 + x1); psi = self.psi(psi); return x * psi

class AttentionUNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(AttentionUNET, self).__init__(); self.downs = nn.ModuleList(); self.ups = nn.ModuleList(); self.pool = nn.MaxPool2d(2, 2)
        for feature in features: self.downs.append(DoubleConv(in_channels, feature)); in_channels = feature
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        for i, feature in enumerate(reversed(features)): self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)); self.ups.append(AttentionGate(F_g=feature, F_l=feature, F_int=feature // 2)); self.ups.append(DoubleConv(feature * 2, feature))
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
    def forward(self, x):
        skip_connections = [];
        for down in self.downs: x = down(x); skip_connections.append(x); x = self.pool(x)
        x = self.bottleneck(x); skip_connections = skip_connections[::-1]
        for i in range(0, len(self.ups), 3):
            x = self.ups[i](x); skip_connection = skip_connections[i // 3]
            att_out = self.ups[i+1](g=x, x=skip_connection)
            if x.shape != att_out.shape: x = F.interpolate(x, size=att_out.shape[2:])
            x = torch.cat((att_out, x), dim=1); x = self.ups[i+2](x)
        return self.final_conv(x)


In [None]:
# INSTANCE-LEVEL METRIC FUNCTIONS
def calculate_instance_iou(mask1, mask2, smooth=1e-6):
    """Calculates IoU for two single-instance binary masks."""
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    return (intersection + smooth) / (union + smooth)

def calculate_instance_metrics(pred_mask, gt_mask, iou_threshold=0.5):
    """Calculates instance-level TP, FP, FN based on IoU threshold."""
    num_pred_labels, pred_labels = cv2.connectedComponents(pred_mask)
    num_gt_labels, gt_labels = cv2.connectedComponents(gt_mask)
    
    if num_pred_labels <= 1 and num_gt_labels <= 1: return {'tp': 0, 'fp': 0, 'fn': 0}

    iou_matrix = np.zeros((num_pred_labels - 1, num_gt_labels - 1))
    for i in range(1, num_pred_labels):
        for j in range(1, num_gt_labels):
            pred_instance = (pred_labels == i).astype(np.uint8)
            gt_instance = (gt_labels == j).astype(np.uint8)
            iou_matrix[i-1, j-1] = calculate_instance_iou(pred_instance, gt_instance)
    
    matches = []
    if iou_matrix.size > 0:
        for j in range(iou_matrix.shape[1]): # For each GT lane
            best_iou_for_gt = iou_matrix[:, j].max()
            if best_iou_for_gt > iou_threshold:
                pred_idx = iou_matrix[:, j].argmax()
                if pred_idx not in [m[0] for m in matches]:
                    matches.append((pred_idx, j))

    tp = len(matches)
    fp = (num_pred_labels - 1) - tp
    fn = (num_gt_labels - 1) - tp
    
    return {'tp': tp, 'fp': fp, 'fn': fn}


In [None]:
def main():
    print("ðŸš€ Starting CULane model evaluation (Instance-Level)...")
    
    model = AttentionUNET(in_channels=3, out_channels=1).to(DEVICE)
    print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()

    eval_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    
    eval_dataset = LaneDataset(annotation_path=EVAL_ANNOTATION_FILE, root_dir=CULANE_ROOT, transform=eval_transform)
    eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    print(f"Loaded {len(eval_dataset)} images for evaluation.")

    total_tp, total_fp, total_fn = 0, 0, 0
    with torch.no_grad():
        for images, targets in tqdm(eval_loader, desc="Evaluating Instances"):
            images = images.to(DEVICE)
            preds = torch.sigmoid(model(images))
            preds = (preds > 0.5)

            for i in range(preds.shape[0]):
                pred_mask = preds[i].squeeze().cpu().numpy().astype(np.uint8)
                gt_mask = targets[i].squeeze().cpu().numpy().astype(np.uint8)
                
                metrics = calculate_instance_metrics(pred_mask, gt_mask)
                total_tp += metrics['tp']
                total_fp += metrics['fp']
                total_fn += metrics['fn']

    smooth = 1e-6
    precision = (total_tp + smooth) / (total_tp + total_fp + smooth)
    recall = (total_tp + smooth) / (total_tp + total_fn + smooth)
    f1_score = 2 * (precision * recall) / (precision + recall + smooth)
    accuracy = (total_tp + smooth) / (total_tp + total_fp + total_fn + smooth)

    print("\n" + "="*45)
    print("âœ… CULANE - FINAL INSTANCE-LEVEL METRICS")
    print("="*45)
    print(f"Total True Positives (Correctly Detected Lanes):  {total_tp}")
    print(f"Total False Positives (Incorrect Detections): {total_fp}")
    print(f"Total False Negatives (Missed Lanes):         {total_fn}")
    print("-" * 45)
    print(f"Accuracy (Paper's Definition): {accuracy:.4f}")
    print(f"Precision:                     {precision:.4f}")
    print(f"Recall:                        {recall:.4f}")
    print(f"F1-Score:                      {f1_score:.4f}")
    print("="*45)

In [7]:
if __name__ == "__main__":
    main()

ðŸš€ Starting CULane model evaluation (Instance-Level)...
Loading model from checkpoint: /kaggle/input/culane-attention-module-trained-model/culane_attention_module_trained_model/best_model_culane_attention.pth.tar
Loaded 9675 images for evaluation.


Evaluating Instances: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 605/605 [02:09<00:00,  4.67it/s]


âœ… CULANE - FINAL INSTANCE-LEVEL METRICS
Total True Positives (Correctly Detected Lanes):  13440
Total False Positives (Incorrect Detections): 32057
Total False Negatives (Missed Lanes):         16083
---------------------------------------------
Accuracy (Paper's Definition): 0.2183
Precision:                     0.2954
Recall:                        0.4552
F1-Score:                      0.3583



