# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount( '/content/gdrive' )

# Install Dependencies

In [None]:
!pip install torch torchvision tqdm
!pip install ultralytics
!pip install python-Levenshtein

print("✅ Dependencies installed")

# Import Libraries

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision import models
from PIL import Image, ImageDraw, ImageFont
import pandas as pd
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from ultralytics import YOLO
import cv2
import numpy as np
from IPython.display import Image as IPImage, display
import glob # For finding files
import re # For plausible plate filtering
from collections import Counter # For majority vote
import tensorflow as tf
from Levenshtein import distance as levenshtein_distance # For CER

# --- FORCE TENSORFLOW TO CPU (FOR DEBUGGING EFFICIENTDET) ---
# This MUST be one of the first TF operations
print("Attempting to force TensorFlow operations to CPU for debugging EfficientDet...")
try:
    # Get a list of physical GPUs
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        # If there are GPUs, prevent TensorFlow from using them
        tf.config.set_visible_devices([], 'GPU')
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(f"Physical GPUs: {len(gpus)}, Logical GPUs after hiding: {len(logical_gpus)}")
        print("✅ GPUs should now be hidden from TensorFlow. Operations will attempt CPU.")
    else:
        print("⚠️ No GPUs found. TensorFlow will use CPU.")
except RuntimeError as e:
    print(f"⚠️ Could not set visible devices (might have already been set or other issue): {e}")
# --- END OF CPU FORCING BLOCK ---

print("✅ Libraries imported")

# Path Setup

In [None]:
output_path = "/content/gdrive/MyDrive/cos30018-test/CRNN_CTC_Loss"
training_annotations_xml_path = '/content/gdrive/MyDrive/cos30018-test/data/annotations_main.xml'
validation_annotations_xml_path = '/content/gdrive/MyDrive/cos30018-test/data/annotations_val.xml'
training_image_path = '/content/gdrive/MyDrive/cos30018-test/data/images/train'
validation_image_path = '/content/gdrive/MyDrive/cos30018-test/data/images/val'
cropped_image_path = os.path.join(output_path, "cropped_images")
cropped_image_val_path = os.path.join(output_path, "cropped_images_val")
training_labels_csv_path = os.path.join(output_path, "labels.csv")
validation_labels_csv_path = os.path.join(output_path, "labels_val.csv")
best_model_path = os.path.join(output_path, "crnn_best_model.pth")
predictions_csv_path = os.path.join(output_path, "validation_predictions.csv")
yolov11_model_path = "/content/gdrive/MyDrive/cos30018-test/yolov11/train_170525/runs/detect/train2/weights/best.pt"
yolov10_model_path = "/content/gdrive/MyDrive/cos30018-test/PaddleOCR/PaddleOCR/weights/best.pt"
efficientdet_model_path = "/content/gdrive/MyDrive/cos30018-test/efficientdetd0/saved_model/"
annotated_output_path = os.path.join(output_path, "od_crnn_predictions")

# Check if paths exist
paths_name = ["training_annotations_xml_path", "validation_annotations_xml_path", "training_image_path", "validation_image_path", "cropped_image_path", "cropped_image_val_path", "training_labels_csv_path", "validation_labels_csv_path", "best_model_path", "predictions_csv_path", "yolov11_model_path", "yolov10_model_path", "efficientdet_model_path", "annotated_output_path"]
paths_to_check = [training_annotations_xml_path, validation_annotations_xml_path, training_image_path, validation_image_path, cropped_image_path, cropped_image_val_path, training_labels_csv_path, validation_labels_csv_path, best_model_path, predictions_csv_path, yolov11_model_path, yolov10_model_path, efficientdet_model_path, annotated_output_path]
print("\nChecking paths...")
for path_name, path in zip(paths_name, paths_to_check):
    if os.path.exists(path):
        print(f"✅ {path_name} exists: {path}")
    else:
        print(f"❌ {path_name} does not exist: {path}")

# Create output directories if not exist
os.makedirs(output_path, exist_ok=True)
os.makedirs(cropped_image_path, exist_ok=True)
os.makedirs(cropped_image_val_path, exist_ok=True)
os.makedirs(annotated_output_path, exist_ok=True)
output_directories_name = ["output_path", "cropped_image_path", "cropped_image_val_path", "annotated_output_path"]
output_directories = [output_path, cropped_image_path, cropped_image_val_path, annotated_output_path]
print("\nCreating output directories...")
for dir_name, dir_path in zip(output_directories_name, output_directories):
    print(f"✅ {dir_name} created: {dir_path}")

print("\n✅ Paths setup")

# Parse CVAT XML

In [None]:
def parse_cvat_annotations(annotations_xml_path):
    tree = ET.parse(annotations_xml_path)
    root = tree.getroot()
    annotations = {}

    for image in tqdm(root.findall('image')):
        image_name = image.attrib['name']
        boxes = []

        for box in image.findall('box'):
            label = box.attrib.get('label')
            if label != 'carplate':
                continue

            xtl = float(box.attrib['xtl'])
            ytl = float(box.attrib['ytl'])
            xbr = float(box.attrib['xbr'])
            ybr = float(box.attrib['ybr'])

            plate_number = None
            for attr in box.findall('attribute'):
                if attr.attrib.get('name') == 'plate_number':
                    plate_number = attr.text.strip() if attr.text else None
                    break

            boxes.append({
                'bbox': (xtl, ytl, xbr, ybr),
                'plate_number': plate_number
            })
        annotations[image_name] = boxes
    return annotations
    print("✅ CVAT annotations for training parsed")

annotations = parse_cvat_annotations(training_annotations_xml_path)

In [None]:
def parse_cvat_annotations(annotations_xml_path):
    tree = ET.parse(annotations_xml_path)
    root = tree.getroot()
    annotations = {}

    for image in tqdm(root.findall('image')):
        image_name = image.attrib['name']
        boxes = []

        for box in image.findall('box'):
            label = box.attrib.get('label')
            if label != 'carplate':
                continue

            xtl = float(box.attrib['xtl'])
            ytl = float(box.attrib['ytl'])
            xbr = float(box.attrib['xbr'])
            ybr = float(box.attrib['ybr'])

            plate_number = None
            for attr in box.findall('attribute'):
                if attr.attrib.get('name') == 'plate_number':
                    plate_number = attr.text.strip() if attr.text else None
                    break

            boxes.append({
                'bbox': (xtl, ytl, xbr, ybr),
                'plate_number': plate_number
            })
        annotations[image_name] = boxes
    return annotations
    print("✅ CVAT annotations for validation parsed")

annotations = parse_cvat_annotations(validation_annotations_xml_path)

# Process Cropped Images, Save All Labels and Bounding Boxes as CSV

In [None]:
if not os.path.exists(cropped_image_path) or not os.path.exists(training_labels_csv_path):
    print("Processing images and creating labels CSV...")
    data = []

    for image_name, boxes in tqdm(annotations.items()):
        image_path = os.path.join(training_image_path, image_name)

        if not os.path.exists(image_path):
            print(f"❌ Image {image_path} not found, skipping.")
            continue

        try:
            img = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"❌ Error opening image {image_path}: {e}")
            continue

        for i, box_info in enumerate(boxes):
            xtl, ytl, xbr, ybr = box_info['bbox']
            plate_number = box_info['plate_number']

            if not plate_number:
                continue

            # Crop the car plate
            cropped = img.crop((xtl, ytl, xbr, ybr))

            # Preserve original file extension
            original_ext = os.path.splitext(image_name)[1]
            cropped_filename = f"{os.path.splitext(image_name)[0]}_{i}{original_ext}"
            cropped_path = os.path.join(cropped_image_path, cropped_filename)
            cropped.save(cropped_path)

            # Append to dataset
            data.append({
                'image': cropped_filename,
                'xtl': xtl,
                'ytl': ytl,
                'xbr': xbr,
                'ybr': ybr,
                'plate_number': plate_number
            })

    print(f"\n✅ Cropped {len(data)} images and saved to {cropped_image_path}")

    # Save CSV file
    df = pd.DataFrame(data)
    df.to_csv(training_labels_csv_path, index=False)

    print(f"✅ Saved labels and bounding boxes to {training_labels_csv_path}")
else:
    print("✅ Training labels CSV and cropped training images already exist.")

