In [1]:
import os
import json
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.amp 
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
from albumentations.pytorch import ToTensorV2
import random  

In [None]:
class LaneDataset(Dataset):
    def __init__(self, annotation_path, root_dir, transform=None):
        with open(annotation_path, 'r') as f:
            self.annotations = json.load(f)  
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations) 

    def __getitem__(self, idx):
        ann = self.annotations[idx] 
        img_path = os.path.join(self.root_dir, ann['image']).replace('\\', '/')
        mask_path = os.path.join(self.root_dir, ann['mask']).replace('\\', '/')

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask = np.array(Image.open(mask_path).convert("L"))
        except FileNotFoundError as e:
            print(f"‚ö†Ô∏è Warning: Skipping invalid sample {idx}. Error: {e} (img: {img_path}, mask: {mask_path})")
            image = np.zeros((590, 1640, 3), dtype=np.uint8)
            mask = np.zeros((590, 1640), dtype=np.uint8)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        mask = (mask > 0).float().unsqueeze(0)  
        return image, mask

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)
        for feature in features:
            self.downs.append(self.double_conv(in_channels, feature))
            in_channels = feature
        self.bottleneck = self.double_conv(features[-1], features[-1] * 2)
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(self.double_conv(feature * 2, feature))
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
        
    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])
            x = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](x)
        return self.final_conv(x)
    
    @staticmethod
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        intersection = (preds * targets).sum()
        dice = (2. * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
        return 1 - dice

def loss_fn(preds, targets):
    bce = nn.BCEWithLogitsLoss()(preds, targets)
    dice = DiceLoss()(preds, targets)
    return bce + dice

def dice_score(preds, targets, threshold=0.5):
    preds = (torch.sigmoid(preds) > threshold).float()
    intersection = (preds * targets).sum()
    return (2. * intersection) / (preds.sum() + targets.sum() + 1e-8)

def pixel_accuracy(preds, targets, threshold=0.5):
    preds = (torch.sigmoid(preds) > threshold).float()
    correct = (preds == targets).float()
    return correct.sum() / correct.numel()

In [None]:
def train_one_epoch(loader, model, optimizer, scaler, device):
    loop = tqdm(loader, leave=True)
    total_loss, total_dice, total_acc = 0, 0, 0
    for data, targets in loop:
        data, targets = data.to(device), targets.to(device)
        with torch.amp.autocast('cuda'):  
            preds = model(data)
            loss = loss_fn(preds, targets)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        total_dice += dice_score(preds, targets).item()
        total_acc += pixel_accuracy(preds, targets).item()
        loop.set_postfix(loss=loss.item())
    return total_loss / len(loader), total_dice / len(loader), total_acc / len(loader)

def evaluate(loader, model, device):
    model.eval()
    total_loss, total_dice, total_acc = 0, 0, 0
    with torch.no_grad():
        for data, targets in loader:
            data, targets = data.to(device), targets.to(device)
            with torch.amp.autocast('cuda'):  
                preds = model(data)
            loss = loss_fn(preds, targets)
            total_loss += loss.item()
            total_dice += dice_score(preds, targets).item()
            total_acc += pixel_accuracy(preds, targets).item()
    model.train()
    return total_loss / len(loader), total_dice / len(loader), total_acc / len(loader)

In [None]:
def main():
    IMAGE_HEIGHT, IMAGE_WIDTH = 160, 320
    LEARNING_RATE = 2e-4
    BATCH_SIZE = 32
    NUM_EPOCHS = 20
    NUM_WORKERS = 2 
    PIN_MEMORY = True
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    CHECKPOINT_PATH = "/kaggle/input/culane-trained-model/culane_trainedModel/best_model_culane.pth.tar"
    SUBSAMPLE_TRAIN = 1  
    
    culane_root = "/kaggle/input/culane"
    train_annotations = "/kaggle/input/culane/CULane/culane_train_annotations.json"
    val_annotations = "/kaggle/input/culane/CULane/culane_val_annotations.json"
    
    train_dataset_temp = LaneDataset(train_annotations, culane_root, transform=None)
    val_dataset_temp = LaneDataset(val_annotations, culane_root, transform=None)
    train_size = int(len(train_dataset_temp) * SUBSAMPLE_TRAIN)
    print(f"üìä Dataset Sizes: Full Train={len(train_dataset_temp)}, Subsampled Train={train_size}, Val={len(val_dataset_temp)}")
    
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    val_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    
    train_dataset = LaneDataset(train_annotations, culane_root, transform=train_transform)
    if SUBSAMPLE_TRAIN < 1.0:
        random.seed(42)
        train_dataset.annotations = random.sample(train_dataset.annotations, train_size)
        print(f"‚úÖ Subsampled train to {len(train_dataset.annotations)} samples")
    val_dataset = LaneDataset(val_annotations, culane_root, transform=val_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    
    model = UNet(in_channels=3, out_channels=1).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.amp.GradScaler('cuda')
    scheduler = ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)
    
    best_dice_score = -1.0
    start_epoch = 8
    
    if os.path.exists(CHECKPOINT_PATH):
        print("üîÑ Resuming from checkpoint...")
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        best_dice_score = checkpoint.get("best_dice", -1.0)
        start_epoch = checkpoint.get("epoch", 0)
        print(f"Resuming from epoch {start_epoch + 1}")
    
    for epoch in range(start_epoch, NUM_EPOCHS):
        train_loss, train_dice, train_acc = train_one_epoch(train_loader, model, optimizer, scaler, DEVICE)
        val_loss, val_dice, val_acc = evaluate(val_loader, model, DEVICE)
        print(f"\nEpoch [{epoch+1}/{NUM_EPOCHS}]")
        print(f"Train: Loss={train_loss:.4f}, Dice={train_dice:.4f}, Acc={train_acc:.4f}")
        print(f"Val:   Loss={val_loss:.4f}, Dice={val_dice:.4f}, Acc={val_acc:.4f}")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        if val_dice > best_dice_score:
            best_dice_score = val_dice
            print(f"‚úÖ New best model! Saving to {CHECKPOINT_PATH} (Val Dice={val_dice:.4f})")
            torch.save({
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "best_dice": best_dice_score,
                "epoch": epoch + 1,
            }, CHECKPOINT_PATH)
        
        scheduler.step(val_dice)
    
    print(f"\nüèÅ Training completed. Loading best model from {CHECKPOINT_PATH} for final validation evaluation.")
    if os.path.exists(CHECKPOINT_PATH):
        best_checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
        model.load_state_dict(best_checkpoint["state_dict"])
        best_epoch = best_checkpoint.get("epoch", "N/A")
        print(f"üìà Best model from epoch {best_epoch} with Val Dice={best_checkpoint['best_dice']:.4f}")
    else:
        print("‚ö†Ô∏è No checkpoint found; using final trained model for evaluation.")
    
    model.eval()
    final_loss, final_dice, final_acc = evaluate(val_loader, model, DEVICE)
    print("\nüîé Final Evaluation on Validation Set (using best model):")
    print(f"Val: Loss={final_loss:.4f}, Dice={final_dice:.4f}, Acc={final_acc:.4f}")


In [7]:
if __name__ == "__main__":
    main()

üìä Dataset Sizes: Full Train=88880, Subsampled Train=88880, Val=9675
üîÑ Resuming from checkpoint...
Resuming from epoch 7


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 3689/3704 [29:38<00:07,  2.08it/s, loss=0.335]


Epoch [7/20]
Train: Loss=0.2857, Dice=0.7780, Acc=0.9868
Val:   Loss=0.4949, Dice=0.6144, Acc=0.9767
Current LR: 0.000200


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.331]



