In [1]:
import torch
from torchvision.transforms import transforms
from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np
import json
from ultralytics import YOLO
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import os


  warn(


In [2]:
base_dir = "."

In [3]:
dataset_dir = "./test/"

In [4]:
# Load OCR Model
class CNN(torch.nn.Module):
    def __init__(self, in_size=(200, 50, 1), conv_out_1=32, conv_kern_1=3, conv_out_2=64,
                 conv_kern_2=3, fc1_out=64, lstm1_out=128, lstm2_out=64, fc2_out=38):
        super(CNN, self).__init__()
        w, h, c = in_size
        self.conv1 = torch.nn.Conv2d(c, conv_out_1, kernel_size=conv_kern_1, padding="same")
        self.conv2 = torch.nn.Conv2d(conv_out_1, conv_out_2, kernel_size=conv_kern_2, padding="same")
        self.fc1 = torch.nn.Linear(conv_out_2, fc1_out)
        self.lstm1 = torch.nn.LSTM(fc1_out * (h // 4), lstm1_out, bidirectional=True, batch_first=True)
        self.lstm2 = torch.nn.LSTM(lstm1_out * 2, lstm2_out, bidirectional=True, batch_first=True)
        self.fc2 = torch.nn.Linear(lstm2_out * 2, fc2_out)

    def forward(self, x):
        batch_size, c, h, w = x.shape
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.permute(0, 3, 2, 1)
        x = torch.relu(self.fc1(x))
        x = torch.dropout(x, p=0.2, train=False)
        x = x.reshape(batch_size, w // 4, -1)
        x, _ = self.lstm1(x)
        x = torch.dropout(x, p=0.25, train=False)
        x, _ = self.lstm2(x)
        x = torch.dropout(x, p=0.25, train=False)
        x = torch.log_softmax(self.fc2(x), dim=-1)
        return x.permute(1, 0, 2)

In [5]:
# Load OCR Model and Classes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# weights_path = "path_to_your_saved_weights.pth"  
# ocr_model.load_state_dict(torch.load(weights_path))
# ocr_model = torch.load("/home/infres/lotfi-23/training/model_last.pth")
ocr_model = torch.load(base_dir + "/best-ocr-model.pt", map_location=torch.device('cpu'))
ocr_model = ocr_model.to(device)
ocr_model.eval()

with open(base_dir + "/classes.json", "r") as f:
    classes = json.load(f)

# Define OCR transforms
ocr_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float32)
])

# Decode OCR predictions
def decode_predictions(predictions, classes):
    pred_str = ''.join([classes[c - 1] if c != 0 else '_' for c in predictions])
    pred_str = ''.join(ch for i, ch in enumerate(pred_str) if ch != '_' and (i == 0 or ch != pred_str[i - 1]))
    return pred_str

# Load YOLO Model
yolo_model = YOLO(base_dir + "/yolo-best.pt")

In [83]:
import PILasOPENCV as CImage
import PILasOPENCV as CImageDraw
import PILasOPENCV as CImageFont

font = CImageFont.truetype("arial.ttf", 24)
np.bool = np.bool_ #library uses deprectated alias

# # Perform inference and visualize results
persian_to_equivalent = {
    "ب": "B",
    "ج": "J",
    "د": "D",
    "س": "S",
    "ص": "Ṣ",
    "ط": "Ṭ",
    "ق": "Q",
    "ل": "L",
    "م": "M",
    "و": "V",
    "ه": "H",
    "ه‍": "H",
    "ن": "N",
    "ی": "Y",
    "الف": "A",
    "پ": "P",
    "ت": "T",
    "ث": "Ṯ",
    "ز": "Z",
    "ژ": "Ž",
    "ش": "Š",
    "ع": "O",
    "ف": "F",
    "ك": "K",
    "گ": "G",
    "۰": "0",
    "۱": "1",
    "۲": "2",
    "۳": "3",
    "۴": "4",
    "۵": "5",
    "۶": "6",
    "۷": "7",
    "۸": "8",
    "۹": "9"
}

pers = {v: k for k,v in persian_to_equivalent.items()}

def detect_and_visualize_cv(image, yolo_model, ocr_model, ocr_transforms, classes, device, next_state, save_path=None, prev_state=None):
    """
    Detect and visualize license plates in an image, showing bounding boxes,
    detected plate text, and confidence scores.
    """

    # Run YOLO detection
    results = yolo_model(image)
    boxes = results[0].boxes

    # if len(boxes) == 0:
    #     #print("No license plates detected.")
    #     return

    # Convert the image to RGB for plt visualization
    # image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    next_state.boxes = []
    next_state.velocities = []
    # next_state.labels = []
    next_state.ages = []

    processed_boxes = []

    def interp(x,px,v):
        return (x + 2 * (px + v // 8)) // 3

    #preprocess boxes
    for box in boxes:
        x_min, y_min, x_max, y_max = map(int, box.xyxy[0].cpu().numpy())
        prev_i = None
        if prev_state is not None:
            for i, p_box in enumerate(prev_state.boxes):
                p_x_min, p_y_min, p_x_max, p_y_max = p_box
                if (p_x_min<=x_min<=p_x_max or p_x_min<=x_max<=p_x_max or x_min<=p_x_min<=x_max or x_min<=p_x_max<=x_max) and (p_y_min<=y_min<=p_y_max or p_y_min<=y_max<=p_y_max or y_min<=p_y_min<=y_max or y_min<=p_y_max<=y_max):
                    prev_i = i
                    break
        if prev_i is not None:
            p_x_min, p_y_min, p_x_max, p_y_max = prev_state.boxes[prev_i]
            vx1, vy1, vx2, vy2 = prev_state.velocities[prev_i]
            x_min = interp(x_min, p_x_min, vx1)
            x_max = interp(x_max, p_x_max, vx2)
            y_min = interp(y_min, p_y_min, vy1)
            y_max = interp(y_max, p_y_max, vy2)
            prev_state.ages[prev_i] = -1
            next_state.velocities.append((x_min - p_x_min, y_min - p_y_min, x_max - p_x_max, y_max - p_y_max))
            next_state.ages.append(0)
        else:
            next_state.velocities.append((0,0,0,0))
            next_state.ages.append(0)
        next_state.boxes.append((x_min, y_min, x_max, y_max))
        processed_boxes.append((x_min, y_min, x_max, y_max, box.conf[0]))

    if prev_state is not None:
        for i, box in enumerate(prev_state.boxes):
            continue
            if 0<=prev_state.ages[i]<1:
                p_x_min, p_y_min, p_x_max, p_y_max = prev_state.boxes[i]
                vx1, vy1, vx2, vy2 = prev_state.velocities[i]
                x_min = interp(p_x_min, p_x_min, vx1)
                x_max = interp(p_x_max, p_x_max, vx2)
                y_min = interp(p_y_min, p_y_min, vy1)
                y_max = interp(p_y_max, p_y_max, vy2)
                next_state.velocities.append((0,0,0,0))
                next_state.ages.append(prev_state.ages[i]+1)
                next_state.boxes.append((x_min, y_min, x_max, y_max))
                processed_boxes.append((x_min, y_min, x_max, y_max, 0))

    print("boxes:", len(processed_boxes))
    for box in processed_boxes:
        # Extract bounding box coordinates
        x_min, y_min, x_max, y_max, conf = box
        if x_min < 0:
            x_min=0
        if x_max >= image.shape[1]:
            x_max = image.shape[1] - 1
        if y_min < 0:
            y_min=0
        if y_max >= image.shape[0]:
            y_max = image.shape[0] - 1
        plate_image = image[y_min:y_max, x_min:x_max]

        # Preprocess plate image for OCR
        try:
            plate_image_pil = Image.fromarray(cv2.cvtColor(plate_image, cv2.COLOR_BGR2RGB)).convert('L')
        except:
            print(x_min, y_min, x_max, y_max, conf)
            plate_image_pil = Image.new("RGB", (1, 1), "black").convert('L')
        plate_image_resized = plate_image_pil.resize((200, 50))
        # plate_image_resized = plate_image_resized.getim()
        plate_tensor = ocr_transforms(plate_image_resized).unsqueeze(0).to(device)

        # Perform OCR
        with torch.no_grad():
            outputs = ocr_model(plate_tensor)
            probabilities = torch.exp(outputs)  # Convert log-softmax to probabilities
            predictions = outputs.argmax(dim=-1).cpu().numpy()[:, 0]
        plate_text = decode_predictions(predictions, classes)
        #print(predictions)
        #print(len(predictions))

        # #print(f"Outputs shape: {outputs.shape}")
        # #print(f"Probabilities range: min={probabilities.min().item()}, max={probabilities.max().item()}")

        # Log character confidences
        # char_confidences = [(classes[c-1], probabilities[0, :, c].max().item()) for c in predictions if c != 0]
        char_confidences1 = []
        for t, c in enumerate(predictions):
            if c != 0 and c <= len(classes):  # Ignore blank index
                char = classes[c - 1]  # Map to the corresponding character
                confidence = probabilities[t, 0, c].item()  # Confidence for this character
                char_confidences1.append((char, confidence))
        
        
        # Map confidences to the final plate text
        char_confidences = []
        prediction_idx = 0  # Pointer for predictions

        for char in plate_text:
            while prediction_idx < len(predictions):
                pred_char_idx = predictions[prediction_idx]
                if pred_char_idx != 0:  # Ignore blanks
                    matched_char = classes[pred_char_idx - 1]
                    if matched_char == char:
                        confidence = probabilities[prediction_idx, 0, pred_char_idx].item()
                        char_confidences.append((char, confidence))
                        prediction_idx += 1
                        break
                prediction_idx += 1

        image = cv2.rectangle(image, (x_min,y_min), (x_max, y_max), (0,0,255), 2)
        plate_text = ''.join([pers[c] for c in plate_text])
        # Add text above the bounding box
        confidence_text = f"{plate_text} ({conf:.2f})"
        # ax.text(x_min, y_min - 15, confidence_text, color='red', fontsize=12,
        #         bbox=dict(facecolor='white', alpha=0.8, edgecolor='red'))
        # im.setim(image)
        # draw.text((x_min, y_min - 15), confidence_text, font=font, fill=(255,0,0))
        # image = im.getim()
        mask = CImageFont.getmask(confidence_text, font)
        # img_mask = np.zeros_like(image)
        if y_min - 15 > 0:
            y1,y2,x1,x2 = y_min-15,(y_min-15+mask.shape[0]),x_min,(x_min+mask.shape[1])
        else:
            y1,y2,x1,x2 = y_min-15,(y_min-15+mask.shape[0]),x_min,(x_min+mask.shape[1])
        
        y_amt, x_amt, _ = image[y1:y2,x1:x2].shape
        image[y1:y2,x1:x2] = np.where(mask[:y_amt,:x_amt, None] != 0, (0,0,255), image[y1:y2,x1:x2] // 2)
        cv2.imshow('image', image)
        cv2.waitKey(1)
        # image = cv2.putText(image, confidence_text, (x_min, y_min - 15), cv2.FONT_HERSHEY_COMPLEX, 1, (0,0,255), 3, cv2.LINE_AA)

    return image
    # Save or display the result
    # if save_path:
    #     # plt.savefig(save_path, bbox_inches='tight')
    #     #print(f"Image saved at {save_path}")
    # else:
    #     pass #plt.show()


def detect_and_visualize(image_path, yolo_model, ocr_model, ocr_transforms, classes, device, save_path=None):
    """
    Detect and visualize license plates in an image, showing bounding boxes,
    detected plate text, and confidence scores.
    """
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle

    # Load the image
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not read image at {image_path}")

    # Run YOLO detection
    results = yolo_model(image)
    boxes = results[0].boxes

    # if len(boxes) == 0:
    #     #print("No license plates detected.")
    #     return

    # Convert the image to RGB for plt visualization
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Initialize a plt figure
    plt.figure(figsize=(12, 8))
    plt.imshow(image_rgb)
    ax = plt.gca()

    for box in boxes:
        # Extract bounding box coordinates
        x_min, y_min, x_max, y_max = map(int, box.xyxy[0].cpu().numpy())
        plate_image = image[y_min:y_max, x_min:x_max]

        # Preprocess plate image for OCR
        plate_image_pil = Image.fromarray(cv2.cvtColor(plate_image, cv2.COLOR_BGR2RGB)).convert('L')
        plate_image_resized = plate_image_pil.resize((200, 50))
        plate_tensor = ocr_transforms(plate_image_resized).unsqueeze(0).to(device)

        # Perform OCR
        with torch.no_grad():
            outputs = ocr_model(plate_tensor)
            probabilities = torch.exp(outputs)  # Convert log-softmax to probabilities
            predictions = outputs.argmax(dim=-1).cpu().numpy()[:, 0]
        plate_text = decode_predictions(predictions, classes)
        #print(predictions)
        #print(len(predictions))

        # #print(f"Outputs shape: {outputs.shape}")
        # #print(f"Probabilities range: min={probabilities.min().item()}, max={probabilities.max().item()}")

        # Log character confidences
        # char_confidences = [(classes[c-1], probabilities[0, :, c].max().item()) for c in predictions if c != 0]
        char_confidences1 = []
        for t, c in enumerate(predictions):
            if c != 0 and c <= len(classes):  # Ignore blank index
                char = classes[c - 1]  # Map to the corresponding character
                confidence = probabilities[t, 0, c].item()  # Confidence for this character
                char_confidences1.append((char, confidence))
        
        
        # Map confidences to the final plate text
        char_confidences = []
        prediction_idx = 0  # Pointer for predictions

        for char in plate_text:
            while prediction_idx < len(predictions):
                pred_char_idx = predictions[prediction_idx]
                if pred_char_idx != 0:  # Ignore blanks
                    matched_char = classes[pred_char_idx - 1]
                    if matched_char == char:
                        confidence = probabilities[prediction_idx, 0, pred_char_idx].item()
                        char_confidences.append((char, confidence))
                        prediction_idx += 1
                        break
                prediction_idx += 1


        # char_confidences = {classes[c]: probabilities[t, 0, c].item()
        #             for t, c in enumerate(predictions) if c <= len(classes)}
        # char_confidences = []
        # for t, c in enumerate(predictions):
        #     if c != 0 and c <= len(classes):  # Ignore blank and invalid indices
        #         char_confidences.append({
        #             'char': classes[c - 1],
        #             'confidence': probabilities[t, 0, c].item()
        #         })
        #print(f'Bounding Box Confidence: {box.conf[0]:.2f}')
        #print(f"Plate text: {plate_text}")
        # #print(f"Character confidences: {char_confidences1}")
        #print(f"Final Character confidences: {char_confidences}")

        # Draw bounding box
        ax.add_patch(Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                               linewidth=2, edgecolor='red', facecolor='none'))

        # Add text above the bounding box
        confidence_text = f"{plate_text} ({box.conf[0]:.2f})"
        ax.text(x_min, y_min - 15, confidence_text, color='red', fontsize=12,
                bbox=dict(facecolor='white', alpha=0.8, edgecolor='red'))

    # Save or display the result
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()
        #print(f"Image saved at {save_path}")
    else:
        pass #plt.show()


In [7]:
# # Example Usage
# image_path = "/Users/panda/IR-ALPR/iran-dashcam-already.mp4"  # Replace with the path to your test image
# save_path = "output_with_bounding_boxes1.jpg"  # Replace or set to None to only display
# detect_and_visualize(image_path, yolo_model, ocr_model, ocr_transforms, classes, device)

In [8]:
class MyState:
    def __init__(self):
        self.boxes = []
        self.velocities = []
        self.ages = []

In [9]:
import io
import numpy as np


def process_input(input_path, yolo_model, ocr_model, ocr_transforms, classes, device, output_path=None):
    """
    Wrapper function to handle both images and videos.
    Args:
        input_path (str): Path to the input file (image or video).
        yolo_model: YOLO model for license plate detection.
        ocr_model: OCR model for character recognition.
        ocr_transforms: Transformations for the OCR model.
        classes: Character classes used in OCR.
        device: Torch device for inference.
        output_path (str): Path to save the output (optional).
    """
    if os.path.isfile(input_path):
        # Determine if the input is an image or video
        ext = os.path.splitext(input_path)[-1].lower()
        if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
            # Handle image
            #print(f"Processing image: {input_path}")
            detect_and_visualize(input_path, yolo_model, ocr_model, ocr_transforms, classes, device, save_path=output_path)
        elif ext in ['.mp4', '.avi', '.mov', '.mkv']:
            # Handle video
            #print(f"Processing video: {input_path}")
            process_video(input_path, yolo_model, ocr_model, ocr_transforms, classes, device, output_video_path=output_path)
        else:
            print(f"Unsupported file type: {ext}")
    else:
        print(f"File not found: {input_path}")

def process_video(video_path, yolo_model, ocr_model, ocr_transforms, classes, device, output_video_path=None):
    """
    Process video input using the existing detect_and_visualize pipeline for each frame.
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Could not open video {video_path}")
    
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    if output_video_path:
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

    frame_count = 0
    try:
        prev_state = None
        state = MyState()
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Save the frame temporarily to pass it to detect_and_visualize
            # temp_image_path = f"temp_frame_{frame_count}.jpg"
            # cv2.imwrite(temp_image_path, frame)

            # Run detection and visualization
            # temp_output_path = f"temp_output_frame_{frame_count}.jpg" if output_video_path else None
            # temp_buf = io.BytesIO() if output_video_path else None
            frame = detect_and_visualize_cv(frame, yolo_model, ocr_model, ocr_transforms, classes, device, state, prev_state=prev_state)
            prev_state = state
            state = MyState()
            if output_video_path:
                # temp_buf.seek(0)
                # bytes_as_np_array = np.frombuffer(temp_buf.read(), dtype=np.uint8)
                # Read back the processed frame and write it to the video
                # processed_frame = cv2.imdecode(bytes_as_np_array, cv2.IMREAD_COLOR)
                out.write(frame)

                # Cleanup temporary files
                # os.remove(temp_output_path)

            # Cleanup the input temporary frame
            # os.remove(temp_image_path)

            frame_count += 1
            print("Frame: ", frame_count)
            # Display the processed frame
            if not output_video_path:
                cv2.imshow("Processed Frame", frame)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
    except KeyboardInterrupt:
        pass
    cap.release()
    if output_video_path:
        out.release()
    print("Total Frames:", frame_count)
    cv2.destroyAllWindows()


In [84]:
image_path = "./iran-dashcam.mp4"  # Replace with the path to your test image
save_path = "./hq-out-6.mp4"  # Replace or set to None to only display
import matplotlib
matplotlib.use('Agg')

process_input(
    input_path=image_path,
    yolo_model=yolo_model,
    ocr_model=ocr_model,
    ocr_transforms=ocr_transforms,
    classes=classes,
    device=device,
    output_path=save_path
)


0: 384x640 (no detections), 47.0ms
Speed: 2.0ms preprocess, 47.0ms inference, 1.0ms postprocess per image at shape (1, 3, 384, 640)
boxes: 0
Frame:  1

0: 384x640 1   , 61.4ms
Speed: 2.0ms preprocess, 61.4ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)
boxes: 1
Frame:  2

0: 384x640 1   , 60.3ms
Speed: 5.0ms preprocess, 60.3ms inference, 2.5ms postprocess per image at shape (1, 3, 384, 640)
boxes: 1
Frame:  3

0: 384x640 1   , 52.0ms
Speed: 3.0ms preprocess, 52.0ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)
boxes: 1
Frame:  4

0: 384x640 1   , 61.2ms
Speed: 2.0ms preprocess, 61.2ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)
boxes: 1
Frame:  5

0: 384x640 1   , 55.2ms
Speed: 2.0ms preprocess, 55.2ms inference, 4.0ms postprocess per image at shape (1, 3, 384, 640)
boxes: 1
Frame:  6

0: 384x640 1   , 64.3ms
Speed: 1.0ms preprocess, 64.3ms inference, 1.0ms postprocess per image at shape (1, 3, 384, 640)
boxes: 1
Frame:  7

0

In [73]:
cv2.destroyAllWindows()