In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# IMPORTANT: Move the dataset class definition to a separate file (e.g., dataset.py)
# Or disable multiprocessing by setting num_workers=0 in DataLoader
# The error happens because multiprocessing can't pickle the dataset class properly

class LaneLineDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None, img_size=(256, 256)):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.img_size = img_size
        
        # Get list of image files
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))])
        self.label_files = sorted([f for f in os.listdir(label_dir) if f.endswith('.txt')])
        
        # Ensure matching files
        img_basenames = [os.path.splitext(f)[0] for f in self.image_files]
        lbl_basenames = [os.path.splitext(f)[0] for f in self.label_files]
        
        common_basenames = list(set(img_basenames) & set(lbl_basenames))
        
        # Filter to include only files with both image and label
        self.image_files = [f for f in self.image_files if os.path.splitext(f)[0] in common_basenames]
        self.label_files = [os.path.splitext(f)[0] + '.txt' for f in self.image_files]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        label_path = os.path.join(self.label_dir, self.label_files[idx])

        # Load and transform image
        image = Image.open(img_path).convert("RGB")
        orig_width, orig_height = image.size
        
        if self.transform:
            image = self.transform(image)
        
        # Load label (YOLO format)
        lane_lines = []
        with open(label_path, "r") as f:
            for line in f:
                values = list(map(float, line.strip().split()))
                if not values:
                    continue
                    
                class_id = int(values[0])  # First value is the class
                # Rest are polygon coordinates in normalized form [x1, y1, x2, y2, ...]
                polygon = torch.tensor(values[1:])
                
                # Ensure even number of coordinates
                if len(polygon) % 2 != 0:
                    polygon = polygon[:-1]  # Remove last element if odd
                
                lane_lines.append(polygon)
        
        # Create lane existence indicator
        max_lanes = 5
        lane_exists = torch.zeros(max_lanes)
        for i in range(min(len(lane_lines), max_lanes)):
            lane_exists[i] = 1.0
        
        # Convert list of polygons to padded tensor
        points_per_lane = 30  # 15 (x,y) pairs, more points for better curve definition
        
        # Create tensor filled with padding value (-1)
        lanes_tensor = torch.ones((max_lanes, points_per_lane)) * -1
        
        # Fill in the actual lane data
        for i, lane in enumerate(lane_lines):
            if i >= max_lanes:
                break  # Skip if more than max_lanes
            
            # Ensure we don't exceed points_per_lane
            points_to_copy = min(lane.shape[0], points_per_lane)
            lanes_tensor[i, :points_to_copy] = lane[:points_to_copy]
        
        return image, lanes_tensor, lane_exists

# Using ResNet-like blocks for better feature extraction
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ImprovedLaneNet(nn.Module):
    def __init__(self, max_lanes=5, points_per_lane=30):
        super(ImprovedLaneNet, self).__init__()
        
        self.max_lanes = max_lanes
        self.points_per_lane = points_per_lane

        # Initial convolution
        self.initial = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # Residual blocks for feature extraction
        self.layer1 = self._make_layer(32, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        
        # Lane existence prediction
        self.exists_head = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, max_lanes),
            nn.Sigmoid()
        )
        
        # Lane coordinate prediction
        self.lane_head = nn.Sequential(
            nn.AdaptiveAvgPool2d((4, 4)),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, max_lanes * points_per_lane)
        )

    def _make_layer(self, in_channels, out_channels, num_blocks, stride=1):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        # Lane existence prediction
        existence = self.exists_head(x)
        
        # Lane coordinate prediction
        lanes = self.lane_head(x)
        lanes = lanes.view(-1, self.max_lanes, self.points_per_lane)
        
        return lanes, existence

# Combined loss function for lane detection
class CombinedLaneLoss(nn.Module):
    def __init__(self, existence_weight=0.5, coordinate_weight=1.0):
        super(CombinedLaneLoss, self).__init__()
        self.existence_weight = existence_weight
        self.coordinate_weight = coordinate_weight
        self.bce = nn.BCELoss()
        self.mse = nn.MSELoss(reduction='none')
    
    def forward(self, pred_coords, pred_exists, target_coords, target_exists):
        # Lane existence loss
        existence_loss = self.bce(pred_exists, target_exists)
        
        # Coordinate loss with masking for both valid lanes and valid points
        # Create a 2D mask: [batch, max_lanes, points] where lanes exist and points are valid
        lane_mask = target_exists.unsqueeze(-1).expand_as(target_coords)
        point_mask = (target_coords != -1).float()
        mask = lane_mask * point_mask
        
        # Apply MSE loss only on valid points of valid lanes
        coord_loss = self.mse(pred_coords, target_coords)
        masked_loss = coord_loss * mask
        
        # Average over valid points
        valid_points = mask.sum()
        if valid_points > 0:
            coordinate_loss = masked_loss.sum() / valid_points
        else:
            coordinate_loss = torch.tensor(0.0, device=pred_coords.device)
        
        # Combined loss
        total_loss = (self.existence_weight * existence_loss + 
                      self.coordinate_weight * coordinate_loss)
        
        return total_loss, existence_loss, coordinate_loss

# Custom collate function
def custom_collate_fn(batch):
    images = []
    lanes = []
    lane_exists = []
    
    for image, lane, exists in batch:
        images.append(image)
        lanes.append(lane)
        lane_exists.append(exists)
    
    images = torch.stack(images)
    lanes = torch.stack(lanes)
    lane_exists = torch.stack(lane_exists)
    
    return images, lanes, lane_exists

