In [None]:
import os
import torch

from PIL import Image, ImageDraw, ImageFont
from transformers import  ConditionalDetrImageProcessor, ConditionalDetrForObjectDetection
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cpu")

ckpt = "./model_ckpt"
image_processor = ConditionalDetrImageProcessor.from_pretrained(ckpt)
model = ConditionalDetrForObjectDetection.from_pretrained(ckpt)
model = model.to(device)

In [None]:
def calculate_iou(box1, box2):
    x1, y1, x2, y2 = box1
    x3, y3, x4, y4 = box2

    xx1, yy1 = max(x1, x3), max(y1, y3)
    xx2, yy2 = min(x2, x4), min(y2, y4)

    intersection_area = max(0, xx2 - xx1) * max(0, yy2 - yy1)

    box1_area = (x2 - x1) * (y2 - y1)
    box2_area = (x4 - x3) * (y4 - y3)

    iou = intersection_area / float(box1_area + box2_area - intersection_area)
    return iou

def non_max_suppression(items, iou_threshold=0.7):
    sorted_items = sorted(items, key=lambda x: x[0], reverse=True)
    
    keep = []
    while sorted_items:
        current = sorted_items.pop(0)
        keep.append(current)
        
        sorted_items = [
            item for item in sorted_items
            if calculate_iou(current[2], item[2]) < iou_threshold
        ]
    
    return keep

def detect_objects(image_path):
    image = Image.open(image_path).convert("RGB")

    with torch.no_grad():
        inputs = image_processor(images=[image], return_tensors="pt")
        outputs = model(**inputs.to(device))
        target_sizes = torch.tensor([[image.size[1], image.size[0]]])
        results = image_processor.post_process_object_detection(outputs, threshold=0.4, target_sizes=target_sizes)[0]

    items = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        score = score.item()
        label = label.item()
        box = [i.item() for i in box]
        print(f"{model.config.id2label[label]}: {round(score, 3)} at {box}")
        items.append((score, label, box))
    
    items = non_max_suppression(items)
        
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default(size=20)
    result = []
    for score, label, bbox in items:
        box = [round(i, 2) for i in bbox]
        result.append((image.crop(bbox), model.config.id2label[label], score))
        x, y, x2, y2 = tuple(bbox)
        draw.rectangle((x, y, x2, y2), outline="green", width=3)
        text_position = (bbox[0], bbox[1] - 25)
        draw.text(text_position, f"{model.config.id2label[label]} {score:.2f}", fill="red", font=font)

    plt.figure(figsize=(12, 9))
    plt.imshow(image)
    plt.xticks([])
    plt.yticks([])
    plt.show()

    return result

In [None]:
image_fnames = [f for f in os.listdir('./sample_images') if not f.startswith('.')]
image_paths = [f'./sample_images/{f}' for f in image_fnames]
for image_path in image_paths:
    detect_objects(image_path)