In [5]:
import torch
from torchvision.transforms import transforms
from PIL import Image
import cv2
import numpy as np
import json
from ultralytics import YOLO 

In [6]:
# Load OCR Model
class CNN(torch.nn.Module):
    def __init__(self, in_size=(200, 50, 1), conv_out_1=32, conv_kern_1=3, conv_out_2=64,
                 conv_kern_2=3, fc1_out=64, lstm1_out=128, lstm2_out=64, fc2_out=38):
        super(CNN, self).__init__()
        w, h, c = in_size
        self.conv1 = torch.nn.Conv2d(c, conv_out_1, kernel_size=conv_kern_1, padding="same")
        self.conv2 = torch.nn.Conv2d(conv_out_1, conv_out_2, kernel_size=conv_kern_2, padding="same")
        self.fc1 = torch.nn.Linear(conv_out_2, fc1_out)
        self.lstm1 = torch.nn.LSTM(fc1_out * (h // 4), lstm1_out, bidirectional=True, batch_first=True)
        self.lstm2 = torch.nn.LSTM(lstm1_out * 2, lstm2_out, bidirectional=True, batch_first=True)
        self.fc2 = torch.nn.Linear(lstm2_out * 2, fc2_out)

    def forward(self, x):
        batch_size, c, h, w = x.shape
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.permute(0, 3, 2, 1)
        x = torch.relu(self.fc1(x))
        x = torch.dropout(x, p=0.2, train=False)
        x = x.reshape(batch_size, w // 4, -1)
        x, _ = self.lstm1(x)
        x = torch.dropout(x, p=0.25, train=False)
        x, _ = self.lstm2(x)
        x = torch.dropout(x, p=0.25, train=False)
        x = torch.log_softmax(self.fc2(x), dim=-1)
        return x.permute(1, 0, 2)

In [7]:
import pandas as pd
import os
# Load character classes
with open("/home/infres/lotfi-23/notebooks/classes.json", "r") as f:
    classes = json.load(f)

# Load OCR Model and Classes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ocr_model = CNN(fc2_out=(len(classes) + 1))  # Adjust `fc2_out` based on your trained model
# ocr_model.load_state_dict(torch.load('/home/infres/lotfi-23/training/weights_last.pth', map_location=device))
ocr_model = torch.load("/home/infres/lotfi-23/notebooks/saved_models/model_epoch_200.pth")
ocr_model = ocr_model.to(device)
ocr_model.eval()



# Define transforms for the OCR model
ocr_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float32)
])

# Decode predictions
def decode_predictions(predictions, classes):
    pred_str = ''.join([classes[c - 1] if c != 0 else '_' for c in predictions])
    pred_str = ''.join(ch for i, ch in enumerate(pred_str) if ch != '_' and (i == 0 or ch != pred_str[i - 1]))
    return pred_str

# Load YOLOv8 Model
yolo_model = YOLO("/home/infres/lotfi-23/notebooks/yolo-best.pt")  # Replace with your YOLOv8 model path

# Full Inference Pipeline
def infer_license_plate(image_path, yolo_model, ocr_model, ocr_transforms, classes, device):
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not read image at {image_path}")

    results = yolo_model(image)
    if len(results[0].boxes) == 0:
        raise ValueError("No license plate detected.")

    bbox = results[0].boxes[0].xyxy[0].cpu().numpy()
    x_min, y_min, x_max, y_max = map(int, bbox)
    plate_image = image[y_min:y_max, x_min:x_max]

    plate_image_pil = Image.fromarray(cv2.cvtColor(plate_image, cv2.COLOR_BGR2RGB)).convert('L')
    plate_image_resized = plate_image_pil.resize((200, 50))
    plate_tensor = ocr_transforms(plate_image_resized).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = ocr_model(plate_tensor)
        predictions = outputs.argmax(dim=-1).cpu().numpy()[:, 0]

    return decode_predictions(predictions, classes)

# Filter Ground Truth
def filter_ground_truth(csv_path):
    """
    Filter the ground truth based on:
    - No spaces in the plate string.
    - Plate string length equals 8.
    - Plate string starts with two digits followed by a non-digit.
    """
    df = pd.read_csv(csv_path)
    df = df[~df['plateString'].str.contains(' ', na=True)]
    df = df[df['plateString'].str.len().eq(8)]
    df = df[df['plateString'].str.match(r'^\d{2}[^0-9]')]

    df['nameOfTheFile'] = df['nameOfTheFile'] + ".jpg"
    return df


In [9]:
import pandas as pd
import os
import time

# ground_truth_csv = "/home/infres/lotfi-23/test/string_plate_test.csv"
# filtered_ground_truth = filter_ground_truth(ground_truth_csv)
# filtered_images = dict(zip(filtered_ground_truth['nameOfTheFile'], filtered_ground_truth['plateString']))