In [None]:
if not os.path.exists(cropped_image_val_path) or not os.path.exists(validation_labels_csv_path):
    print("Processing images and creating labels CSV...")
    data = []

    for image_name, boxes in tqdm(annotations.items()):
        image_path = os.path.join(validation_image_path, image_name)

        if not os.path.exists(image_path):
            print(f"❌ Image {image_path} not found, skipping.")
            continue

        try:
            img = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"❌ Error opening image {image_path}: {e}")
            continue

        for i, box_info in enumerate(boxes):
            xtl, ytl, xbr, ybr = box_info['bbox']
            plate_number = box_info['plate_number']

            if not plate_number:
                continue

            # Crop the car plate
            cropped = img.crop((xtl, ytl, xbr, ybr))

            # Preserve original file extension
            original_ext = os.path.splitext(image_name)[1]
            cropped_filename = f"{os.path.splitext(image_name)[0]}_{i}{original_ext}"
            cropped_path = os.path.join(cropped_image_val_path, cropped_filename)
            cropped.save(cropped_path)

            # Append to dataset
            data.append({
                'image': cropped_filename,
                'xtl': xtl,
                'ytl': ytl,
                'xbr': xbr,
                'ybr': ybr,
                'plate_number': plate_number
            })

    print(f"\n✅ Cropped {len(data)} images and saved to {cropped_image_val_path}")

    # Save CSV file
    df = pd.DataFrame(data)
    df.to_csv(validation_labels_csv_path, index=False)

    print(f"✅ Saved labels and bounding boxes to {validation_labels_csv_path}")
else:
    print("✅ Validation labels CSV and cropped validation images already exist.")

# Hyper-parameters

In [None]:
batch_size = 32
learning_rate = 0.0005
num_epochs = 50
patience_epochs = 10
epochs_no_improve = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

print("\n✅ Hyper-parameters set")

# Charset

In [None]:
ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" # Removed space, handle spaces in pre/post processing if needed
char_to_idx = {char: idx + 1 for idx, char in enumerate(ALPHABET)}  # CTC: 0 is blank
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
num_classes = len(ALPHABET) + 1  # +1 for CTC blank

def encode_label(text):
    text = str(text).upper().replace(" ", "") # Normalize before encoding
    return [char_to_idx[c] for c in text if c in char_to_idx]

print("✅ Charset defined")

# Malaysia Car License Plate Dataset

In [None]:
class CRNNCarPlateDataset(Dataset):
    def __init__(self, csv_path, image_folder, transform=None):
        self.labels_df = pd.read_csv(csv_path)
        self.image_folder = image_folder
        self.transform = transform if transform else transforms.ToTensor()

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        row = self.labels_df.iloc[idx]
        img_path = os.path.join(self.image_folder, row['image'])
        label_text = str(row['plate_number'])

        # Load image
        image = Image.open(img_path).convert('L')  # grayscale
        image = self.transform(image)

        # Encode label
        label_seq = torch.tensor(encode_label(label_text), dtype=torch.long)

        return image, label_seq, label_text

print("✅ Dataset defined")

# CRNN Model Architecture

In [None]:
class CRNN(nn.Module):
    def __init__(self, img_height, num_classes):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),  # input_channels, output_channels, kernel_size
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # height becomes 32/2 = 16
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # height becomes 16/2 = 8
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d((2,1), (2,1)), # height becomes 8/2 = 4
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            # Adjust the last pooling layer to reduce height from 4 to 1
            nn.MaxPool2d((4,1), (4,1)) # kernel_size=(height, width), stride=(height, width)
        )
        # The input size to the RNN needs to be recalculated based on the new CNN output dimensions.
        # If input is (B, 1, 32, 128), after CNN with new pooling:
        # Output channels = 512
        # Width calculation: 128 -> 64 -> 32 -> 32 -> 32 (pooling only affects width by /2 twice)
        # Height calculation: 32 -> 16 -> 8 -> 4 -> 1 (pooling affects height by /2 four times)
        # Final feature map size: (B, 512, 1, 32)
        # Input to RNN: (Batch, Seq_len, Features) = (B, Width, Channels * Height) = (B, 32, 512 * 1) = (B, 32, 512)
        # The RNN expects 512 features per step.
        self.rnn = nn.LSTM(512, 256, bidirectional=True, num_layers=2, batch_first=True)
        # The FC layer input size should be 2 * hidden_size (due to bidirectional)
        self.fc = nn.Linear(2 * 256, num_classes) # 2*256 because of bidirectional=True

    def forward(self, x):
        x = self.cnn(x) # x shape: (batch, channels, height, width) e.g. B, 512, 1, 32 for h=32, w=128 input
        b, c, h, w = x.size()
        # For CRNN, LSTM expects (batch, seq_len, features)
        # Here, width is sequence length. Features are channels * height
        # The assertion should now pass if the input height is 32
        assert h == 1, f"Feature map height expected to be 1, but got {h}"
        x = x.squeeze(2) # Remove height dim: (b, c, w)
        x = x.permute(0, 2, 1) # (b, w, c) which is (batch, seq_len, num_features) for LSTM. c is now 512
        x, _ = self.rnn(x) # x shape (B, W, 2*hidden_size) -> (B, 32, 512)
        x = self.fc(x) # (B, W, num_classes) -> (B, 32, num_classes)
        x = x.permute(1, 0, 2)  # (w, b, num_classes) for CTC Loss: (T, N, C) -> (32, B, num_classes)
        return x

print("✅ CRNN model defined")

# Dataloaders, Train and Testing Sets

In [None]:
# CRNN Training Transform (can add more augmentation here if retraining)
crnn_train_transform = transforms.Compose([
    transforms.Resize((32, 128)),
    # Add more augmentations here for CRNN training if desired:
    # transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.9, 1.1), shear=5),
    # Apply geometric transformations
    transforms.RandomAffine(degrees=(-5, 5),      # Rotate by -5 to +5 degrees
                            translate=(0.05, 0.05), # Translate by up to 5% of width/height
                            scale=(0.9, 1.1),     # Scale by 90% to 110%
                            shear=(-5, 5)),       # Shear by -5 to +5 degrees
    # transforms.ColorJitter(brightness=0.3, contrast=0.3),
    # Apply color/pixel-level transformations (your images are grayscale, so brightness/contrast are most relevant)
    transforms.ColorJitter(brightness=(0.7, 1.3), # Randomly change brightness (e.g., 70% to 130% of original)
                           contrast=(0.7, 1.3)),   # Randomly change contrast
    # transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
    # Optional: Gaussian Blur (can sometimes make characters harder to read, use with caution)
    # If you use it, start with small values:
    # transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# CRNN Inference Transform (simpler, no augmentation)
crnn_inference_transform = transforms.Compose([
    transforms.Resize((32, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = CRNNCarPlateDataset(training_labels_csv_path, cropped_image_path, transform=crnn_train_transform)
# Training dataset split (5353 cropped training images): 80% for training, 20% for validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=lambda x: x)

print(f"Train set size: {len(train_set)}")
print(f"Validation set size: {len(val_set)}")
print("\n✅ Dataloaders, Training and Validation Sets defined")

# Training Model

In [None]:
def decode_output(outputs_from_model):
    """Greedy decoding for CRNN output during training/validation."""
    probs = torch.softmax(outputs_from_model, 2)
    argmax_probs, argmax_indices = probs.max(2)
    argmax_indices = argmax_indices.permute(1,0) # (B, T)
    pred_texts = []
    for i in range(argmax_indices.size(0)):
        path = argmax_indices[i]
        text = ''
        last_idx = 0
        for char_idx_tensor in path:
            char_idx = char_idx_tensor.item()
            if char_idx != 0 and char_idx != last_idx: # Not blank and not repeated
                if char_idx in idx_to_char:
                    text += idx_to_char[char_idx]
            last_idx = char_idx
        pred_texts.append(text)
    return pred_texts

def train_epoch(model, loader, optimizer, criterion, device=device):
    model.train()
    running_loss = 0
    for batch in tqdm(loader, desc="Training Epoch", leave=True):
        imgs, labels, _ = zip(*batch)
        imgs = torch.stack(imgs).to(device)
        labels_concat = torch.cat(labels).to(device)

        outputs = model(imgs) # (T, B, C)
        output_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)
        target_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long).to(device)

        optimizer.zero_grad()
        loss = criterion(outputs.log_softmax(2), labels_concat, output_lengths, target_lengths) # log_softmax for CTCLoss
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(loader)

def validate_epoch(model, loader, criterion):
    model.eval()
    running_loss = 0; total_correct_seq = 0; total_seq = 0; total_correct_chars = 0; total_chars = 0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating Epoch", leave=True):
            imgs, labels, label_texts = zip(*batch)
            imgs = torch.stack(imgs).to(device)
            labels_concat = torch.cat(labels).to(device)

            outputs = model(imgs)
            output_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)
            target_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long).to(device)

            loss = criterion(outputs.log_softmax(2), labels_concat, output_lengths, target_lengths)
            running_loss += loss.item()

            preds = decode_output(outputs) # Use original decode_output here
            for pred, true_text in zip(preds, label_texts):
                true_text_norm = true_text.upper().replace(" ","")
                if pred.strip() == true_text_norm.strip(): total_correct_seq +=1
                total_seq +=1
                correct_chars_count = sum(p == t for p, t in zip(pred, true_text_norm))
                total_correct_chars += correct_chars_count
                total_chars += len(true_text_norm)
    seq_acc = (total_correct_seq / total_seq) if total_seq > 0 else 0
    char_acc = (total_correct_chars / total_chars) if total_chars > 0 else 0
    return running_loss / len(loader), seq_acc, char_acc

print("✅ Training Model defined")

# Initialize Model, Loss, Optimizer

In [None]:
model = CRNN(img_height=32, num_classes=num_classes).to(device)
criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

print("✅ Model, Loss, Optimizer initialized")

In [None]:
best_val_loss = float('inf')
train_losses = []
val_losses = []
val_seq_accuracies = []
val_char_accuracies = []

for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}/{num_epochs}")
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_seq_acc, val_char_acc = validate_epoch(model, val_loader, criterion)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_seq_accuracies.append(val_seq_acc)
    val_char_accuracies.append(val_char_acc)

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"Val Seq Acc: {val_seq_acc*100:.2f}% | Val Char Acc: {val_char_acc*100:.2f}%")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        epochs_no_improve = 0
        print(f"✅ Saved Best Model at epoch {epoch}")
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience_epochs:
        print(f"Validation loss did not improve for {patience_epochs} epochs. Stopping early")
        break

