In [11]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import json
import os
from PIL import Image
import numpy as np
import cv2

In [None]:
TUSIMPLE_RAW_ROOT = "/kaggle/input/tusimple/TUSimple" 
CHECKPOINT_PATH = "/kaggle/input/tusimple-attention-module-trained-model/Tusimple_attention_module_trained_model/my_tusimple_attention_model.pth.tar" 
TEST_LABEL_FILE = "/kaggle/input/tusimple/test_label_new.json"

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
IMAGE_HEIGHT, IMAGE_WIDTH = 288, 512

print(f"Using device: {DEVICE}")
print(f"Using device: {DEVICE}")
print(f"Model path: {CHECKPOINT_PATH}")
print(f"Dataset path: {TUSIMPLE_RAW_ROOT}")
print(f"Test labels path: {TEST_LABEL_FILE}")

Using device: cuda
Using device: cuda
Model path: /kaggle/input/tusimple-attention-module-trained-model/Tusimple_attention_module_trained_model/my_tusimple_attention_model.pth.tar
Dataset path: /kaggle/input/tusimple/TUSimple
Test labels path: /kaggle/input/tusimple/test_label_new.json


In [None]:
# DATASET AND MODEL DEFINITIONS
class TusimpleImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform if transform is not None else transforms.ToTensor()
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        image_tensor = self.transform(image)
        return image_tensor, img_path

# Model Architecture with Attention
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.double_conv(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class AttentionUNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(AttentionUNET, self).__init__()
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = nn.MaxPool2d(2); self.conv1 = DoubleConv(64, 128)
        self.down2 = nn.MaxPool2d(2); self.conv2 = DoubleConv(128, 256)
        self.down3 = nn.MaxPool2d(2); self.conv3 = DoubleConv(256, 512)
        self.down4 = nn.MaxPool2d(2); self.bottleneck = DoubleConv(512, 1024)
        
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.Att1 = AttentionGate(F_g=512, F_l=512, F_int=256)
        self.up_conv1 = DoubleConv(1024, 512)
        
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.Att2 = AttentionGate(F_g=256, F_l=256, F_int=128)
        self.up_conv2 = DoubleConv(512, 256)
        
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.Att3 = AttentionGate(F_g=128, F_l=128, F_int=64)
        self.up_conv3 = DoubleConv(256, 128)
        
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.Att4 = AttentionGate(F_g=64, F_l=64, F_int=32)
        self.up_conv4 = DoubleConv(128, 64)
        
        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1); x2 = self.conv1(x2)
        x3 = self.down2(x2); x3 = self.conv2(x3)
        x4 = self.down3(x3); x4 = self.conv3(x4)
        x5 = self.down4(x4); x5 = self.bottleneck(x5)
        
        up1 = self.up1(x5)
        att_x4 = self.Att1(g=up1, x=x4) 
        concat1 = torch.cat([up1, att_x4], dim=1)
        up1_conv = self.up_conv1(concat1)
        
        up2 = self.up2(up1_conv)
        att_x3 = self.Att2(g=up2, x=x3)
        concat2 = torch.cat([up2, att_x3], dim=1)
        up2_conv = self.up_conv2(concat2)
        
        up3 = self.up3(up2_conv)
        att_x2 = self.Att3(g=up3, x=x2) 
        concat3 = torch.cat([up3, att_x2], dim=1)
        up3_conv = self.up_conv3(concat3)
        
        up4 = self.up4(up3_conv)
        att_x1 = self.Att4(g=up4, x=x1) 
        concat4 = torch.cat([up4, att_x1], dim=1)
        up4_conv = self.up_conv4(concat4)
        
        return self.outc(up4_conv)



