In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import numpy as np
import os
import random


# ==========================================
# 1. THE ENET ARCHITECTURE (CORRECTED)
# ==========================================

class InitialBlock(nn.Module):
    """
    The initial block of ENet.
    Concatenates a 3x3 conv (stride 2) and a max pooling layer.
    Input: 3 channels -> Output: 13 (conv) + 3 (pool) = 16 channels
    """

    def __init__(self, in_channels=3, out_channels=13):
        super(InitialBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
        self.batch_norm = nn.BatchNorm2d(out_channels + in_channels)
        self.pool = nn.MaxPool2d(2, stride=2, padding=0)
        self.prelu = nn.PReLU()

    def forward(self, x):
        conv = self.conv(x)
        pool = self.pool(x)
        out = torch.cat([conv, pool], dim=1)  # Concatenate -> 16 channels
        out = self.batch_norm(out)
        out = self.prelu(out)
        return out


class Bottleneck(nn.Module):
    """
    The main building block of ENet.
    Handles Downsampling (Concat) and Regular (Add) connections differently.
    """

    def __init__(self, in_channels, out_channels, internal_ratio=4, kernel_size=3, padding=1, dilation=1,
                 asymmetric=False, dropout_prob=0.1, downsample=False, upsample=False):
        super(Bottleneck, self).__init__()

        self.downsample = downsample
        self.upsample = upsample
        self.main_branch_pool_indices = None

        # Calculate internal channels
        internal_channels = in_channels // internal_ratio

        # KEY FIX: In downsampling, extension branch must output the *difference* in channels
        # so that Concat(Main(16), Ext(48)) == 64
        if downsample:
            self.ext_channels = out_channels - in_channels
        else:
            self.ext_channels = out_channels

        # --- Main Branch ---
        if downsample:
            self.main_pool = nn.MaxPool2d(2, stride=2, return_indices=True)
        elif upsample:
            self.main_unpool = nn.MaxUnpool2d(kernel_size=2)
            self.main_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
            self.main_bn = nn.BatchNorm2d(out_channels)

        # --- Extension Branch ---
        # 1. Projection (1x1 or 2x2)
        if downsample:
            self.ext_conv1 = nn.Conv2d(in_channels, internal_channels, kernel_size=2, stride=2, bias=False)
        else:
            self.ext_conv1 = nn.Conv2d(in_channels, internal_channels, kernel_size=1, bias=False)
        self.ext_bn1 = nn.BatchNorm2d(internal_channels)
        self.ext_prelu1 = nn.PReLU()

        # 2. Main Conv (Regular, Dilated, or Asymmetric)
        if asymmetric:
            self.ext_conv2 = nn.Sequential(
                nn.Conv2d(internal_channels, internal_channels, kernel_size=(kernel_size, 1), padding=(padding, 0),
                          dilation=dilation, bias=False),
                nn.BatchNorm2d(internal_channels),
                nn.PReLU(),
                nn.Conv2d(internal_channels, internal_channels, kernel_size=(1, kernel_size), padding=(0, padding),
                          dilation=dilation, bias=False),
                nn.BatchNorm2d(internal_channels),
                nn.PReLU()
            )
        elif upsample:
            self.ext_conv2 = nn.ConvTranspose2d(internal_channels, internal_channels, kernel_size=kernel_size,
                                                padding=padding, output_padding=1, stride=2, bias=False)
            self.ext_bn2 = nn.BatchNorm2d(internal_channels)
            self.ext_prelu2 = nn.PReLU()
        else:
            self.ext_conv2 = nn.Conv2d(internal_channels, internal_channels, kernel_size=kernel_size, padding=padding,
                                       dilation=dilation, bias=False)
            self.ext_bn2 = nn.BatchNorm2d(internal_channels)
            self.ext_prelu2 = nn.PReLU()

        # 3. Expansion (1x1)
        # Note: We use self.ext_channels calculated above
        self.ext_conv3 = nn.Conv2d(internal_channels, self.ext_channels, kernel_size=1, bias=False)
        self.ext_bn3 = nn.BatchNorm2d(self.ext_channels)
        self.ext_prelu3 = nn.PReLU()
        self.ext_dropout = nn.Dropout2d(p=dropout_prob)

        self.final_prelu = nn.PReLU()

    def forward(self, x, max_indices=None):
        # Main Branch
        main = x
        if self.downsample:
            main, self.main_branch_pool_indices = self.main_pool(main)
        elif self.upsample:
            # FIX: Conv (reduce channels) -> BN -> Unpool (match indices channels)
            main = self.main_conv(main)
            main = self.main_bn(main)
            main = self.main_unpool(main, max_indices)

        # Extension Branch
        ext = self.ext_conv1(x)
        ext = self.ext_bn1(ext)
        ext = self.ext_prelu1(ext)

        if hasattr(self, 'ext_conv2'):
            ext = self.ext_conv2(ext)
            if self.upsample and hasattr(self, 'ext_bn2'):
                ext = self.ext_bn2(ext)
                ext = self.ext_prelu2(ext)

        ext = self.ext_conv3(ext)
        ext = self.ext_bn3(ext)
        ext = self.ext_prelu3(ext)
        ext = self.ext_dropout(ext)

        # Combine
        if self.downsample:
            # Concatenate main and ext for downsampling
            out = torch.cat((main, ext), 1)
        else:
            # Add for regular residual blocks
            out = main + ext

        return self.final_prelu(out)


class ENet(nn.Module):
    def __init__(self, num_classes=2):
        super(ENet, self).__init__()

        # Initial: 3 -> 16
        self.initial = InitialBlock(in_channels=3, out_channels=13)

        # Stage 1: 16 -> 64
        self.bottleneck1_0 = Bottleneck(16, 64, downsample=True, dropout_prob=0.01)  # Downsample
        self.bottleneck1_1 = Bottleneck(64, 64, dropout_prob=0.01)  # Regular
        self.bottleneck1_2 = Bottleneck(64, 64, dropout_prob=0.01)
        self.bottleneck1_3 = Bottleneck(64, 64, dropout_prob=0.01)
        self.bottleneck1_4 = Bottleneck(64, 64, dropout_prob=0.01)

        # Stage 2: 64 -> 128
        self.bottleneck2_0 = Bottleneck(64, 128, downsample=True)
        self.bottleneck2_1 = Bottleneck(128, 128)
        self.bottleneck2_2 = Bottleneck(128, 128, dilation=2, padding=2)
        self.bottleneck2_3 = Bottleneck(128, 128, asymmetric=True, kernel_size=5, padding=2)
        self.bottleneck2_4 = Bottleneck(128, 128, dilation=4, padding=4)

        # Stage 3 (Usually repeats Stage 2) - Simplified here for brevity
        self.bottleneck3_1 = Bottleneck(128, 128)
        self.bottleneck3_2 = Bottleneck(128, 128, dilation=2, padding=2)
        self.bottleneck3_3 = Bottleneck(128, 128, asymmetric=True, kernel_size=5, padding=2)

        # Stage 4 (Decoder): 128 -> 64
        self.bottleneck4_0 = Bottleneck(128, 64, upsample=True)
        self.bottleneck4_1 = Bottleneck(64, 64)
        self.bottleneck4_2 = Bottleneck(64, 64)

        # Stage 5 (Decoder): 64 -> 16
        self.bottleneck5_0 = Bottleneck(64, 16, upsample=True)
        self.bottleneck5_1 = Bottleneck(16, 16)

        # Fullconv: 16 -> num_classes
        self.fullconv = nn.ConvTranspose2d(16, num_classes, kernel_size=2, stride=2, padding=0, bias=False)

    def forward(self, x):
        # Initial
        x = self.initial(x)  # 16

        # Stage 1
        x = self.bottleneck1_0(x)  # 64
        indices1 = self.bottleneck1_0.main_branch_pool_indices
        x = self.bottleneck1_1(x)
        x = self.bottleneck1_2(x)
        x = self.bottleneck1_3(x)
        x = self.bottleneck1_4(x)

        # Stage 2
        x = self.bottleneck2_0(x)  # 128
        indices2 = self.bottleneck2_0.main_branch_pool_indices
        x = self.bottleneck2_1(x)
        x = self.bottleneck2_2(x)
        x = self.bottleneck2_3(x)
        x = self.bottleneck2_4(x)

        # Stage 3
        x = self.bottleneck3_1(x)
        x = self.bottleneck3_2(x)
        x = self.bottleneck3_3(x)

        # Stage 4
        x = self.bottleneck4_0(x, indices2)  # 64
        x = self.bottleneck4_1(x)
        x = self.bottleneck4_2(x)

        # Stage 5
        x = self.bottleneck5_0(x, indices1)  # 16
        x = self.bottleneck5_1(x)

        # Output
        x = self.fullconv(x)
        return x

In [10]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import cv2

import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"


model = ENet(num_classes=2)
model.load_state_dict(torch.load("enet_best.pth", map_location=device))
model.to(device)
model.eval()

print("Model Loaded Successfully!")

video_path = "test_video.mp4"
cap = cv2.VideoCapture(video_path)

# Video output writer
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter("test_output_video.mp4", fourcc, 30, (1280, 720))

def preprocess(frame):
    img = cv2.resize(frame, (512, 256))
    img = img / 255.0
    img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
    img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
    return img.to(device)

def predict_mask(input_tensor):
    with torch.no_grad():
        pred = model(input_tensor)
        pred = F.softmax(pred, dim=1)[:, 1]
        mask = (pred > 0.5).float()[0].cpu().numpy()
    return mask

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    original_frame = frame.copy()
    original_frame = cv2.resize(original_frame, (1280, 720))

    prep = preprocess(frame)
    mask = predict_mask(prep)

    mask_resized = cv2.resize(mask, (1280, 720))

    overlay = original_frame.copy()
    overlay[mask_resized > 0.5] = [57, 255, 20]

    result = cv2.addWeighted(original_frame, 0.7, overlay, 0.3, 0)


    out.write(result)

cap.release()
out.release()

print("Video saved as output_lane_video.mp4")

FileNotFoundError: [Errno 2] No such file or directory: 'enet_best.pth'