print("\n🎉 CRNN training complete!")

# Plot Loss Curves and Accuracy Curves

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("CTC Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_path, 'LossCurves.png'))
plt.show()
print("\n✅ Loss curves plotted and saved")

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(val_seq_accuracies, label="Val Seq Accuracy")
plt.plot(val_char_accuracies, label="Val Char Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Validation Accuracy (Sequence & Character)")
plt.grid(True)
plt.legend()
plt.savefig(os.path.join(output_path, 'AccuracyCurves.png'))
plt.show()
print("\n✅ Accuracy curves plotted and saved")

# Inference on Test Set and Save Predictions to CSV

In [None]:
# Load Best Model for Inference
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.eval()

In [None]:
def show_and_save_plate_predictions_with_individual_metrics( # Renamed for clarity
    cropped_image_path,
    csv_path,
    output_csv_path,
    model,
    transform,
    device,
    decode_output_fn
):
    """
    Displays EVERY image with its true vs predicted label and individual outcome (match, edit distance).
    Saves all prediction results to CSV, and calculates overall CRNN inference performance metrics.
    """

    def predict_plate(image_path):
        image = Image.open(image_path).convert('L')
        img = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(img)
            decoded_results = decode_output_fn(outputs)
            if isinstance(decoded_results, tuple): # Handle (texts, confs)
                texts = decoded_results[0]
            else: # Handle just texts
                texts = decoded_results
        return texts[0] if texts else ""

    df = pd.read_csv(csv_path)
    predictions_data = []

    # --- Overall Metrics Initialization ---
    total_plates_processed = 0
    overall_correct_exact_matches = 0
    overall_total_edit_distance = 0
    overall_total_true_char_length = 0
    # --- End Overall Metrics Initialization ---

    print(f"Processing and displaying all {len(df)} images with individual metrics...\n")

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Evaluating & Displaying All Images"):
        image_name = row['image']
        true_label_raw = str(row['plate_number'])
        image_path = os.path.join(cropped_image_path, image_name)

        predicted_label_raw = ""
        is_exact_match_current_image = False
        edit_distance_current_image = -1 # Default if not applicable

        if not os.path.exists(image_path):
            print(f"⚠️ Image {image_path} not found for {image_name}, skipping display and metrics for this entry.")
            predictions_data.append({
                'image': image_name,
                'true_label': true_label_raw,
                'predicted_label': "IMAGE_NOT_FOUND",
                'is_exact_match': False,
                'edit_distance': -1
            })
            # Display a placeholder or skip if you prefer
            plt.figure(figsize=(8, 4))
            plt.text(0.5, 0.5, f"Image Not Found:\n{image_name}", ha='center', va='center', fontsize=12, color='red')
            plt.title(f"File: {image_name}", fontsize=10)
            plt.axis('off')
            plt.show()
            continue

        predicted_label_raw = predict_plate(image_path)
        total_plates_processed += 1

        true_label_norm = true_label_raw.upper().replace(" ", "")
        predicted_label_norm = predicted_label_raw.upper().replace(" ", "")

        # Individual image metrics
        if true_label_norm == predicted_label_norm:
            is_exact_match_current_image = True
            overall_correct_exact_matches += 1

        if true_label_norm: # Calculate edit distance if there's a ground truth
            edit_distance_current_image = levenshtein_distance(predicted_label_norm, true_label_norm)
            overall_total_edit_distance += edit_distance_current_image
            overall_total_true_char_length += len(true_label_norm)
        else: # No ground truth to compare against for edit distance
            edit_distance_current_image = 0 # Or some indicator that it's not applicable

        # --- Display Image with Individual Metric Info ---
        try:
            image_pil = Image.open(image_path)
            fig, ax = plt.subplots(figsize=(6, 4))
            ax.imshow(image_pil, cmap='gray')

            title_color = 'green' if is_exact_match_current_image else 'red'
            title_text = f"File: {image_name}\nTrue: {true_label_raw} | Predicted: {predicted_label_raw}"
            ax.set_title(title_text, fontsize=12, color=title_color, pad=20)

            # Add text below the image for individual metrics
            info_text = f"Exact Match: {'YES' if is_exact_match_current_image else 'NO'}\n"
            if true_label_norm: # Only show edit distance if comparable
                info_text += f"Edit Distance: {edit_distance_current_image}"
                if len(true_label_norm) > 0:
                    char_err_rate_instance = (edit_distance_current_image / len(true_label_norm)) * 100
                    info_text += f" (Instance CER: {char_err_rate_instance:.2f}%)"
            else:
                info_text += "Edit Distance: N/A (No ground truth for comparison)"

            # Position text - you might need to adjust coordinates based on your image sizes
            # This places it relative to the axes, below the image
            plt.text(0.5, -0.15, info_text, ha='center', va='top', transform=ax.transAxes,
                     fontsize=10, bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))

            ax.axis('off')
            plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust layout to make space for title and text
            plt.show()
        except Exception as e:
            print(f"Error displaying image {image_name}: {e}")
        # --- End Display Image ---

        predictions_data.append({
            'image': image_name,
            'true_label': true_label_raw,
            'predicted_label': predicted_label_raw,
            'is_exact_match': is_exact_match_current_image,
            'edit_distance': edit_distance_current_image,
            'true_label_normalized': true_label_norm,
            'predicted_label_normalized': predicted_label_norm
        })

    # Save all predictions to CSV
    pred_df = pd.DataFrame(predictions_data)
    pred_df.to_csv(output_csv_path, index=False)
    print(f"\n✅ Saved all predictions to {output_csv_path}")

    # --- Calculate and Print Final Overall Metrics ---
    overall_exact_match_accuracy = 0
    overall_cer = float('nan')
    overall_normalized_char_accuracy = float('nan')

    print(f"\n-------------- Overall CRNN Inference Performance Metrics --------------")
    if total_plates_processed > 0:
        overall_exact_match_accuracy = (overall_correct_exact_matches / total_plates_processed) * 100
        print(f"Total Plates Processed & Displayed: {total_plates_processed}")
        print(f"Overall Exact Match Accuracy (Sequence Accuracy): {overall_exact_match_accuracy:.2f}% ({overall_correct_exact_matches}/{total_plates_processed})")

        if overall_total_true_char_length > 0:
            overall_cer = (overall_total_edit_distance / overall_total_true_char_length) * 100
            overall_normalized_char_accuracy = (1 - (overall_total_edit_distance / overall_total_true_char_length)) * 100
            print(f"Overall Average Character Error Rate (CER): {overall_cer:.2f}%")
            print(f"Overall Normalized Character Accuracy: {overall_normalized_char_accuracy:.2f}%")
        else:
            print("Overall Character Error Rate (CER) and Normalized Character Accuracy could not be calculated (no true characters found or processed).")
    else:
        print("\n❌ No plates were processed to calculate overall metrics.")
    print(f"-----------------------------------------------------------------------")

    return {
        "total_plates_processed": total_plates_processed,
        "overall_exact_match_accuracy": overall_exact_match_accuracy,
        "overall_cer": overall_cer,
        "overall_normalized_char_accuracy": overall_normalized_char_accuracy
    }

