In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import json
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [9]:
# --- Hyperparameters & Configuration ---
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_EPOCHS = 25
NUM_WORKERS = 2
IMAGE_HEIGHT = 288
IMAGE_WIDTH = 512
PIN_MEMORY = True
ANNOTATION_FILE = "/kaggle/input/masked-dataset/processed/annotations.json"
CHECKPOINT_PATH = "my_checkpoint.pth.tar"

print(f"Using device: {DEVICE}")

Using device: cuda


In [None]:
# Note: Augmentations are ONLY for the training set
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5), # Randomly flips the image horizontally
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Randomly changes brightness, etc.
])

# Validation set should NOT have augmentations
val_transform = transforms.Compose([
    transforms.ToTensor(),
])


In [None]:
class LaneDataset(Dataset):
    # Add 'root_dir' to the init method
    def __init__(self, annotations, root_dir, transform=None):
        self.annotations = annotations
        self.root_dir = root_dir # Store the root directory
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        # Get the RELATIVE paths from the json file
        img_relative_path = self.annotations[idx]['image']
        mask_relative_path = self.annotations[idx]['mask']
        
        # Create the FULL, ABSOLUTE path by joining the root_dir
        img_path = os.path.join(self.root_dir, img_relative_path)
        mask_path = os.path.join(self.root_dir, mask_relative_path)
        
        # The rest of the function is the same
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Apply transforms only to the image if they exist
        if self.transform:
            image = self.transform(image)
        
        mask = transforms.ToTensor()(mask)
        mask = torch.where(mask > 0, 1.0, 0.0)
        
        return image, mask

