In [None]:
import os
import json
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [26]:
# ===================== Dataset =====================
class LaneDataset(Dataset):
    def __init__(self, annotation_path, root_dirs, transform=None):
        with open(annotation_path, 'r') as f:
            self.lines = f.readlines()
        self.root_dirs = root_dirs
        self.transform = transform

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        ann = json.loads(self.lines[idx])
        source = ann['source']
        root_dir = self.root_dirs[source]


        img_path = os.path.join(root_dir, ann['image'])
        mask_path = os.path.join(root_dir, ann['mask'])
        
        img_path = img_path.replace('\\', '/')
        mask_path = mask_path.replace('\\', '/')

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        mask = (mask > 0).float().unsqueeze(0) # binary mask
        return image, mask


In [27]:
# ===================== U-Net Model =====================
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)

        # Downsampling
        for feature in features:
            self.downs.append(self.double_conv(in_channels, feature))
            in_channels = feature

        # Bottleneck
        self.bottleneck = self.double_conv(features[-1], features[-1] * 2)

        # Upsampling
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(self.double_conv(feature * 2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])
            x = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](x)

        return self.final_conv(x)

    @staticmethod
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )


In [28]:
# ===================== Loss Functions =====================
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        intersection = (preds * targets).sum()
        dice = (2. * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
        return 1 - dice


def loss_fn(preds, targets):
    bce = nn.BCEWithLogitsLoss()(preds, targets)
    dice = DiceLoss()(preds, targets)
    return bce + dice



In [29]:
# ===================== Metrics =====================
def dice_score(preds, targets, threshold=0.5):
    preds = (torch.sigmoid(preds) > threshold).float()
    intersection = (preds * targets).sum()
    return (2. * intersection) / (preds.sum() + targets.sum() + 1e-8)


def pixel_accuracy(preds, targets, threshold=0.5):
    preds = (torch.sigmoid(preds) > threshold).float()
    correct = (preds == targets).float()
    return correct.sum() / correct.numel()


In [30]:
# ===================== Training =====================
def train_one_epoch(loader, model, optimizer, scaler, device):
    loop = tqdm(loader, leave=True)
    total_loss, total_dice, total_acc = 0, 0, 0

    for data, targets in loop:
        data, targets = data.to(device), targets.to(device)

        with autocast():
            preds = model(data)
            loss = loss_fn(preds, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        total_dice += dice_score(preds, targets).item()
        total_acc += pixel_accuracy(preds, targets).item()

        loop.set_postfix(loss=loss.item())

    return total_loss / len(loader), total_dice / len(loader), total_acc / len(loader)


def validate(loader, model, device):
    model.eval()
    total_loss, total_dice, total_acc = 0, 0, 0

    with torch.no_grad():
        for data, targets in loader:
            data, targets = data.to(device), targets.to(device)
            preds = model(data)
            loss = loss_fn(preds, targets)

            total_loss += loss.item()
            total_dice += dice_score(preds, targets).item()
            total_acc += pixel_accuracy(preds, targets).item()

    model.train()
    return total_loss / len(loader), total_dice / len(loader), total_acc / len(loader)


In [35]:
def main():
    # Config
    IMAGE_HEIGHT, IMAGE_WIDTH = 160, 320
    LEARNING_RATE = 2e-4
    BATCH_SIZE = 24
    NUM_EPOCHS = 20
    NUM_WORKERS = 4
    PIN_MEMORY = True
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    CHECKPOINT_PATH = "best_model.pth.tar"  # <-- standard extension

    # Paths
    root_dirs = {
        "culane": "/kaggle/input/culane/CULane",
        "tusimple": "/kaggle/input/masked-dataset/processed",
    }
    train_annotations = "/kaggle/input/annotations/annotations/final_train_annotations.jsonl"
    val_annotations = "/kaggle/input/annotations/annotations/final_val_annotations.jsonl"

    # Transforms
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

    val_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

    # Datasets & Loaders
    train_dataset = LaneDataset(train_annotations, root_dirs, transform=train_transform)
    val_dataset = LaneDataset(val_annotations, root_dirs, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    # Model, Optimizer, AMP
    model = UNet(in_channels=3, out_channels=1).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scaler = GradScaler()

    # Scheduler (optional, but recommended)
    scheduler = ReduceLROnPlateau(optimizer, mode="max", factor=0.5,
                                  patience=3, verbose=True)

    # Resume training if checkpoint exists
    best_dice_score = -1.0
    if os.path.exists(CHECKPOINT_PATH):
        print("🔄 Resuming from checkpoint...")
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        best_dice_score = checkpoint.get("best_dice", -1.0)  # ✅ safer load

    # Training loop
    for epoch in range(NUM_EPOCHS):
        train_loss, train_dice, train_acc = train_one_epoch(
            train_loader, model, optimizer, scaler, DEVICE)
        val_loss, val_dice, val_acc = validate(val_loader, model, DEVICE)

        print(f"\nEpoch [{epoch+1}/{NUM_EPOCHS}]")
        print(f"Train: Loss={train_loss:.4f}, Dice={train_dice:.4f}, Acc={train_acc:.4f}")
        print(f"Val:   Loss={val_loss:.4f}, Dice={val_dice:.4f}, Acc={val_acc:.4f}")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")  # ✅ LR logging

        # Save only if validation improves
        if val_dice > best_dice_score:
            best_dice_score = val_dice
            print(f"✅ New best model! Saving to {CHECKPOINT_PATH}")
            torch.save({
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "best_dice": best_dice_score,
            }, CHECKPOINT_PATH)

        # Step scheduler
        scheduler.step(val_dice)


In [36]:
if __name__ == "__main__":
    main()

  scaler = GradScaler()
  with autocast():
100%|██████████| 3855/3855 [31:17<00:00,  2.05it/s, loss=0.482]



Epoch [1/20]
Train: Loss=0.5909, Dice=0.5649, Acc=0.9736
Val:   Loss=0.5383, Dice=0.5789, Acc=0.9733
Current LR: 0.000200
✅ New best model! Saving to best_model.pth.tar


100%|██████████| 3855/3855 [31:17<00:00,  2.05it/s, loss=0.467]



Epoch [2/20]
Train: Loss=0.3999, Dice=0.6887, Acc=0.9813
Val:   Loss=0.5142, Dice=0.5978, Acc=0.9753
Current LR: 0.000200
✅ New best model! Saving to best_model.pth.tar


100%|██████████| 3855/3855 [31:18<00:00,  2.05it/s, loss=0.336]



Epoch [3/20]
Train: Loss=0.3596, Dice=0.7203, Acc=0.9832
Val:   Loss=0.5052, Dice=0.6046, Acc=0.9761
Current LR: 0.000200
✅ New best model! Saving to best_model.pth.tar


100%|██████████| 3855/3855 [31:18<00:00,  2.05it/s, loss=0.391]



Epoch [4/20]
Train: Loss=0.3349, Dice=0.7396, Acc=0.9843
Val:   Loss=0.5062, Dice=0.6049, Acc=0.9758
Current LR: 0.000200
✅ New best model! Saving to best_model.pth.tar


100%|██████████| 3855/3855 [31:18<00:00,  2.05it/s, loss=0.483]



Epoch [5/20]
Train: Loss=0.3172, Dice=0.7536, Acc=0.9851
Val:   Loss=0.5110, Dice=0.6021, Acc=0.9764
Current LR: 0.000200


100%|██████████| 3855/3855 [31:18<00:00,  2.05it/s, loss=0.252]



Epoch [6/20]
Train: Loss=0.3030, Dice=0.7648, Acc=0.9858
Val:   Loss=0.4942, Dice=0.6149, Acc=0.9761
Current LR: 0.000200
✅ New best model! Saving to best_model.pth.tar


100%|██████████| 3855/3855 [31:20<00:00,  2.05it/s, loss=0.391]



Epoch [7/20]
Train: Loss=0.2911, Dice=0.7742, Acc=0.9863
Val:   Loss=0.4972, Dice=0.6114, Acc=0.9767
Current LR: 0.000200


100%|██████████| 3855/3855 [31:19<00:00,  2.05it/s, loss=0.229]



Epoch [8/20]
Train: Loss=0.2805, Dice=0.7826, Acc=0.9868
Val:   Loss=0.4975, Dice=0.6126, Acc=0.9762
Current LR: 0.000200


100%|██████████| 3855/3855 [31:19<00:00,  2.05it/s, loss=0.241]



Epoch [9/20]
Train: Loss=0.2711, Dice=0.7900, Acc=0.9873
Val:   Loss=0.5009, Dice=0.6113, Acc=0.9765
Current LR: 0.000200


100%|██████████| 3855/3855 [31:19<00:00,  2.05it/s, loss=0.255]



Epoch [10/20]
Train: Loss=0.2622, Dice=0.7970, Acc=0.9877
Val:   Loss=0.5030, Dice=0.6110, Acc=0.9762
Current LR: 0.000200


 56%|█████▌    | 2142/3855 [17:25<13:55,  2.05it/s, loss=0.299]


KeyboardInterrupt: 