print("✅ show_and_save_plate_predictions_with_individual_metrics defined")

In [None]:
print("Starting CRNN inference evaluation with individual display and metrics...")
# Note: This will display ALL images from your training_labels_csv_path
inference_metrics_individual_display = show_and_save_plate_predictions_with_individual_metrics(
    cropped_image_path=cropped_image_val_path,
    csv_path=validation_labels_csv_path,
    output_csv_path=predictions_csv_path, # This will overwrite the previous one
    model=model,
    transform=crnn_inference_transform,
    device=device,
    decode_output_fn=decode_output # Or decode_output_with_confidence
)

print("\nReturned Overall Metrics Dictionary:")
print(inference_metrics_individual_display)
print("\n✅ CRNN inference evaluation with individual display and metrics completed")

# Object Detection and CRNN Inference

In [None]:
def is_plausible_malaysian_plate(text, min_len=3, max_len=10):
    """Basic filter for plausible Malaysian license plate text."""
    if not text or not isinstance(text, str) or not (min_len <= len(text) <= max_len):
        return False
    # Normalize: uppercase, remove ALL spaces for pattern matching and consistency
    text_norm = text.upper().replace(" ", "")
    if not text_norm: return False # Empty after removing spaces

    # Rule 1: Must contain at least one letter and one digit for most common plates
    has_letter = any(c.isalpha() for c in text_norm)
    has_digit = any(c.isdigit() for c in text_norm)

    if not (has_letter and has_digit):
        # Allow purely alphabetical special plates if they are reasonably long
        # (e.g., "PUTRAJAYA", "SUKOM")
        if text_norm.isalpha() and len(text_norm) >= 5:
            # Could check against a known list of special plates here
            pass # Allow
        else:
            return False # Likely not a standard plate

    # Rule 2: Reject if too many consecutive identical characters (often CRNN misreads like "AAAAA" or "11111")
    if re.search(r'(.)\1{3,}', text_norm):  # 4 or more consecutive same characters
        return False

    # Rule 3: All characters must be in our defined ALPHABET (after normalization)
    for char_p in text_norm:
        if char_p not in ALPHABET: # ALPHABET here does not have space
            return False

    # Rule 4: Check typical Malaysian plate patterns (simplified)
    # Common: LLLDDDD, LL DDDD, L DDDD, LLL DDD, etc. (L=Letter, D=Digit)
    # Vowels 'I' and 'O' are often confused with '1' and '0'. CRNN might output them.
    # This rule helps catch some structural anomalies but is not exhaustive.
    # Pattern: 1-3 leading letters, then 1-4 digits, optionally followed by 1-2 trailing letters.
    m = re.match(r'^([A-Z]{1,3})(\d{1,4})([A-Z]{0,2})$', text_norm)
    # Special prefixes like S (Sabah), Q (Sarawak), J (Johor) etc. are covered by [A-Z]
    # This regex is quite strict for standard plates.
    # Example special plates: G1M, IM4U, PATRIOT, NBOS, etc.
    # Example taxi: HWA1234

    # Looser check: starts with letter, ends with digit or letter.
    if not text_norm[0].isalpha():
        return False # Most plates start with a letter

    if not (text_norm[-1].isdigit() or text_norm[-1].isalpha()): # Should end in digit or letter
        return False

    return True
print("✅ Plausible Malaysian plate filter function defined.")

In [None]:
def decode_output_with_confidence(outputs_from_model, char_map=idx_to_char):
    """
    Greedy decoding with a simple confidence score from CRNN output.
    Confidence is the mean probability of the selected (non-blank, non-repeated) characters.
    """
    # outputs_from_model shape: (Sequence_Length, Batch_Size, Num_Classes)
    # Ensure it's on CPU for numpy operations if needed, and detach from graph
    outputs_from_model = outputs_from_model.cpu().detach()

    log_probs = torch.nn.functional.log_softmax(outputs_from_model, dim=2)
    probs = torch.exp(log_probs) # Probabilities (T, B, C)

    best_path_probs_per_step, best_path_indices = probs.max(2) # (T, B)

    # Permute to (B, T) for easier iteration over batch
    best_path_indices = best_path_indices.permute(1, 0)
    best_path_probs_per_step = best_path_probs_per_step.permute(1, 0)

    pred_texts = []
    pred_confidences = []

    for i in range(best_path_indices.size(0)): # Iterate over batch
        batch_item_path_indices = best_path_indices[i]
        batch_item_step_probs = best_path_probs_per_step[i]

        decoded_text = ''
        char_probabilities_for_text = []
        last_char_idx = 0

        for j, current_char_idx_tensor in enumerate(batch_item_path_indices):
            current_char_idx = current_char_idx_tensor.item()

            if current_char_idx != 0 and current_char_idx != last_char_idx: # Not blank and not repeated
                if current_char_idx in char_map:
                    decoded_text += char_map[current_char_idx]
                    char_probabilities_for_text.append(batch_item_step_probs[j].item())
            last_char_idx = current_char_idx

        pred_texts.append(decoded_text)
        if char_probabilities_for_text:
            # Geometric mean is better for product of probabilities, but arithmetic mean is simpler here.
            # Or, sum of log-probabilities. For now, arithmetic mean of probabilities.
            pred_confidences.append(np.mean(char_probabilities_for_text))
        else:
            pred_confidences.append(0.0) # No valid characters decoded

    return pred_texts, pred_confidences

print("✅ decode_output_with_confidence defined")

In [None]:
def calculate_iou(boxA, boxB):
    # Determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    # Compute the area of intersection rectangle
    interArea = max(0, xB - xA) * max(0, yB - yA)

    # Compute the area of both the prediction and ground-truth rectangles
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])

    # Compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the intersection area
    iou = interArea / float(boxAArea + boxBArea - interArea)

    return iou

print("✅ IoU calculation function defined")

In [None]:
def load_od_model(model_path, device):
    """Loads the YOLO object detection model."""
    try:
        model = YOLO(model_path)
        model.to(device)
        print(f"✅ Object detection model loaded successfully from {model_path} on {device}")
        return model
    except Exception as e:
        print(f"❌ Error loading object detection model: {e}")
        return None

def load_tf_saved_model(model_dir_path):
    """Loads a TensorFlow SavedModel for inference."""
    try:
        model = tf.saved_model.load(model_dir_path)
        # Often, the actual inference function is a concrete function.
        # You might need to explore `model.signatures` to find the correct one.
        # Common signature is 'serving_default'.
        infer = model.signatures["serving_default"]
        print(f"✅ TensorFlow SavedModel loaded successfully from {model_dir_path} on {device}")
        return infer # Return the inference function
    except Exception as e:
        print(f"❌ Error loading TensorFlow SavedModel: {e}")
        return None

def load_crnn_for_inference(crnn_model_path, img_h, num_cls, device_to_use):
    try:
        model_crnn_inf = CRNN(img_height=img_h, num_classes=num_cls)
        if not os.path.exists(crnn_model_path):
            print(f"❌ CRNN model file not found at {crnn_model_path}. Cannot load for inference.")
            return None
        model_crnn_inf.load_state_dict(torch.load(crnn_model_path, map_location=device_to_use))
        model_crnn_inf.to(device_to_use)
        model_crnn_inf.eval()
        print(f"✅ CRNN model loaded successfully from {crnn_model_path} for inference on {device_to_use}")
        return model_crnn_inf
    except Exception as e:
        print(f"❌ Error loading CRNN model for inference: {e}")
        return None