# Evaluate on filtered data with timing
def infer_and_evaluate_filtered_with_timing(yolo_model, ocr_model, ocr_transforms, classes, device, image_dir, filtered_images):
    correct = 0
    total = len(filtered_images)
    results = []
    total_time = 0  # To accumulate inference times

    for img_name, gt_text in filtered_images.items():
        img_path = os.path.join(image_dir, img_name)
        if not os.path.isfile(img_path):
            print(f"File not found: {img_path}")
            continue

        try:
            # Start timing
            start_time = time.time()
            
            # Run end-to-end inference
            plate_text = infer_license_plate(img_path, yolo_model, ocr_model, ocr_transforms, classes, device)
            
            # End timing
            end_time = time.time()
            inference_time = end_time - start_time
            total_time += inference_time

            match = plate_text == gt_text  # Case-insensitive comparison
            correct += int(match)
            print(f"Image: {img_name}, GT: {gt_text}, Pred: {plate_text}, Match: {match}, Time: {inference_time:.4f}s")

            results.append({
                'image_name': img_name,
                'ground_truth': gt_text,
                'prediction': plate_text,
                'match': match,
                'inference_time': inference_time
            })

        except ValueError as e:
            print(f"Skipping {img_name}: {e}")
            results.append({
                'image_name': img_name,
                'ground_truth': gt_text,
                'prediction': None,
                'match': False,
                'inference_time': None
            })

    accuracy = correct / total * 100 if total > 0 else 0
    average_time = total_time / total if total > 0 else 0
    print(f"Accuracy on filtered data: {accuracy:.2f}%")
    print(f"Average inference time per image: {average_time:.4f}s")
    return results, accuracy, average_time

# Main execution
ground_truth_csv = "/home/infres/lotfi-23/test/string_plate_test.csv"
test_images_dir = "/home/infres/lotfi-23/test/Images"

filtered_ground_truth = filter_ground_truth(ground_truth_csv)
print(f"filtered_ground_truth: {filtered_ground_truth}")
filtered_images = dict(zip(filtered_ground_truth['nameOfTheFile'], filtered_ground_truth['plateString']))

results, accuracy, avg_inference_time = infer_and_evaluate_filtered_with_timing(
    yolo_model, ocr_model, ocr_transforms, classes, device, test_images_dir, filtered_images
)

results_df = pd.DataFrame(results)
results_df.to_csv("filtered_inference_results.csv", index=False)
print("Results saved to 'filtered_inference_results.csv'")
print(f"Average inference time per image: {avg_inference_time:.4f}s")

filtered_ground_truth:          nameOfTheFile plateString
0        day_15564.jpg    58B39218
1        day_08611.jpg    61Q48444
2        day_16010.jpg    87Ṣ57418
3        day_08270.jpg    47M37153
4        day_09842.jpg    25S63677
...                ...         ...
4171  night (1019).jpg    18J85320
4172  night (4029).jpg    93Ṭ87516
4173     day_11265.jpg    32Ṭ29999
4174     day_02622.jpg    35M84724
4175     day_10425.jpg    18D39966

[3574 rows x 2 columns]



0: 384x640 1 کل ناحیه پلاک, 9.0ms
Speed: 1.3ms preprocess, 9.0ms inference, 2.6ms postprocess per image at shape (1, 3, 384, 640)
Image: day_15564.jpg, GT: 58B39218, Pred: 58B39218, Match: True, Time: 0.0285s

0: 480x640 1 کل ناحیه پلاک, 8.4ms
Speed: 1.5ms preprocess, 8.4ms inference, 1.1ms postprocess per image at shape (1, 3, 480, 640)
Image: day_08611.jpg, GT: 61Q48444, Pred: 61Q4844, Match: False, Time: 0.0317s

0: 640x480 1 کل ناحیه پلاک, 8.3ms
Speed: 1.5ms preprocess, 8.3ms inference, 1.1ms postprocess per image at shape (1, 3, 640, 480)
Image: day_16010.jpg, GT: 87Ṣ57418, Pred: 87Ṣ57418, Match: True, Time: 0.0287s

0: 480x640 1 کل ناحیه پلاک, 8.1ms
Speed: 0.9ms preprocess, 8.1ms inference, 1.1ms postprocess per image at shape (1, 3, 480, 640)
Image: day_08270.jpg, GT: 47M37153, Pred: 47M37153, Match: True, Time: 0.0174s

0: 640x480 1 کل ناحیه پلاک, 8.6ms
Speed: 1.4ms preprocess, 8.6ms inference, 1.1ms postprocess per image at shape (1, 3, 640, 480)
Image: day_09842.jpg, GT: 25S6