## Milestone 1: Problem Definition, Dataset Selection, and Data Exploration

LIS 640 - Introduction to Applied Deep Learning

Due 2/14/25

## **Overview**
In this milestone, you will:
1. **Define a deep learning problem** where AI can make a meaningful impact.
2. **Identify three datasets** that fit your topic and justify their relevance.
3. **Explore and visualize** the datasets to understand their structure.
4. **Implement a PyTorch Dataset class** to prepare data for deep learning.

This notebook provides an example of **fuel-efficient car usage** to illustrate what is expected.


## **Step 1: Define Your Deep Learning Problem**
Write a paragraph explaining:
- **Why your chosen topic is important.**
- **How deep learning can help solve the problem.**

### **Example Problem Statement: Fuel-Efficient Car Usage**
*Fuel efficiency is a major factor in reducing carbon emissions and lowering fuel costs. Drivers often adopt inefficient driving patterns, wasting fuel through unnecessary acceleration, braking, or idling. A deep learning model could analyze driving behavior and suggest optimizations in real-time, helping individuals improve their fuel economy.*

➡ **Write your problem statement below:**  

### **Lane Line Detection**
Lane line detection is a crucial aspect of most autonomous driving systems. It involves identifying and tracking the lane markings on the road to ensure the vehicle stays within its designated lane. Accurate lane detection is essential for maintaining safe driving conditions, enabling features like lane-keeping assistance, adaptive cruise control, and autonomous navigation. However, challenges such as varying lighting conditions, occlusions (e.g., by other vehicles or debris), and poorly marked or faded lane lines can make this task complex. A deep learning model trained on annotated road images can be used to detect lane lines in real-time, providing the vehicle with the necessary information to make informed decisions and navigate safely. By improving the robustness and accuracy of lane detection systems, we can enhance the safety and reliability of autonomous vehicles, ultimately contributing to safer roads and more efficient transportation systems.

## **Step 2: Identify and Justify Three Relevant Datasets**
Find three datasets that provide useful information for solving your problem.  
For each dataset, include:
1. A **short description** of what it contains.
2. A **link to the dataset**.
3. **Why this dataset is useful for your problem.**

### **Example Datasets for Fuel Efficiency**