def predict_plate_with_crnn_and_conf(cropped_pil_image, model_crnn_inf, transform_crnn, device_to_use):
    """ Recognizes characters and returns text with confidence. """
    if cropped_pil_image.mode != 'L':
        image_l = cropped_pil_image.convert('L')
    else:
        image_l = cropped_pil_image
    if image_l.width == 0 or image_l.height == 0: # Handle empty crops
        return "N/A", 0.0

    image_tensor = transform_crnn(image_l).unsqueeze(0).to(device_to_use)
    with torch.no_grad():
        outputs = model_crnn_inf(image_tensor) # (T, B, C), B=1
        pred_texts_list, pred_confidences_list = decode_output_with_confidence(outputs) # Use new decoder

    text_result = pred_texts_list[0] if pred_texts_list else "N/A"
    conf_result = pred_confidences_list[0] if pred_confidences_list else 0.0

    # Normalize text output (remove spaces, uppercase) for consistency
    text_result = text_result.upper().replace(" ", "")

    return text_result, conf_result

print("✅ Helper functions for OD and CRNN inference pipeline defined")

In [None]:
# NEW/UPDATED FUNCTION: Run EfficientDet Inference and Parse Output
def run_efficientdet_inference(tf_infer_function, pil_image,
                               expected_input_size=(512, 512),
                               score_threshold=0.40, # << INCREASED DEFAULT THRESHOLD
                               license_plate_class_id=1, # VERIFY THIS
                               min_plate_width=30,     # New: min expected width of a detected plate in pixels
                               min_plate_height=10,    # New: min expected height
                               max_aspect_ratio_dev=1.5):# New: How much aspect ratio can deviate (e.g., actual_ar / expected_ar)
    """
    Runs inference using a loaded TensorFlow EfficientDet model and parses its output.
    Returns a list of detection result structures compatible with the existing pipeline.
    """
    original_img_width, original_img_height = pil_image.size

    # Preprocessing: Resize and Pad to expected_input_size
    img_resized_pil = pil_image.copy()
    # Calculate aspect ratios
    original_aspect = original_img_width / original_img_height
    target_aspect = expected_input_size[0] / expected_input_size[1]

    if original_aspect > target_aspect: # Original is wider than target aspect
        new_width = expected_input_size[0]
        new_height = int(new_width / original_aspect)
    else: # Original is taller or same aspect
        new_height = expected_input_size[1]
        new_width = int(new_height * original_aspect)

    img_resized_pil = img_resized_pil.resize((new_width, new_height), Image.Resampling.LANCZOS)

    padded_img_pil = Image.new("RGB", expected_input_size, (128, 128, 128)) # Pad with gray
    paste_x = (expected_input_size[0] - img_resized_pil.width) // 2
    paste_y = (expected_input_size[1] - img_resized_pil.height) // 2
    padded_img_pil.paste(img_resized_pil, (paste_x, paste_y))

    image_np_uint8 = np.array(padded_img_pil, dtype=np.uint8)
    input_tensor = tf.convert_to_tensor(image_np_uint8)
    input_tensor = input_tensor[tf.newaxis, ...]

    detections = tf_infer_function(input_tensor=input_tensor) # TF Serving signature often uses named inputs

    num_detections = int(detections['num_detections'][0])
    det_boxes_normalized = detections['detection_boxes'][0,:num_detections].numpy()
    det_scores = detections['detection_scores'][0,:num_detections].numpy()
    det_classes = detections['detection_classes'][0,:num_detections].numpy().astype(np.int32)

    output_boxes_for_yolo_structure = []
    for i in range(num_detections):
        score = det_scores[i]
        class_id = det_classes[i]

        if score >= score_threshold and class_id == license_plate_class_id:
            ymin, xmin, ymax, xmax = det_boxes_normalized[i]

            # Denormalize coordinates from padded 512x512 input
            # to coordinates on the original image

            # Coords relative to the 512x512 padded input:
            pad_abs_xmin = xmin * expected_input_size[0]
            pad_abs_ymin = ymin * expected_input_size[1]
            pad_abs_xmax = xmax * expected_input_size[0]
            pad_abs_ymax = ymax * expected_input_size[1]

            # Subtract padding offset
            unpadded_xmin = pad_abs_xmin - paste_x
            unpadded_ymin = pad_abs_ymin - paste_y
            unpadded_xmax = pad_abs_xmax - paste_x
            unpadded_ymax = pad_abs_ymax - paste_y

            # Scale back to original image dimensions
            # (resized_pil dimensions are new_width, new_height)
            original_xmin = int(unpadded_xmin * (original_img_width / new_width))
            original_ymin = int(unpadded_ymin * (original_img_height / new_height))
            original_xmax = int(unpadded_xmax * (original_img_width / new_width))
            original_ymax = int(unpadded_ymax * (original_img_height / new_height))

            # Clamp to original image boundaries
            original_xmin = max(0, original_xmin)
            original_ymin = max(0, original_ymin)
            original_xmax = min(original_img_width -1 , original_xmax) # -1 to be safe for width/height indexing
            original_ymax = min(original_img_height -1, original_ymax)


            class FakeBox:
                def __init__(self, xyxy_val, conf_val, cls_val):
                    self.xyxy = [torch.tensor(xyxy_val, dtype=torch.float32)]
                    self.conf = [torch.tensor(conf_val, dtype=torch.float32)]
                    self.cls = [torch.tensor(cls_val, dtype=torch.float32)]

            output_boxes_for_yolo_structure.append(
                FakeBox([original_xmin, original_ymin, original_xmax, original_ymax], score, class_id)
            )

    class FakeResults:
        def __init__(self, boxes_list):
            self.boxes = boxes_list

    return [FakeResults(output_boxes_for_yolo_structure)] if output_boxes_for_yolo_structure else [FakeResults([])]

print("✅ EfficientDet inference and parsing function defined")

# Image Processing