# Training function with learning rate scheduler
def train_model(model, train_loader, val_loader, num_epochs=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Loss and optimizer
    criterion = CombinedLaneLoss(existence_weight=0.5, coordinate_weight=1.0)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': []}
    
    for epoch in range(num_epochs):
        # Train phase
        model.train()
        train_loss = 0.0
        
        for images, lane_coords, lane_exists in train_loader:
            images = images.to(device)
            lane_coords = lane_coords.to(device)
            lane_exists = lane_exists.to(device)
            
            # Forward pass
            pred_coords, pred_exists = model(images)
            loss, exist_loss, coord_loss = criterion(pred_coords, pred_exists, lane_coords, lane_exists)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for images, lane_coords, lane_exists in val_loader:
                images = images.to(device)
                lane_coords = lane_coords.to(device)
                lane_exists = lane_exists.to(device)
                
                pred_coords, pred_exists = model(images)
                loss, exist_loss, coord_loss = criterion(pred_coords, pred_exists, lane_coords, lane_exists)
                
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        history['val_loss'].append(val_loss)
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_lane_model.pth')
        
        print(f'Epoch {epoch+1}/{num_epochs}: ' 
              f'Train Loss: {train_loss:.4f}, '
              f'Validation Loss: {val_loss:.4f}')
    
    return history

# Visualization functions
def visualize_lane_predictions(model, dataloader, num_samples=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    # Get samples for visualization
    samples = []
    with torch.no_grad():
        for images, lane_coords, lane_exists in dataloader:
            batch_size = images.shape[0]
            for i in range(min(batch_size, num_samples - len(samples))):
                samples.append((
                    images[i].clone(),
                    lane_coords[i].clone(),
                    lane_exists[i].clone()
                ))
            if len(samples) >= num_samples:
                break
    
    # Create figure for visualization
    fig, axes = plt.subplots(num_samples, 2, figsize=(12, 4 * num_samples))
    
    with torch.no_grad():
        for i, (image, ground_truth, gt_exists) in enumerate(samples):
            # Get prediction
            image_tensor = image.unsqueeze(0).to(device)
            pred_coords, pred_exists = model(image_tensor)
            pred_coords = pred_coords[0].cpu()
            pred_exists = pred_exists[0].cpu()
            
            # Convert to numpy for visualization
            img_np = image.permute(1, 2, 0).numpy()
            
            # Denormalize image for better visualization
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img_np = img_np * std + mean
            img_np = np.clip(img_np, 0, 1)
            
            # Ground truth image
            axes[i, 0].imshow(img_np)
            axes[i, 0].set_title("Ground Truth")
            plot_lanes(axes[i, 0], ground_truth, gt_exists, img_np.shape[1], img_np.shape[0])
            axes[i, 0].axis('off')
            
            # Prediction image
            axes[i, 1].imshow(img_np)
            axes[i, 1].set_title("Prediction")
            plot_lanes(axes[i, 1], pred_coords, pred_exists, img_np.shape[1], img_np.shape[0])
            axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.savefig('lane_predictions.png')
    plt.show()

def plot_lanes(ax, lane_data, lane_exists, img_width, img_height):
    colors = ['r', 'g', 'b', 'c', 'm']
    
    for lane_idx, lane in enumerate(lane_data):
        # Skip invalid lanes
        if lane_idx >= len(lane_exists) or lane_exists[lane_idx] < 0.5:
            continue
            
        lane_np = lane.detach().cpu().numpy()
        
        points_x = []
        points_y = []
        
        # Process points in pairs (x,y)
        for j in range(0, len(lane_np), 2):
            if j+1 < len(lane_np) and lane_np[j] != -1 and lane_np[j+1] != -1:
                # Unnormalize coordinates
                x = lane_np[j] * img_width
                y = lane_np[j+1] * img_height
                points_x.append(x)
                points_y.append(y)
        
        if points_x and points_y:
            # Plot the lane
            ax.plot(points_x, points_y, color=colors[lane_idx % len(colors)], linewidth=3, marker='.')

def main():
    # Define transforms with data augmentation
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = LaneLineDataset(
        image_dir="../ComputerVisionGroup/train/images",
        label_dir="../ComputerVisionGroup/train/labels",
        transform=train_transform
    )
    
    # Split into train and validation
    val_size = int(0.1 * len(train_dataset))
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
    
    test_dataset = LaneLineDataset(
        image_dir="../Maanasa/train/images",
        label_dir="../Maanasa/train/labels",
        transform=test_transform
    )
    
    # Create dataloaders WITH NUM_WORKERS=0 TO FIX THE ERROR
    batch_size = 16
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=custom_collate_fn,
        num_workers=0  # Changed from 2 to 0 to fix multiprocessing issue
    )
    
    val_dataloader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        collate_fn=custom_collate_fn,
        num_workers=0  # Changed from 2 to 0
    )
    
    test_dataloader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        collate_fn=custom_collate_fn,
        num_workers=0  # Changed from 2 to 0
    )
    
    # Visualization dataloader with smaller batch size
    test_dataloader_viz = DataLoader(
        test_dataset, 
        batch_size=1, 
        shuffle=True, 
        collate_fn=custom_collate_fn,
        num_workers=0  # Changed from 2 to 0
    )
    
    # Initialize model
    model = ImprovedLaneNet(max_lanes=5, points_per_lane=30)
    
    # Train the model
    history = train_model(
        model, 
        train_dataloader, 
        val_dataloader, 
        num_epochs=30
    )
    
    # Load best model for visualization
    model.load_state_dict(torch.load('best_lane_model.pth'))
    
    # Visualize predictions
    visualize_lane_predictions(model, test_dataloader_viz, num_samples=5)

if __name__ == "__main__":
    main()