Epoch [8/20]
Train: Loss=0.2749, Dice=0.7865, Acc=0.9873
Val:   Loss=0.4928, Dice=0.6150, Acc=0.9767
Current LR: 0.000200


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.285]



Epoch [9/20]
Train: Loss=0.2647, Dice=0.7947, Acc=0.9878
Val:   Loss=0.4997, Dice=0.6130, Acc=0.9769
Current LR: 0.000200


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.199]



Epoch [10/20]
Train: Loss=0.2557, Dice=0.8018, Acc=0.9882
Val:   Loss=0.4983, Dice=0.6133, Acc=0.9764
Current LR: 0.000200


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.253]



Epoch [11/20]
Train: Loss=0.2465, Dice=0.8091, Acc=0.9886
Val:   Loss=0.4975, Dice=0.6142, Acc=0.9764
Current LR: 0.000200


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.3]  



Epoch [12/20]
Train: Loss=0.2389, Dice=0.8151, Acc=0.9890
Val:   Loss=0.5096, Dice=0.6062, Acc=0.9764
Current LR: 0.000200


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.351]



Epoch [13/20]
Train: Loss=0.2146, Dice=0.8345, Acc=0.9901
Val:   Loss=0.4932, Dice=0.6179, Acc=0.9771
Current LR: 0.000100


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:44<00:00,  2.08it/s, loss=0.275] 



Epoch [14/20]
Train: Loss=0.2030, Dice=0.8437, Acc=0.9907
Val:   Loss=0.5045, Dice=0.6116, Acc=0.9765
Current LR: 0.000100


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.227] 



Epoch [15/20]
Train: Loss=0.1948, Dice=0.8502, Acc=0.9910
Val:   Loss=0.5006, Dice=0.6142, Acc=0.9768
Current LR: 0.000100


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.217] 



Epoch [16/20]
Train: Loss=0.1882, Dice=0.8554, Acc=0.9913
Val:   Loss=0.5086, Dice=0.6104, Acc=0.9763
Current LR: 0.000100


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.225] 



Epoch [17/20]
Train: Loss=0.1816, Dice=0.8606, Acc=0.9916
Val:   Loss=0.5110, Dice=0.6100, Acc=0.9764
Current LR: 0.000100


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:40<00:00,  2.08it/s, loss=0.251] 



Epoch [18/20]
Train: Loss=0.1672, Dice=0.8720, Acc=0.9923
Val:   Loss=0.5154, Dice=0.6069, Acc=0.9765
Current LR: 0.000050


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.121] 



Epoch [19/20]
Train: Loss=0.1601, Dice=0.8776, Acc=0.9927
Val:   Loss=0.5152, Dice=0.6090, Acc=0.9764
Current LR: 0.000050


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3704/3704 [29:41<00:00,  2.08it/s, loss=0.132] 



Epoch [20/20]
Train: Loss=0.1550, Dice=0.8817, Acc=0.9929
Val:   Loss=0.5201, Dice=0.6078, Acc=0.9762
Current LR: 0.000050

üèÅ Training completed. Loading best model from /kaggle/input/culane-trained-model/culane_trainedModel/best_model_culane.pth.tar for final validation evaluation.
üìà Best model from epoch 6 with Val Dice=0.6183

üîé Final Evaluation on Validation Set (using best model):
Val: Loss=0.4864, Dice=0.6183, Acc=0.9770
