In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import json
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# --- Hyperparameters & Configuration ---
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_EPOCHS = 25
NUM_WORKERS = 2
IMAGE_HEIGHT = 288
IMAGE_WIDTH = 512
PIN_MEMORY = True
ANNOTATION_FILE = "/kaggle/input/masked-dataset/processed/annotations.json"
CHECKPOINT_PATH = "my_tusimple_model.pth.tar"

print(f"Using device: {DEVICE}")

Using device: cuda


In [None]:
# Note: Augmentations are ONLY for the training set
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5), 
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
])


In [None]:
class LaneDataset(Dataset):
    def __init__(self, annotations, root_dir, transform=None):
        self.annotations = annotations
        self.root_dir = root_dir 
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_relative_path = self.annotations[idx]['image']
        mask_relative_path = self.annotations[idx]['mask']
        
        img_path = os.path.join(self.root_dir, img_relative_path)
        mask_path = os.path.join(self.root_dir, mask_relative_path)
        
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
        
        mask = transforms.ToTensor()(mask)
        mask = torch.where(mask > 0, 1.0, 0.0)
        
        return image, mask

In [5]:
# U-Net Model Definition
class DoubleConv(nn.Module):
    """(Convolution => BatchNorm => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNET, self).__init__()
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = nn.MaxPool2d(2)
        self.conv1 = DoubleConv(64, 128)
        self.down2 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(128, 256)
        self.down3 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(256, 512)
        self.down4 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(512, 1024)
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv4 = DoubleConv(128, 64)
        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x2 = self.conv1(x2)
        x3 = self.down2(x2)
        x3 = self.conv2(x3)
        x4 = self.down3(x3)
        x4 = self.conv3(x4)
        x5 = self.down4(x4)
        x5 = self.bottleneck(x5)
        up1 = self.up1(x5)
        concat1 = torch.cat([up1, x4], dim=1)
        up1_conv = self.up_conv1(concat1)
        up2 = self.up2(up1_conv)
        concat2 = torch.cat([up2, x3], dim=1)
        up2_conv = self.up_conv2(concat2)
        up3 = self.up3(up2_conv)
        concat3 = torch.cat([up3, x2], dim=1)
        up3_conv = self.up_conv3(concat3)
        up4 = self.up4(up3_conv)
        concat4 = torch.cat([up4, x1], dim=1)
        up4_conv = self.up_conv4(concat4)
        logits = self.outc(up4_conv)
        return logits

In [None]:
# Training and Evaluation Functions
def train_fn(loader, model, optimizer, loss_fn, scaler):
    """Processes one epoch of training."""
    loop = tqdm(loader)
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().to(device=DEVICE)
        
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        loop.set_postfix(loss=loss.item())

def check_accuracy(loader, model, device="cuda"):
    """
    Checks accuracy on a validation set and returns the average Dice score.
    """
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval() 
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            
    # Calculate final metrics
    avg_dice_score = dice_score / len(loader)
    accuracy = num_correct / num_pixels * 100
    
    model.train() 
    
    return avg_dice_score, accuracy

In [None]:
KAGGLE_ROOT_DIR = "/kaggle/input/masked-dataset/"

print("ðŸ“‚ Loading all TuSimple annotations...")
with open(os.path.join(KAGGLE_ROOT_DIR, "processed/annotations.json"), "r") as f:
    all_annotations = json.load(f)

# Split into Main Training Set (80%) and a Final, held-out Test Set (20%)
main_train_ann, final_test_ann = train_test_split(all_annotations, test_size=0.2, random_state=42)

print(f"Total samples: {len(all_annotations)}")
print(f"Main training set size: {len(main_train_ann)}")
print(f"Final hold-out test set size: {len(final_test_ann)}")

# Split the Main Training Set again into Sub-Training and Sub-Validation sets
sub_train_ann, sub_val_ann = train_test_split(main_train_ann, test_size=0.2, random_state=42)

print(f"\nSub-training set size: {len(sub_train_ann)}")
print(f"Sub-validation set size (for monitoring during training): {len(sub_val_ann)}")


# raining Phase
print("\n--- Starting Training Phase ---")
train_dataset = LaneDataset(annotations=sub_train_ann, root_dir=KAGGLE_ROOT_DIR, transform=train_transform)
val_dataset = LaneDataset(annotations=sub_val_ann, root_dir=KAGGLE_ROOT_DIR, transform=val_transform)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=PIN_MEMORY, shuffle=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=PIN_MEMORY, shuffle=False
)

model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

# Training Loop
best_dice_score = -1.0 
for epoch in range(NUM_EPOCHS):
    print(f"--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    train_fn(train_loader, model, optimizer, loss_fn, scaler)
    
    current_dice, current_acc = check_accuracy(val_loader, model, device=DEVICE)
    print(f"Validation Accuracy (on sub-validation set): {current_acc:.2f}")
    print(f"Validation Dice Score (on sub-validation set): {current_dice:.4f}")
    
    if current_dice > best_dice_score:
        best_dice_score = current_dice
        print(f"âœ… New best model! Saving to {CHECKPOINT_PATH}")
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        torch.save(checkpoint, CHECKPOINT_PATH)


print("\n--- Starting Final Evaluation on the Hold-Out Test Set ---")

final_model = UNET(in_channels=3, out_channels=1).to(DEVICE)

print(f"Loading best model from {CHECKPOINT_PATH}...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
final_model.load_state_dict(checkpoint["state_dict"])

final_test_dataset = LaneDataset(annotations=final_test_ann, root_dir=KAGGLE_ROOT_DIR, transform=val_transform)
final_test_loader = DataLoader(
    final_test_dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=PIN_MEMORY, shuffle=False
)

print("Running final evaluation on the unseen test set...")
final_dice, final_acc = check_accuracy(final_test_loader, final_model, device=DEVICE)

print("     Final Unbiased Model Performance")
print(f"Dice Score on Final Test Set: {final_dice:.4f}")
print(f"Pixel Accuracy on Final Test Set: {final_acc:.2f}")


ðŸ“‚ Loading all TuSimple annotations...
Total samples: 3626
Main training set size: 2900
Final hold-out test set size: 726

Sub-training set size: 2320
Sub-validation set size (for monitoring during training): 580

--- Starting Training Phase ---


  scaler = torch.cuda.amp.GradScaler()


--- Epoch 1/25 ---


  with torch.cuda.amp.autocast():
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [03:18<00:00,  1.46it/s, loss=0.272]


Validation Accuracy (on sub-validation set): 95.75
Validation Dice Score (on sub-validation set): 0.0625
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 2/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.196]


Validation Accuracy (on sub-validation set): 96.21
Validation Dice Score (on sub-validation set): 0.2559
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 3/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.75it/s, loss=0.158]


Validation Accuracy (on sub-validation set): 96.97
Validation Dice Score (on sub-validation set): 0.5230
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 4/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.151]


Validation Accuracy (on sub-validation set): 97.34
Validation Dice Score (on sub-validation set): 0.6412
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 5/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.75it/s, loss=0.129]


Validation Accuracy (on sub-validation set): 97.39
Validation Dice Score (on sub-validation set): 0.6345
--- Epoch 6/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.116] 


Validation Accuracy (on sub-validation set): 97.53
Validation Dice Score (on sub-validation set): 0.6844
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 7/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.108] 


Validation Accuracy (on sub-validation set): 97.51
Validation Dice Score (on sub-validation set): 0.6525
--- Epoch 8/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.75it/s, loss=0.104] 


Validation Accuracy (on sub-validation set): 97.75
Validation Dice Score (on sub-validation set): 0.7163
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 9/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.75it/s, loss=0.0994]


Validation Accuracy (on sub-validation set): 97.82
Validation Dice Score (on sub-validation set): 0.7285
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 10/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.0876]


Validation Accuracy (on sub-validation set): 97.70
Validation Dice Score (on sub-validation set): 0.7027
--- Epoch 11/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.0841]


Validation Accuracy (on sub-validation set): 97.83
Validation Dice Score (on sub-validation set): 0.7301
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 12/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.74it/s, loss=0.0944]


Validation Accuracy (on sub-validation set): 97.88
Validation Dice Score (on sub-validation set): 0.7378
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 13/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.0676]


Validation Accuracy (on sub-validation set): 97.82
Validation Dice Score (on sub-validation set): 0.7195
--- Epoch 14/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.75it/s, loss=0.066] 


Validation Accuracy (on sub-validation set): 97.88
Validation Dice Score (on sub-validation set): 0.7441
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 15/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.74it/s, loss=0.0728]


Validation Accuracy (on sub-validation set): 97.87
Validation Dice Score (on sub-validation set): 0.7359
--- Epoch 16/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.0775]


Validation Accuracy (on sub-validation set): 97.75
Validation Dice Score (on sub-validation set): 0.6922
--- Epoch 17/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.74it/s, loss=0.0833]


Validation Accuracy (on sub-validation set): 97.90
Validation Dice Score (on sub-validation set): 0.7414
--- Epoch 18/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.74it/s, loss=0.0739]


Validation Accuracy (on sub-validation set): 97.92
Validation Dice Score (on sub-validation set): 0.7458
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 19/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.75it/s, loss=0.0806]


Validation Accuracy (on sub-validation set): 98.00
Validation Dice Score (on sub-validation set): 0.7590
âœ… New best model! Saving to my_tusimple_model.pth.tar
--- Epoch 20/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.75it/s, loss=0.0511]


Validation Accuracy (on sub-validation set): 97.93
Validation Dice Score (on sub-validation set): 0.7505
--- Epoch 21/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.0672]


Validation Accuracy (on sub-validation set): 97.99
Validation Dice Score (on sub-validation set): 0.7552
--- Epoch 22/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.74it/s, loss=0.0576]


Validation Accuracy (on sub-validation set): 97.99
Validation Dice Score (on sub-validation set): 0.7574
--- Epoch 23/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:45<00:00,  1.75it/s, loss=0.0448]


Validation Accuracy (on sub-validation set): 97.91
Validation Dice Score (on sub-validation set): 0.7423
--- Epoch 24/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.75it/s, loss=0.0553]


Validation Accuracy (on sub-validation set): 97.97
Validation Dice Score (on sub-validation set): 0.7576
--- Epoch 25/25 ---


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 290/290 [02:46<00:00,  1.74it/s, loss=0.0366]


Validation Accuracy (on sub-validation set): 97.90
Validation Dice Score (on sub-validation set): 0.7541

--- Starting Final Evaluation on the Hold-Out Test Set ---
Loading best model from my_tusimple_model.pth.tar...
Running final evaluation on the unseen test set...
     Final Unbiased Model Performance
Dice Score on Final Test Set: 0.7559
Pixel Accuracy on Final Test Set: 97.97