In [None]:
def process_single_image(image_path, od_model_inf_callable, current_od_model_type: str,
                         crnn_model_inf, crnn_image_transform_inf, device_to_use, output_dir_img):
    """
    Detects license plates in a single image, recognizes characters,
    draws bounding boxes and labels, displays, and saves the annotated image.
    Includes checks to keep text within image bounds.
    """
    try:
        img_pil = Image.open(image_path).convert("RGB")
    except Exception as e:
        print(f"❌ Error loading image {image_path}: {e}")
        return

    # --- Object Detection Call ---
    results_od = None
    if current_od_model_type.startswith("yolo"):
        results_od = od_model_inf_callable(img_pil, verbose=False, conf=0.25)
    elif current_od_model_type == "efficientdet":
        # For EfficientDet, the od_model_inf_callable IS the infer function.
        # Preprocessing and postprocessing (including conf threshold) are in run_efficientdet_inference
        results_od = run_efficientdet_inference(
            tf_infer_function=od_model_inf_callable, # This is the loaded TF signature
            pil_image=img_pil,
            expected_input_size=(512, 512),      # As per your config
            score_threshold=0.35,                # This is used INSIDE run_efficientdet_inference
            license_plate_class_id=1             # VERIFY THIS
        )
    else:
        print(f"❌ Unknown OD model type for inference: {current_od_model_type}")
        return
    # --- End Object Detection Call ---

    annotated_img_pil = img_pil.copy() # Work on a copy for drawing
    draw = ImageDraw.Draw(annotated_img_pil)
    font_size = 28 # Increased font size
    # Standard Colab/Linux font path for DejaVu Sans Bold
    font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
    try:
        font = ImageFont.truetype(font_path, font_size)
    except IOError:
        print(f"⚠️ Font not found at {font_path}. Attempting to install fonts-dejavu...")
        # Commands to install fonts, ensure they run silently or handle output
        os.system('apt-get update -qq > /dev/null')
        os.system('apt-get install -qq -y fonts-dejavu > /dev/null')
        print("✅ DejaVu fonts installation attempt finished. Retrying font loading.")
        try:
            font = ImageFont.truetype(font_path, font_size)
            print(f"✅ Using installed font: {font_path}")
        except IOError:
            print("❌ Failed to load DejaVuSans-Bold.ttf even after installation. Using default font.")
            font = ImageFont.load_default()


    num_detections_drawn = 0
    # The structure of results_od[0].boxes should now be consistent (list of FakeBox or YOLO's Box objects)
    if results_od and results_od[0] and hasattr(results_od[0], 'boxes') and results_od[0].boxes:
        for i, box_obj in enumerate(results_od[0].boxes): # box_obj is either yolo.Box or our FakeBox
            # Access attributes directly or via index if they are lists of tensors
            xyxy_tensor = box_obj.xyxy[0] # This is a tensor [xmin, ymin, xmax, ymax]
            od_conf_tensor = box_obj.conf[0] # This is a tensor [confidence]
            # cls_tensor = box_obj.cls[0] # If needed

            xyxy = xyxy_tensor.cpu().numpy().astype(int)
            od_conf = od_conf_tensor.cpu().numpy().item() # Get scalar value

            xmin, ymin, xmax, ymax = xyxy
            xmin = max(0, xmin); ymin = max(0, ymin)
            xmax = min(img_pil.width, xmax); ymax = min(img_pil.height, ymax)

            if xmax <= xmin or ymax <= ymin:
                continue

            cropped_plate_pil = img_pil.crop((xmin, ymin, xmax, ymax))

            plate_text, plate_conf = "N/A", 0.0
            if cropped_plate_pil.width > 5 and cropped_plate_pil.height > 5 : # Basic check for valid crop size
                plate_text, plate_conf = predict_plate_with_crnn_and_conf(cropped_plate_pil, crnn_model_inf, crnn_image_transform_inf, device_to_use)

            label = f"{plate_text} (C:{plate_conf:.2f} | OD:{od_conf:.2f})"
            draw.rectangle([xmin, ymin, xmax, ymax], outline="green", width=3) # Thicker box

            # --- Text Placement and Boundary Check ---
            text_x = xmin
            text_y_above = ymin - font_size - 5 # Attempt to place above
            text_y_below = ymax + 5 # Option to place below

            # Calculate text bounding box assuming placement above
            try:
                text_bbox_above = draw.textbbox((text_x, text_y_above), label, font=font)
            except Exception as e:
                 print(f"⚠️ Warning: Could not calculate textbbox with font, using default. {e}")
                 # Fallback if textbbox calculation fails with the loaded font
                 # This might not be perfectly accurate but prevents crashes
                 default_font_size = 10 # Estimate default font size
                 text_bbox_above = (text_x, text_y_above, text_x + len(label) * default_font_size * 0.6, text_y_above + default_font_size * 1.2)


            # Check if placing above goes off the top edge
            if text_bbox_above[1] < 0:
                # Place below the box
                text_y = text_y_below
                try:
                    text_bbox = draw.textbbox((text_x, text_y), label, font=font)
                except Exception as e:
                     print(f"⚠️ Warning: Could not calculate textbbox with font, using default. {e}")
                     default_font_size = 10 # Estimate default font size
                     text_bbox = (text_x, text_y, text_x + len(label) * default_font_size * 0.6, text_y + default_font_size * 1.2)
            else:
                # Place above the box
                text_y = text_y_above
                text_bbox = text_bbox_above

            # Ensure text background stays within left/right image bounds
            text_bbox_adjusted = list(text_bbox)
            text_bbox_adjusted[0] = max(0, text_bbox_adjusted[0]) # Ensure left is not less than 0
            text_bbox_adjusted[2] = min(img_pil.width, text_bbox_adjusted[2]) # Ensure right is not more than width

            # Ensure text background stays within top/bottom image bounds
            text_bbox_adjusted[1] = max(0, text_bbox_adjusted[1]) # Ensure top is not less than 0
            text_bbox_adjusted[3] = min(img_pil.height, text_bbox_adjusted[3]) # Ensure bottom is not more than height

            # Draw text background and text
            draw.rectangle(text_bbox_adjusted, fill="green")
            draw.text((text_x, text_y), label, fill="black", font=font)

            print(f"✅ Plate {i+1}: Text='{plate_text}', CRNN_Conf={plate_conf:.2f}, OD_Conf={od_conf:.2f}, BBox={xyxy}")
            num_detections_drawn += 1

    if num_detections_drawn == 0:
        print(f"❌ No license plates detected by OD in {os.path.basename(image_path)}.")
    else:
        print(f"✅ Processed {num_detections_drawn} detections.")

    # Display in Colab (Matplotlib)
    plt.figure(figsize=(20, 15)) # Larger figure for better visibility
    plt.imshow(annotated_img_pil)
    plt.axis('off')
    plt.title(f"Annotated: {os.path.basename(image_path)}")
    plt.show()

    base_name = os.path.basename(image_path)
    name, ext = os.path.splitext(base_name)
    annotated_image_path = os.path.join(output_dir_img, f"{name}_annotated{ext}")
    try:
        annotated_img_pil.save(annotated_image_path)
        print(f"\n✅ Annotated image saved to {annotated_image_path}")
    except Exception as e:
        print(f"❌ Error saving annotated image {annotated_image_path}: {e}")

def process_image_input(input_path, od_model_inf_callable, current_od_model_type_str,
                        crnn_model_inf, crnn_img_transform_inf, device_to_use, output_dir_imgs_base):
    """
    Processes a single image or all images in a folder with a progress bar for folder processing.
    """
    if os.path.isfile(input_path):
        print(f"\n--- Processing single image: {os.path.basename(input_path)} ---")
        process_single_image(input_path, od_model_inf_callable, current_od_model_type_str,
                             crnn_model_inf, crnn_img_transform_inf, device_to_use, output_dir_imgs_base)
    elif os.path.isdir(input_path):
        print(f"\n--- Processing images in folder: {input_path} ---")
        image_files = []
        # Common image extensions
        for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.webp"]:
            image_files.extend(glob.glob(os.path.join(input_path, ext.lower())))
            image_files.extend(glob.glob(os.path.join(input_path, ext.upper()))) # Handle uppercase extensions
        image_files = sorted(list(set(image_files))) # Remove duplicates and sort

        if not image_files:
            print(f"❌ No images found in {input_path}")
            return

        # Use tqdm for the loop over image files
        for image_file in tqdm(image_files, desc="Processing images in folder", unit="image"):
            print(f"\n--- Processing {os.path.basename(image_file)} ---") # Still useful to know which file is next
            process_single_image(image_file, od_model_inf_callable, current_od_model_type_str,
                                 crnn_model_inf, crnn_img_transform_inf, device_to_use, output_dir_imgs_base)
        print("\n✅ Finished processing all images in the folder.")
    else:
        print(f"❌ Input path {input_path} is not a valid file or directory.")

print("✅ Image processing functions defined")

# Video Processing

