In [2]:
import torch
import warnings
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import random
import torch
from transformers import ConditionalDetrFeatureExtractor, ConditionalDetrForObjectDetection

warnings.filterwarnings("ignore")

In [3]:
MODEL_NAME = "microsoft/conditional-detr-resnet-50"
feature_extractor = ConditionalDetrFeatureExtractor.from_pretrained(MODEL_NAME)
model = ConditionalDetrForObjectDetection.from_pretrained(MODEL_NAME)

In [4]:
image = Image.open("000000039769.jpg")

In [5]:
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)

In [6]:
size = image.size
target_sizes = torch.tensor((size[1], size[0])).unsqueeze(0)
results = feature_extractor.post_process(outputs, target_sizes)[0]
filters = results['scores'] > 0.5

In [7]:
scores = results['scores'][filters].tolist()
labels = results['labels'][filters].tolist()
boxes = results['boxes'][filters].tolist()

In [8]:
print("Object Count: ", len(labels))
print("Found Labels:")
for i in set(labels):
    print(model.config.id2label[i])

Object Count:  5
Found Labels:
cat
remote
couch


In [9]:
def generate_random_color():
    r = random.uniform(0, 1)
    g = random.uniform(0, 1)
    b = random.uniform(0, 1)
    return [r, g, b]

In [None]:
fig, ax = plt.subplots()
ax.imshow(image)

for score, box, label in zip(scores, boxes, labels):
    xmin, ymin, xmax, ymax = box
    color = generate_random_color()
    rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor=color)
    ax.add_patch(rect)
    text = f"{model.config.id2label[label]}: {score:.2f}"
    ax.text(xmin, ymin, text, fontsize=8, color="white", verticalalignment="top")
    
plt.axis("off")
plt.show()