- **Dataset 1: Vehicle Trajectory Data (NGSIM US 101 Dataset)**
	- Description: This dataset contains detailed vehicle trajectory data collected on a segment of the U.S. Highway 101 in Los Angeles, California. It includes precise location information of each vehicle within the study area every one-tenth of a second, capturing detailed lane-changing and car-following behaviors.
	- Source: [U.S. Department of Transportation - NGSIM Program](https://data.transportation.gov/stories/s/Next-Generation-Simulation-NGSIM-Open-Data/i5zb-xe34/)
	- Justification: Analyzing this data can help identify driving patterns that affect fuel efficiency, such as frequent lane changes or abrupt braking.

- **Dataset 2: Climate & Air Quality Data**  
  - Description: Contains CO2 emissions and climate-related metrics across different regions.
	•	Source: [U.S. Historical Climatology Network](https://www.ncei.noaa.gov/products/land-based-station/us-historical-climatology-network)
	•	Justification: Can correlate driving behavior with environmental impact. Provides environmental context to fuel consumption.

- **Dataset 3: Automobile Dataset (UCI Machine Learning Repository)**
  - Description: This dataset includes various characteristics of automobiles, such as engine size, horsepower, weight, and fuel consumption. It also provides information on insurance risk ratings and normalized losses in use as compared to other cars.
  - Source: [UCI Machine Learning Repository - Automobile Dataset](https://archive.ics.uci.edu/dataset/10/automobile)
  - Justification: The dataset’s detailed vehicle specifications and performance metrics can be used to analyze how different factors influence fuel efficiency, aiding in the development of predictive models.

➡ **Find and document three datasets for your problem below:**

- **Dataset 1: Lane Line detection Dataset**
	- Description: This dataset has 100 images from German roads including annotated lane lines in each image. The dataset is diverse including curved roads.
	- Source: [New-Lane detection Computer Vision](https://universe.roboflow.com/maanasa-prasad/new-lane-detection)
	- Justification: Provides lane line data from different environments and includes curved lane lines.

- **Dataset 2: Indian Roads Dataset**
	- Description: This dataset contains over 650 labeled lane images of various road environments, such as curves, traffic, and more. It is collected from real scenarios across multiple cities in India, and includes images with lane lines that are manually annotated with polygons.
	- Source: [Lane Detection Computer Vision Project](https://universe.roboflow.com/autonomous-umjvo/lane-detection-2-qpx6p)
	- Justification: Provides lane line data from a different road driving setting, allowing us to get a better variety of data (i.e. left/right hand drive, different road markings, etc).

- **Dataset 3: US Roads Dataset**
	- Description: This dataset contains over 100 labeled lane images from highway driving in the United States in clear conditions. It includes images with the relevant lanes and lane lines annotated with polygons.
	- Source: [Lane Detection Computer Vision Group](https://universe.roboflow.com/computer-vision-controls-research-group/lane-detection-tjaa0)
	- Justification: Provides lane line data from a more local and potentially more relevant setting where lane line detection might be more necessary – ADAS/hands-off cruise control and steering for long highway drives.



### We also found these datasets, but have not analyzed them yet due to the size and might explore them later if possible
- **Dataset 1: CurveLanes Dataset from Kaggle**
	- Description: This dataset has 150k lane images of difficult scenarios such as curves and multi-lanes in traffic. It is collected from real urban and highway scenarios in multiple cities in China. The dataset includes images with lane lines which are manually annotated with natural cubic splines. The labels include two key x, y coordinates of the lane marking.
	- Source: [Kaggle CurveLanes Dataset](https://www.kaggle.com/datasets/bnyadmohammed/curvelanes/data) and uploaded from [Github CurveLanes Dataset](https://github.com/SoulmateB/CurveLanes)
	- Justification: The dataset includes more difficult to detect lane lines in more complex and variety of scenarios

- **Dataset 2: Waymo Open Dataset - Motion Dataset**
	- Description: This dataset has lane line data which was used internally by Waymo for their training purposes which has been open sourced. It includes lane connections, lane boundaries and lane neighbors. It provies information of multiple x, y coordinates along the lane line as labels.
	- Source [Waymo Open Dataset](https://github.com/waymo-research/waymo-open-dataset?tab=readme-ov-file)
	- Justification: This data includes detailed information on more lane line data but also includes features such as lane neighbors and lane connections.

- **Dataset 3: TuSimple Lane Line Dataset**
	- Description: The dataset consists of 6,408 road images on US highways and includes images from different weather conditions. Dataset includes annotated lane lines.
	- Source [TuSimple Dataset on Kaggle](https://www.kaggle.com/datasets/manideep1108/tusimple)
	- Justification: This dataset emphasizes variation in weather conditions which the other datasets do not mention which will allow our model to generalize better.


## **Step 3: Explore and Visualize Your Data**
Understanding the structure of your dataset is crucial. Perform the following tasks:
1. **Summarize dataset statistics:**
   - Number of samples
   - Number of features
   - Data types (numerical, categorical, text, etc.)
   - Ranges of values (min/max)
   - Missing values

2. **Create visualizations:**
   - Histograms: Show feature distributions.
   - Scatter plots: Explore relationships between key variables.
   - (Optional) PCA: Visualize high-dimensional data in 2D.

### **Example Exploration Code**
Modify this code to work with your dataset.


In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random

datasets = {
    "German": {
        "image_dir": "../Maanasa/train/images",
        "label_dir": "../Maanasa/train/labels"
    },
    "Indian": {
        "image_dir": "../Autonomous/train/images",
        "label_dir": "../Autonomous/train/labels"
    },
    "US": {
        "image_dir": "../ComputerVisionGroup/train/images",
        "label_dir": "../ComputerVisionGroup/train/labels"
    }
}

total_images = 0
total_labels = 0
all_dims = []

for name, paths in datasets.items():
    image_dir = paths["image_dir"]
    label_dir = paths["label_dir"]

    image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
    label_files = [f for f in os.listdir(label_dir) if f.endswith('.txt')]

    total_images += len(image_files)
    total_labels += len(label_files)

    print(f"{name} Dataset:")
    print(f"  Total images: {len(image_files)}")
    print(f"  Total label files: {len(label_files)}")


This gives us an understanding of how much data we have for each dataset. If the datasets have certain trends, we see that the Indian roads dataset is the primary source of data, meaning that our model may be more heavily suited for inference in those conditions. This suggests that we may need to find more data, such as the large datasets we listed above (Waymo, TuSimple, Kaggle curvelines), and incorperate those into our combined dataset in addition to these, to get more accuracy for all different conditions the vehicle may be in.

In [None]:
def plot_lane_points(ax, image_path, label_path):
    """Plot lane lines on an image using the given axes."""
    img = Image.open(image_path)
    img = np.array(img)

    lanes = []
    with open(label_path, "r") as f:
        for line in f:
            parts = list(map(float, line.strip().split()))
            if len(parts) < 3:
                continue
            points = np.array(parts[1:]).reshape(-1, 2)
            lanes.append(points)

    ax.imshow(img)
    for lane in lanes:
        ax.plot(lane[:, 0] * img.shape[1], lane[:, 1] * img.shape[0], 
                 marker='o', linestyle='-', markersize=4, label="Lane")

fig, axes = plt.subplots(1, 3, figsize=(12, 5))

for i, (name, dataset) in enumerate(datasets.items()):
    image_dir = dataset["image_dir"]
    label_dir = dataset["label_dir"]
    image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
    random_file = random.choice(image_files)
    label_file = random_file.replace('.jpg', '.txt').replace('.png', '.txt')
    plot_lane_points(axes[i], os.path.join(image_dir, random_file), os.path.join(label_dir, label_file))
    axes[i].set_title(name)

plt.tight_layout()
plt.show()


This visualization gives us a visual understanding of the labels for all of our datasets. We see that the Indian roads data set seems to use a different format for its labels, as it uses one big polygon for the area inside the lane lines, while the other two use polygons for each lane line and lane line segment. This may cause issues when we train our model, and we may need to reconfigure the Indian roads labels.

In [None]:
all_lane_counts = []

for name, paths in datasets.items():
    label_dir = paths["label_dir"]
    label_files = [f for f in os.listdir(label_dir) if f.endswith('.txt')]

    lane_counts = []
    for label_file in label_files:
        with open(os.path.join(label_dir, label_file), "r") as f:
            lane_counts.append(len(f.readlines()))

    all_lane_counts.extend(lane_counts)

min_value = min(all_lane_counts)
max_value = max(all_lane_counts)

bins = np.arange(min_value, max_value + 2) - 0.5

plt.hist(all_lane_counts, bins=bins, edgecolor="black", align="mid")
plt.xlabel("Number of lanes per image")
plt.ylabel("Frequency")
plt.title("Lane Count Distribution Across All Datasets")
plt.xticks(np.arange(min_value, max_value + 1))
plt.show()

This gives us an understanding of how many lanes each image has, and we see that predominantly it is one lane per image and this reflects well on our goal as we are trying to develop a model for lane keeping assistance, and thus focusing on the lane that we are in is more important that picking up lanes next to the current one or in opposing directions.

## **Step 4: Implement a PyTorch Dataset Class**
Follow [this tutorial](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) to prepare data for deep learning by creating a PyTorch Dataset class that:
- Loads data from a CSV or another source.
- Applies preprocessing (e.g., normalization, missing value handling).
- Returns samples in a PyTorch-compatible format.

### **Example PyTorch Dataset Implementation**
Modify this template for your dataset.


In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from torchvision import transforms

class LaneDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        """
        Args:
            image_dir (str): Path to the directory with images.
            label_dir (str): Path to the directory with lane line annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

        image_filenames = {f.split('.')[0]: f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))}
        label_filenames = {f.split('.')[0]: f for f in os.listdir(label_dir) if f.endswith('.txt')}
        self.filenames = sorted(image_filenames.keys() & label_filenames.keys())

        self.image_files = [image_filenames[f] for f in self.filenames]
        self.label_files = [label_filenames[f] for f in self.filenames]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")

        label_path = os.path.join(self.label_dir, self.label_files[idx])
        lanes = []
        with open(label_path, "r") as f:
            for line in f:
                parts = list(map(float, line.strip().split()))
                if len(parts) < 3:
                    continue
                class_id = int(parts[0])
                points = parts[1:]
                lanes.append((class_id, torch.tensor(points).view(-1, 2)))
        if self.transform:
            image = self.transform(image)

        return image, lanes

transform = transforms.Compose([
    transforms.Resize((360, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

# German roads dataset
german_dataset = LaneDataset(
    image_dir="../Maanasa/train/images",
    label_dir="../Maanasa/train/labels",
    transform=transform
)

# Indian roads dataset
indian_dataset = LaneDataset(
    image_dir="../Autonomous/train/images",
    label_dir="../Autonomous/train/labels",
    transform=transform
)

# US roads dataset
us_dataset = LaneDataset(
    image_dir="../ComputerVisionGroup/train/images",
    label_dir="../ComputerVisionGroup/train/labels",
    transform=transform
)

print(len(german_dataset))
print(len(indian_dataset))
print(len(us_dataset))


## **Final Submission**
Upload your submission for Milestone 1 to Canvas. 
Submit this notebook with:

✅ A **clear problem statement**.  
✅ Three **documented datasets** with justification.  
✅ **Exploratory analysis** with summary statistics & visualizations.  
✅ A **PyTorch Dataset class** for preparing data.  

📌 Use the provided example to guide your work. Happy Deep Learning! 🚀

In [None]:
import os
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# --- Paths to datasets ---
datasets = {
    "Dataset1": {"image_dir": "../Maanasa/train/images", "label_dir": "../Maanasa/train/labels"},
    "Dataset2": {"image_dir": "../Autonomous/train/images", "label_dir": "../Autonomous/train/labels"},
    "Dataset3": {"image_dir": "../ComputerVisionGroup/train/images", "label_dir": "../ComputerVisionGroup/train/labels"},
}

# --- Pad or truncate lane points to exactly 20 ---
def fix_lane_points(points, num_points=20):
    points = np.array(points)  # Convert list to numpy array
    if points.shape[0] < num_points:  # If fewer than 20, pad with last point
        pad_amount = num_points - points.shape[0]
        pad_values = np.tile(points[-1], (pad_amount, 1))  # Repeat last point
        points = np.vstack((points, pad_values))
    elif points.shape[0] > num_points:  # If more than 20, truncate
        points = points[:num_points]
    return points

# --- Pad lane tensors ---
def pad_lanes(lanes, max_lanes=5, num_points=20):
    padded = np.zeros((max_lanes, num_points, 2))  # Default: all zeros
    num_lanes = min(len(lanes), max_lanes)
    
    for i in range(num_lanes):
        padded[i] = fix_lane_points(lanes[i], num_points)  # Fix lane size

    return torch.tensor(padded, dtype=torch.float32)

# --- Dataset Class ---
class LaneDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None, max_lanes=5, num_points=20):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.max_lanes = max_lanes
        self.num_points = num_points
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)

        # Load lane labels
        label_file = self.image_files[idx].replace('.jpg', '.txt').replace('.png', '.txt')
        label_path = os.path.join(self.label_dir, label_file)
        lanes = []

        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                for line in f:
                    parts = list(map(float, line.strip().split()))
                    if len(parts) < 3:
                        continue  # Skip invalid lines
                    points = np.array(parts[1:]).reshape(-1, 2)  # Convert to (N, 2)
                    lanes.append(points)

        # Convert to tensor and pad
        lanes = pad_lanes(lanes, self.max_lanes, self.num_points)
        return image, lanes

# --- Custom Collate Function ---
def custom_collate(batch):
    images, labels = zip(*batch)
    images = torch.stack(images)  
    return images, torch.stack(labels)  

# --- DataLoader ---
transform = transforms.Compose([
    transforms.Resize((640, 640)), 
    transforms.ToTensor(),
])

datasets_list = []
for dataset in datasets.values():
    datasets_list.append(LaneDataset(dataset["image_dir"], dataset["label_dir"], transform=transform))

full_dataset = torch.utils.data.ConcatDataset(datasets_list)

# --- Create Train Loader ---
train_loader = DataLoader(full_dataset, batch_size=8, collate_fn=custom_collate, shuffle=True)

# --- Verify Batch ---
batch_images, batch_labels = next(iter(train_loader))
print(f"Batch Image Shape: {batch_images.shape}")  # [8, 3, 640, 640]
print(f"Example Label Shape: {batch_labels.shape}")  # [8, 5, 20, 2]


In [None]:
import torch
import torch.nn as nn

class LightweightLaneDetection(nn.Module):
    def __init__(self, debug=False):
        super(LightweightLaneDetection, self).__init__()
        self.debug = debug
        
        # Input: (batch_size, 3, 640, 640)
        self.conv_layers = nn.Sequential(
            # Layer 1: 640 -> 320
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            
            # Layer 2: 320 -> 160
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            
            # Layer 3: 160 -> 80
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            
            # Layer 4: 80 -> 40
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            
            # Layer 5: 40 -> 20
            nn.MaxPool2d(2)
        )
        
        # Final conv output: (batch_size, 32, 20, 20)
        self.conv_output_size = 32 * 20 * 20  # = 12800
        
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.conv_output_size, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 5 * 20 * 2)  # 5 lanes, 20 points, 2 coordinates each
        )
        
    def forward(self, x):
        if self.debug:
            print(f"Input shape: {x.shape}")
        
        x = self.conv_layers(x)
        if self.debug:
            print(f"After conv shape: {x.shape}")
            print(f"Flattened size should be: {x.shape[0]} x {x.shape[1] * x.shape[2] * x.shape[3]}")
        
        x = self.fc_layers(x)
        if self.debug:
            print(f"Output shape before view: {x.shape}")
        
        return x.view(-1, 5, 20, 2)

def initialize_training(learning_rate=0.001, debug=True):
    model = LightweightLaneDetection(debug=debug)
    
    if debug:
        print("\nModel Architecture:")
        print(model)
        print("\nModel Parameters:")
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Total parameters: {total_params:,}")
    
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MPS device")
    else:
        device = torch.device("cpu")
        print("Using CPU device")
    
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    return model, criterion, optimizer, device

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        if batch_idx == 0 and model.debug:
            print(f"\nBatch shapes:")
            print(f"Images: {images.shape}")
            print(f"Labels: {labels.shape}")
        
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        if batch_idx == 0 and model.debug:
            print(f"Outputs: {outputs.shape}")
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Memory cleanup
        images = images.cpu()
        labels = labels.cpu()
        outputs = outputs.cpu()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item():.4f}')
            
    return total_loss / len(train_loader)

# No need to resize to 256x256 since your images are already 640x640
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Small batch size for M3 MacBook
train_loader = DataLoader(full_dataset, batch_size=4, collate_fn=custom_collate, shuffle=True)

# Initialize with debug
model, criterion, optimizer, device = initialize_training(debug=True)

# Train
num_epochs = 1
for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')
    avg_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    print(f'Average Loss: {avg_loss:.4f}')


In [None]:
import os
import cv2
import numpy as np
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

# Path to the image directory and model weights
image_directory = '../Autonomous/valid/images'  # Replace with your image directory
model_path = 'lane_detection_model.pth'  # Replace with your model weights file path

# Load the model (your model architecture)
class LightweightLaneDetection(torch.nn.Module):
    def __init__(self, debug=False):
        super(LightweightLaneDetection, self).__init__()
        self.debug = debug
        self.conv_layers = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(32, 32, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(32, 32, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(32, 32, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
            torch.nn.MaxPool2d(2),
            torch.nn.MaxPool2d(2)
        )
        
        self.conv_output_size = 32 * 20 * 20  # 12800
        
        self.fc_layers = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(self.conv_output_size, 1024),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(1024, 5 * 20 * 2)  # 5 lanes, 20 points, 2 coordinates each
        )
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x.view(-1, 5, 20, 2)

# Load the pre-trained model weights
model = LightweightLaneDetection(debug=False)
model.load_state_dict(torch.load(model_path))
model.eval()  # Set model to evaluation mode

# Transform to resize and normalize the image to the model's expected input size
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((640, 640)),  # Resize to 640x640
    transforms.ToTensor(),  # Converts to tensor and normalizes to [0, 1]
])

# Get all image files from the directory
image_files = [f for f in os.listdir(image_directory) if f.endswith('.png') or f.endswith('.jpg')]

# Process each image
for image_file in image_files:
    img = plt.imread(os.path.join(image_directory, image_file))  # Load image
    if img is None:
        continue

    # Preprocess the image
    img_resized = np.array(img)  # Convert to numpy array
    img_tensor = transform(img_resized)  # Apply transform
    img_tensor = img_tensor.unsqueeze(0)  # Add batch dimension (1, 3, 640, 640)
    
    # Run inference
    with torch.no_grad():
        predicted_output = model(img_tensor)  # Get the prediction

    # Reshape the output: (5, 20, 2) -> 5 lanes, each with 20 points (x, y) coordinates
    predicted_output = predicted_output.squeeze().cpu().numpy()  # Shape: (5, 20, 2)

    # Get image dimensions (height, width)
    img_height, img_width, _ = img.shape

    # Extract lane coordinates from predicted output and draw them
    for lane_idx in range(predicted_output.shape[0]):
        lane_coords = predicted_output[lane_idx]  # Shape: (20, 2) for 20 points (x, y)

        # Convert normalized coordinates (0 to 1) into pixel values
        points = [(int(x * img_width), int(y * img_height)) for x, y in lane_coords]

        # Draw lane lines on the image
        for i in range(len(points) - 1):
            img_resized = cv2.line(img_resized, points[i], points[i + 1], (0, 255, 0), 3)  # Green line

    # Show the annotated image with matplotlib (without saving)
    plt.imshow(img_resized)
    plt.axis('off')  # Turn off axis labels
    plt.show()