In [12]:
# U-Net Model Definition
class DoubleConv(nn.Module):
    """(Convolution => BatchNorm => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNET, self).__init__()
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = nn.MaxPool2d(2)
        self.conv1 = DoubleConv(64, 128)
        self.down2 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(128, 256)
        self.down3 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(256, 512)
        self.down4 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(512, 1024)
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv4 = DoubleConv(128, 64)
        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x2 = self.conv1(x2)
        x3 = self.down2(x2)
        x3 = self.conv2(x3)
        x4 = self.down3(x3)
        x4 = self.conv3(x4)
        x5 = self.down4(x4)
        x5 = self.bottleneck(x5)
        up1 = self.up1(x5)
        concat1 = torch.cat([up1, x4], dim=1)
        up1_conv = self.up_conv1(concat1)
        up2 = self.up2(up1_conv)
        concat2 = torch.cat([up2, x3], dim=1)
        up2_conv = self.up_conv2(concat2)
        up3 = self.up3(up2_conv)
        concat3 = torch.cat([up3, x2], dim=1)
        up3_conv = self.up_conv3(concat3)
        up4 = self.up4(up3_conv)
        concat4 = torch.cat([up4, x1], dim=1)
        up4_conv = self.up_conv4(concat4)
        logits = self.outc(up4_conv)
        return logits

In [None]:
# Training and Evaluation Functions
def train_fn(loader, model, optimizer, loss_fn, scaler):
    """Processes one epoch of training."""
    loop = tqdm(loader)
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().to(device=DEVICE)
        
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        loop.set_postfix(loss=loss.item())

def check_accuracy(loader, model, device="cuda"):
    """
    Checks accuracy on a validation set and returns the average Dice score.
    """
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval() # Set model to evaluation mode
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            
    # Calculate final metrics
    avg_dice_score = dice_score / len(loader)
    accuracy = num_correct / num_pixels * 100
    
    model.train() # Set model back to training mode
    
    # Return the key metric
    return avg_dice_score, accuracy

In [14]:
# --- Data Loading ---
KAGGLE_ROOT_DIR = "/kaggle/input/masked-dataset/"
with open(os.path.join(KAGGLE_ROOT_DIR, "processed/annotations.json"), "r") as f:
    all_annotations = json.load(f)

train_ann, val_ann = train_test_split(all_annotations, test_size=0.2, random_state=42)

train_dataset = LaneDataset(annotations=train_ann, root_dir=KAGGLE_ROOT_DIR, transform=train_transform)
val_dataset = LaneDataset(annotations=val_ann, root_dir=KAGGLE_ROOT_DIR, transform=val_transform)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=PIN_MEMORY, shuffle=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=PIN_MEMORY, shuffle=False
)

# --- Model, Loss, Optimizer ---
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

# --- NEW: Keep track of the best score ---
best_dice_score = -1.0 

# --- Training Loop ---
for epoch in range(NUM_EPOCHS):
    print(f"--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    train_fn(train_loader, model, optimizer, loss_fn, scaler)
    
    # --- MODIFIED: Get scores from the function and print them ---
    current_dice, current_acc = check_accuracy(val_loader, model, device=DEVICE)
    print(f"Validation Accuracy: {current_acc:.2f}")
    print(f"Validation Dice Score: {current_dice:.4f}")
    
    # --- NEW: Conditional Saving Logic ---
    if current_dice > best_dice_score:
        best_dice_score = current_dice
        print(f"✅ New best score! Saving model to {CHECKPOINT_PATH}")
        
        # Save the checkpoint only if the score is better
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        torch.save(checkpoint, CHECKPOINT_PATH)

  scaler = torch.cuda.amp.GradScaler()


--- Epoch 1/25 ---


  with torch.cuda.amp.autocast():
100%|██████████| 363/363 [03:59<00:00,  1.51it/s, loss=0.266]


Validation Accuracy: 96.34
Validation Dice Score: 0.3867
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 2/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.19] 


Validation Accuracy: 96.86
Validation Dice Score: 0.4851
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 3/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.168]


Validation Accuracy: 96.78
Validation Dice Score: 0.4609
--- Epoch 4/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.143]


Validation Accuracy: 97.02
Validation Dice Score: 0.5382
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 5/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.119] 


Validation Accuracy: 97.63
Validation Dice Score: 0.7066
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 6/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.117] 


Validation Accuracy: 97.50
Validation Dice Score: 0.6871
--- Epoch 7/25 ---


100%|██████████| 363/363 [03:29<00:00,  1.73it/s, loss=0.103] 


Validation Accuracy: 97.68
Validation Dice Score: 0.6926
--- Epoch 8/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.119] 


Validation Accuracy: 97.79
Validation Dice Score: 0.7223
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 9/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.0706]


Validation Accuracy: 97.86
Validation Dice Score: 0.7380
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 10/25 ---


100%|██████████| 363/363 [03:29<00:00,  1.74it/s, loss=0.107] 


Validation Accuracy: 97.87
Validation Dice Score: 0.7455
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 11/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.111] 


Validation Accuracy: 97.87
Validation Dice Score: 0.7424
--- Epoch 12/25 ---


100%|██████████| 363/363 [03:29<00:00,  1.73it/s, loss=0.0491]


Validation Accuracy: 97.88
Validation Dice Score: 0.7390
--- Epoch 13/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.0601]


Validation Accuracy: 97.87
Validation Dice Score: 0.7261
--- Epoch 14/25 ---


100%|██████████| 363/363 [03:29<00:00,  1.74it/s, loss=0.0883]


Validation Accuracy: 97.92
Validation Dice Score: 0.7494
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 15/25 ---


100%|██████████| 363/363 [03:29<00:00,  1.73it/s, loss=0.0519]


Validation Accuracy: 97.96
Validation Dice Score: 0.7495
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 16/25 ---


100%|██████████| 363/363 [03:29<00:00,  1.74it/s, loss=0.0753]


Validation Accuracy: 97.99
Validation Dice Score: 0.7571
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 17/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.0631]


Validation Accuracy: 97.99
Validation Dice Score: 0.7578
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 18/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.0581]


Validation Accuracy: 97.96
Validation Dice Score: 0.7586
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 19/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.0658]


Validation Accuracy: 97.93
Validation Dice Score: 0.7480
--- Epoch 20/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.0722]


Validation Accuracy: 98.01
Validation Dice Score: 0.7565
--- Epoch 21/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.0375]


Validation Accuracy: 97.99
Validation Dice Score: 0.7586
✅ New best score! Saving model to my_checkpoint.pth.tar
--- Epoch 22/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.044] 


Validation Accuracy: 97.95
Validation Dice Score: 0.7510
--- Epoch 23/25 ---


100%|██████████| 363/363 [03:28<00:00,  1.74it/s, loss=0.044] 


Validation Accuracy: 97.86
Validation Dice Score: 0.7429
--- Epoch 24/25 ---


100%|██████████| 363/363 [03:27<00:00,  1.75it/s, loss=0.0334]


Validation Accuracy: 97.95
Validation Dice Score: 0.7561
--- Epoch 25/25 ---


100%|██████████| 363/363 [03:27<00:00,  1.75it/s, loss=0.0394]


Validation Accuracy: 97.95
Validation Dice Score: 0.7523