In [None]:
def main():
    print("ðŸš€ Starting Tusimple model evaluation on TEST SET (Point-Based)...")
    
    print(f"Loading TEST SET ground truth annotations from {TEST_LABEL_FILE}...")
    gt_annotations = []
    with open(TEST_LABEL_FILE) as f:
        gt_annotations.extend([json.loads(line) for line in f])
    gt_lookup = {ann['raw_file']: ann for ann in gt_annotations}

    model = AttentionUNET(in_channels=3, out_channels=1).to(DEVICE)
    print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()
    
    test_image_paths = [os.path.join(TUSIMPLE_RAW_ROOT, path) for path in gt_lookup.keys()]
    
    eval_transform = transforms.Compose([transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), transforms.ToTensor()])
    eval_dataset = TusimpleImageDataset(image_paths=test_image_paths, transform=eval_transform)
    eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    print(f"Loaded {len(eval_dataset)} test images for evaluation.")

    total_correct_points, total_gt_points, total_pred_points = 0, 0, 0
    total_fp_lanes, total_fn_lanes = 0, 0

    with torch.no_grad():
        for images, img_paths in tqdm(eval_loader, desc="Evaluating Test Set"):
            images = images.to(device=DEVICE)
            preds_masks = torch.sigmoid(model(images))
            preds_masks = (preds_masks > 0.5)

            for i in range(len(img_paths)):
                pred_mask = preds_masks[i].squeeze().cpu().numpy()
                relative_path = os.path.relpath(img_paths[i], TUSIMPLE_RAW_ROOT).replace('\\', '/')
                gt_ann = gt_lookup.get(relative_path)
                if not gt_ann: continue
                
                gt_lanes = gt_ann['lanes']
                h_samples = gt_ann['h_samples']

                pred_lanes = []
                num_labels, labels = cv2.connectedComponents(pred_mask.astype(np.uint8))
                for label_idx in range(1, num_labels):
                    lane_pixels = np.where(labels == label_idx)
                    lane_points = []
                    for h in h_samples:
                        scaled_h = int(h * (IMAGE_HEIGHT / 720.0))
                        row_pixels = np.where(lane_pixels[0] == scaled_h)[0]
                        if row_pixels.size > 0:
                            x_coord = int(np.mean(lane_pixels[1][row_pixels]) * (1280.0 / IMAGE_WIDTH))
                            lane_points.append(x_coord)
                            total_pred_points += 1
                        else:
                            lane_points.append(-2)
                    pred_lanes.append(lane_points)
                
                img_correct_points, gt_points_count = 0, 0
                for gt_lane in gt_lanes:
                    for j, gt_x in enumerate(gt_lane):
                        if gt_x == -2: continue
                        gt_points_count += 1
                        best_dist = float('inf')
                        for pred_lane in pred_lanes:
                            if j < len(pred_lane) and pred_lane[j] != -2:
                                best_dist = min(best_dist, abs(gt_x - pred_lane[j]))
                        if best_dist <= 20:
                            img_correct_points += 1
                
                total_correct_points += img_correct_points
                total_gt_points += gt_points_count
                total_fp_lanes += max(0, len(pred_lanes) - len(gt_lanes))
                total_fn_lanes += max(0, len(gt_lanes) - len(pred_lanes))

    tp = total_correct_points
    fp = total_pred_points - total_correct_points
    fn = total_gt_points - total_correct_points
    smooth = 1e-6

    accuracy = (tp + smooth) / (total_gt_points + smooth)
    precision = (tp + smooth) / (tp + fp + smooth)
    recall = (tp + smooth) / (tp + fn + smooth)
    f1_score = 2 * (precision * recall) / (precision + recall + smooth)
    fp_rate_lanes = total_fp_lanes / len(eval_dataset)
    fn_rate_lanes = total_fn_lanes / len(eval_dataset)

    print("\n" + "="*50)
    print("âœ… TUSIMPLE TEST SET - FINAL POINT-BASED METRICS")
    print("="*50)
    print(f"Official Accuracy (Correct Points / GT Points): {accuracy:.4f}")
    print(f"Point-Level Precision:                        {precision:.4f}")
    print(f"Point-Level Recall:                           {recall:.4f}")
    print(f"Point-Level F1-Score:                         {f1_score:.4f}")
    print("-" * 50)
    print(f"Lane FP Rate (per image):                     {fp_rate_lanes:.4f}")
    print(f"Lane FN Rate (per image):                     {fn_rate_lanes:.4f}")
    print("="*50)

In [22]:
if __name__ == "__main__":
    main()

ðŸš€ Starting Tusimple model evaluation on TEST SET (Point-Based)...
Loading TEST SET ground truth annotations from /kaggle/input/tusimple/test_label_new.json...
Loading model from checkpoint: /kaggle/input/tusimple-attention-module-trained-model/Tusimple_attention_module_trained_model/my_tusimple_attention_model.pth.tar
Loaded 2782 test images for evaluation.


Evaluating Test Set: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 348/348 [01:24<00:00,  4.13it/s]


âœ… TUSIMPLE TEST SET - FINAL POINT-BASED METRICS
Official Accuracy (Correct Points / GT Points): 0.7563
Point-Level Precision:                        0.8413
Point-Level Recall:                           0.7563
Point-Level F1-Score:                         0.7966
--------------------------------------------------
Lane FP Rate (per image):                     2.4860
Lane FN Rate (per image):                     0.1373