In [None]:
def process_single_video(
    video_path, od_model_inf_callable, current_od_model_type_str, crnn_model_inf, crnn_img_transform_inf, device_to_use, output_dir_vid,
    display_frames=True, display_interval=1,
    iou_threshold=0.3, smoothing_window_size=7,
    max_track_age=10,
    plausible_plate_threshold=0.3,
    hysteresis_count=2
):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"❌ Error opening video file {video_path}")
        return

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS) if cap.get(cv2.CAP_PROP_FPS) > 0 else 30.0
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    base_name = os.path.basename(video_path)
    name, ext = os.path.splitext(base_name)
    annotated_video_path = os.path.join(output_dir_vid, f"{name}_annotated.mp4")
    out = cv2.VideoWriter(annotated_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))

    print(f"Processing video: {video_path} ({total_frames} frames)")
    print(f"Parameters: IoU={iou_threshold}, Age={max_track_age}, Window={smoothing_window_size}, PlausConf={plausible_plate_threshold}, Hyst={hysteresis_count}")
    print(f"Output will be saved to: {annotated_video_path}")

    font_cv = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.6
    box_color = (0, 255, 0)
    thickness = 2
    text_bg_color = (0, 255, 0)
    text_font_color = (0,0,0)

    active_tracks = {}
    next_track_id = 0
    processed_frames_count_for_log = 0

    # For displaying frames
    if display_frames:
         from IPython.display import display, clear_output
         import ipywidgets as widgets
         from IPython.display import HTML, display as ipydisplay
         from base64 import b64encode

         # Setup for displaying video in Colab output
         output_widget = widgets.Output()
         ipydisplay(output_widget)

    for frame_count_idx in tqdm(range(total_frames), desc=f"Processing {base_name}", unit="frame"):
        ret, frame_cv = cap.read()
        if not ret:
            print(f"\n⚠️ Warning: Could not read frame {frame_count_idx + 1}/{total_frames}. Ending video processing early")
            break
        processed_frames_count_for_log = frame_count_idx + 1

        frame_pil = Image.fromarray(cv2.cvtColor(frame_cv, cv2.COLOR_BGR2RGB))

        # --- Object Detection Call ---
        od_results = None
        if current_od_model_type_str.startswith("yolo"):
            od_results = od_model_inf_callable(frame_pil, verbose=False, conf=0.25)
        elif current_od_model_type_str == "efficientdet":
            od_results = run_efficientdet_inference(
                tf_infer_function=od_model_inf_callable,
                pil_image=frame_pil,
                expected_input_size=(512, 512),
                score_threshold=0.35,
                license_plate_class_id=1 # VERIFY THIS
            )
        else:
            print(f"❌ Unknown OD model type for inference in video: {current_od_model_type_str}")
            out.write(annotated_frame_cv) # Write unannotated frame
            continue
        # --- End Object Detection Call ---

        annotated_frame_cv = frame_cv.copy()

        current_frame_detections = [] # Store detections for this frame WITH their assigned track IDs
        if od_results and len(od_results[0].boxes) > 0:
            for i, box_data in enumerate(od_results[0].boxes):
                xyxy = box_data.xyxy[0].cpu().numpy().astype(int)
                od_conf = box_data.conf[0].cpu().numpy()
                xmin, ymin, xmax, ymax = xyxy
                xmin = max(0, xmin); ymin = max(0, ymin)
                xmax = min(frame_pil.width, xmax); ymax = min(frame_pil.height, ymax)
                if xmax <= xmin or ymax <= ymin: continue

                cropped_plate_pil = frame_pil.crop((xmin, ymin, xmax, ymax))
                raw_text, raw_conf = "N/A", 0.0
                if cropped_plate_pil.width > 5 and cropped_plate_pil.height > 5:
                    raw_text, raw_conf = predict_plate_with_crnn_and_conf(cropped_plate_pil, crnn_model_inf, crnn_img_transform_inf, device_to_use)

                current_frame_detections.append({
                    'bbox': (xmin, ymin, xmax, ymax), 'raw_text': raw_text,
                    'raw_conf': raw_conf, 'od_conf': od_conf,
                    'assigned_track_id': None # Will be filled if matched
                })

        # --- Track Management ---
        matched_track_ids_this_frame = set()

        for det_idx, det_data in enumerate(current_frame_detections): # Iterate over copy for modification
            best_iou_val = 0; matched_id = -1
            for track_id_key, track_data_val in active_tracks.items():
                iou_val = calculate_iou(det_data['bbox'], track_data_val['last_bbox'])
                if iou_val > iou_threshold and iou_val > best_iou_val:
                    best_iou_val = iou_val
                    matched_id = track_id_key

            if matched_id != -1:
                track = active_tracks[matched_id]
                track['last_bbox'] = det_data['bbox'] # CRITICAL: Update bbox with current detection
                track['age'] = 0
                track['visible_this_frame'] = True
                current_frame_detections[det_idx]['assigned_track_id'] = matched_id # Store track_id with current detection
                matched_track_ids_this_frame.add(matched_id)

                is_plausible = is_plausible_malaysian_plate(det_data['raw_text'])
                if is_plausible and det_data['raw_conf'] >= plausible_plate_threshold:
                    track['text_conf_history'].append((det_data['raw_text'], det_data['raw_conf']))
                elif is_plausible:
                    track['text_conf_history'].append((det_data['raw_text'], det_data['raw_conf'])) # Store even if low conf but plausible

                while len(track['text_conf_history']) > smoothing_window_size:
                    track['text_conf_history'].pop(0)

                if track['text_conf_history']:
                    texts_in_history = [item[0] for item in track['text_conf_history']]
                    vote_counts = Counter(texts_in_history)
                    if vote_counts:
                        most_common_text, count = vote_counts.most_common(1)[0]
                        if most_common_text != track['smoothed_text']:
                            if track['candidate_text'] == most_common_text:
                                track['candidate_persistence'] +=1
                            else:
                                track['candidate_text'] = most_common_text
                                track['candidate_persistence'] = 1
                            if track['candidate_persistence'] >= hysteresis_count:
                                track['smoothed_text'] = most_common_text
                                confs_for_smoothed_text = [item[1] for item in track['text_conf_history'] if item[0] == most_common_text]
                                track['smoothed_conf'] = np.mean(confs_for_smoothed_text) if confs_for_smoothed_text else 0.0
                                track['candidate_persistence'] = 0
                        else:
                            track['candidate_persistence'] = 0
                            confs_for_smoothed_text = [item[1] for item in track['text_conf_history'] if item[0] == track['smoothed_text']]
                            track['smoothed_conf'] = np.mean(confs_for_smoothed_text) if confs_for_smoothed_text else track['smoothed_conf']
            # else:
                # This detection is unmatched, will be considered for a new track later if not assigned.


        ids_to_remove = []
        for track_id_key, track_data_val in active_tracks.items():
            if not track_data_val.get('visible_this_frame', False):
                track_data_val['age'] += 1
            if track_data_val['age'] > max_track_age:
                ids_to_remove.append(track_id_key)
            track_data_val['visible_this_frame'] = False # Reset for next frame
        for r_id in ids_to_remove:
            if r_id in active_tracks:
                del active_tracks[r_id]

        for det_idx, det_data in enumerate(current_frame_detections): # Iterate again to create new tracks for truly unmatched
            if det_data['assigned_track_id'] is None: # Still unmatched after trying to associate with existing tracks
                is_plausible = is_plausible_malaysian_plate(det_data['raw_text'])
                initial_text_conf_history = []
                if is_plausible and det_data['raw_conf'] >= plausible_plate_threshold:
                     initial_text_conf_history.append((det_data['raw_text'], det_data['raw_conf']))

                active_tracks[next_track_id] = {
                    'last_bbox': det_data['bbox'], # Current bbox
                    'text_conf_history': initial_text_conf_history,
                    'smoothed_text': det_data['raw_text'] if is_plausible else "N/A",
                    'smoothed_conf': det_data['raw_conf'] if is_plausible else 0.0,
                    'age': 0, 'visible_this_frame': True,
                    'candidate_text': None, 'candidate_persistence': 0
                }
                current_frame_detections[det_idx]['assigned_track_id'] = next_track_id # Assign new track_id to current detection
                next_track_id += 1

        # Iterate through current_frame_detections. If a detection was assigned a track,
        # use its current bbox and the track's smoothed text.
        for det_data in current_frame_detections:
            if det_data['assigned_track_id'] is not None:
                track_id_key = det_data['assigned_track_id']
                # Ensure the track still exists (it should, as we just updated/created it)
                if track_id_key in active_tracks:
                    track_data_val = active_tracks[track_id_key]

                    # Use the CURRENT detection's bbox for drawing
                    xmin, ymin, xmax, ymax = det_data['bbox']

                    # Use the track's smoothed text and confidence
                    display_text = track_data_val['smoothed_text']
                    display_crnn_conf = track_data_val['smoothed_conf']

                    # OD confidence from the current detection
                    display_od_conf = det_data['od_conf']

                    label = f"{display_text} (C:{display_crnn_conf:.2f} | OD:{display_od_conf:.2f})"

                    cv2.rectangle(annotated_frame_cv, (xmin, ymin), (xmax, ymax), box_color, thickness)
                    (txt_w, txt_h), base = cv2.getTextSize(label, font_cv, font_scale, thickness)
                    txt_y_coord = ymin - txt_h - base - 5
                    bg_y1_coord = txt_y_coord - base; bg_y2_coord = ymin - base + 5
                    if bg_y1_coord < 0:
                        txt_y_coord = ymax + txt_h + 5
                        bg_y1_coord = ymax+5
                        bg_y2_coord = ymax+txt_h+base+5
                    cv2.rectangle(annotated_frame_cv, (xmin, bg_y1_coord), (xmin + txt_w + 5, bg_y2_coord), text_bg_color, -1)
                    cv2.putText(annotated_frame_cv, label, (xmin + 2, txt_y_coord + txt_h), font_cv, font_scale, text_font_color, thickness, cv2.LINE_AA)

        out.write(annotated_frame_cv)
        if display_frames and (frame_count_idx + 1) % display_interval == 0:
            print(f"\nDisplaying frame {frame_count_idx + 1}/{total_frames} of {base_name}")
            display_img = cv2.cvtColor(annotated_frame_cv, cv2.COLOR_BGR2RGB)
            display(IPImage(data=cv2.imencode('.jpeg', annotated_frame_cv)[1].tobytes()))

    cap.release()
    out.release()
    print(f"\n✅ Video processing complete for {base_name}. {processed_frames_count_for_log} frames processed")
    print(f"✅ Annotated video saved to {annotated_video_path}")

def process_video_input(input_path, od_model, current_od_model_name, crnn_model, crnn_image_transform, device, output_dir,
                        display_frames=True, display_interval=15,
                        iou_threshold=0.3, smoothing_window_size=5, max_track_age=10,
                        plausible_plate_threshold=0.3, hysteresis_count=2):
    """
    Processes a single video or all videos in a folder.
    Selected frames are displayed sequentially.
    """
    video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm', '.mpg', '.mpeg', '.ts']
    if os.path.isfile(input_path):
        if any(input_path.lower().endswith(ext) for ext in video_extensions):
            print(f"\n--- Processing single video: {os.path.basename(input_path)} ---")
            process_single_video(input_path, od_model, current_od_model_name, crnn_model, crnn_image_transform, device, output_dir,
                         display_frames, display_interval,
                         iou_threshold, smoothing_window_size, max_track_age,
                         plausible_plate_threshold, hysteresis_count)
        else:
            print(f"❌ {input_path} is not a recognized video file type (checked: {', '.join(video_extensions)}).")
    elif os.path.isdir(input_path):
        print(f"\n--- Processing videos in folder: {input_path} ---")
        video_files = []
        for ext_pattern in ["*.mp4", "*.avi", "*.mov", "*.mkv", "*.wmv", "*.flv", "*.webm", "*.mpg", "*.mpeg", "*.ts"]:
            video_files.extend(glob.glob(os.path.join(input_path, ext_pattern.lower())))
            video_files.extend(glob.glob(os.path.join(input_path, ext_pattern.upper())))
        video_files = sorted(list(set(video_files)))

        if not video_files:
            print(f"❌ No videos found in {input_path}")
            return

        for video_file in tqdm(video_files, desc="Processing video folder", unit="video"):
            print(f"\n--- Processing video: {os.path.basename(video_file)} ---")
            process_single_video(video_file, od_model, current_od_model_name, crnn_model, crnn_image_transform, device, output_dir,
                         display_frames, display_interval,
                         iou_threshold, smoothing_window_size, max_track_age,
                         plausible_plate_threshold, hysteresis_count)
        print("\n✅ Finished processing all videos in the folder.")
    else:
        print(f"❌ Input path {input_path} is not a valid file or directory.")

print("✅ Video processing functions defined")

# Load Models and Run Processing

In [None]:
# --- Choose which OD model to use ---
# Option 1: YOLOv11
od_model_to_use_path = yolov11_model_path
current_od_model_name = "yolov11"

# Option 2: YOLOv10
# od_model_to_use_path = yolov10_model_path
# current_od_model_name = "yolov10"

# Option 3: EfficientDet
# od_model_to_use_path = efficientdet_model_path
# current_od_model_name = "efficientdet"

# --- Load the selected OD model ---
object_detector_callable = None
if not os.path.exists(od_model_to_use_path):
    print(f"❌ FATAL: Chosen OD model path does not exist: {od_model_to_use_path}")
else:
    print(f"Attempting to load OD model '{current_od_model_name}' from: {od_model_to_use_path}")
    if current_od_model_name.startswith("yolo"):
        object_detector_callable = load_od_model(od_model_to_use_path, device)
    elif current_od_model_name == "efficientdet":
        object_detector_callable = load_tf_saved_model(od_model_to_use_path)
    else:
        print(f"❌ Unknown OD model type specified: {current_od_model_name}")

img_height = 32 # CRNN input height

# --- Load CRNN model ---
crnn_recognizer = None
if not os.path.exists(best_model_path):
    print(f"❌ FATAL: CRNN Model weights not found: {best_model_path}")
else:
    print(f"\nUsing CRNN model from: {best_model_path}")
    crnn_recognizer = load_crnn_for_inference(best_model_path, img_height, num_classes, device)

In [None]:
if object_detector_callable and crnn_recognizer:
    print("✅ All models loaded successfully for pipeline.")

    # Determine output paths based on the OD model being used
    if current_od_model_name == "yolov11":
        model_specific_output_base = os.path.join(annotated_output_path, 'yolov11')
        print(f"Outputting results to YOLOv11 specific directory: {model_specific_output_base}")
    elif current_od_model_name == "yolov10":
        model_specific_output_base = os.path.join(annotated_output_path, 'yolov10')
        print(f"Outputting results to YOLOv10 specific directory: {model_specific_output_base}")
    elif current_od_model_name == "efficientdet":
        model_specific_output_base = os.path.join(annotated_output_path, 'efficientdet')
        print(f"Outputting results to EfficientDet specific directory: {model_specific_output_base}")
    else:
        # Fallback for any other model name or if current_od_model_name is not set properly
        model_specific_output_base = os.path.join(annotated_output_path, 'unknown_od_model')
        print(f"⚠️ Warning: Unknown OD model name '{current_od_model_name}'. Outputting to: {model_specific_output_base}")

    # Create model-specific subdirectories for images and videos
    pipeline_output_images_path = os.path.join(model_specific_output_base, 'images')
    pipeline_output_videos_path = os.path.join(model_specific_output_base, 'videos')
    os.makedirs(model_specific_output_base, exist_ok=True)
    os.makedirs(pipeline_output_images_path, exist_ok=True)
    os.makedirs(pipeline_output_videos_path, exist_ok=True)

    # --- Test Cases (using the determined pipeline_output_images and pipeline_output_videos) ---

    # Test with a single image
    # sample_image_path = "/content/gdrive/MyDrive/cos30018/NO20250513-132446-130571F_frame_1500.jpg"
    # if os.path.exists(sample_image_path):
    #    process_image_input(sample_image_path, object_detector_callable, current_od_model_name, crnn_recognizer, crnn_inference_transform, device, pipeline_output_images_path)
    # else:
    #    print(f"❌ Image {sample_image_path} not found. Skipping.")

    # Test with an image folder
    sample_image_folder = "/content/gdrive/MyDrive/cos30018/test_image"
    if os.path.exists(sample_image_folder) and os.listdir(sample_image_folder):
         process_image_input(sample_image_folder, object_detector_callable, current_od_model_name, crnn_recognizer, crnn_inference_transform, device, pipeline_output_images_path)
    else:
        print(f"❌ Image folder {sample_image_folder} not found or is empty. Skipping.")

    # Test with a single video
    sample_video_path = "/content/gdrive/MyDrive/cos30018/20250426070025_041707.TS"
    if os.path.exists(sample_video_path):
        process_video_input(
            sample_video_path,
            object_detector_callable,
            current_od_model_name,
            crnn_recognizer,
            crnn_inference_transform,
            device,
            output_dir=pipeline_output_videos_path, # Changed from positional to keyword arg
            display_frames=True,
            display_interval=30,
            iou_threshold=0.3,
            smoothing_window_size=10,
            max_track_age=12,
            plausible_plate_threshold=0.25,
            hysteresis_count=4
        )
    else:
        print(f"❌ Video {sample_video_path} not found. Skipping.")

    # Test with a video folder
    # sample_video_folder = "/content/gdrive/MyDrive/cos30018/test_video"
    # video_extensions_check = ('.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm', '.mpg', '.mpeg', '.ts')
    # if os.path.exists(sample_video_folder) and any(f.lower().endswith(video_extensions_check) for f in os.listdir(sample_video_folder)):
    #     process_video_input(
    #         sample_video_folder,
    #         object_detector_callable,
    #         current_od_model_name,
    #         crnn_recognizer,
    #         crnn_inference_transform,
    #         device,
    #         output_dir=pipeline_output_videos_path, # Changed from positional to keyword arg
    #         display_frames=True,
    #         display_interval=30,
    #         iou_threshold=0.3,
    #         smoothing_window_size=10,
    #         max_track_age=12,
    #         plausible_plate_threshold=0.25,
    #         hysteresis_count=4
    #     )
    # else:
    #     print(f"❌ Video folder {sample_video_folder} not found or is empty or contains no videos. Skipping.")

    print("\n✅ Pipeline processing finished for OD model:", current_od_model_name)

else:
    print("❌ One or more models failed to load. Cannot proceed with processing")
    if not object_detector_callable:
        print(f"❌ Object Detector failed to load. Check path related to '{current_od_model_name if 'current_od_model_name' in locals() else 'chosen OD model'}'.")
    if not crnn_recognizer:
        print(f"❌ CRNN Recognizer failed to load. Check path: {best_model_path}")