In [3]:
from pathlib import Path
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm

IMG_DIR = Path("D:\\work_space\\projects\\deep_learning\\data_set\\final_data_set\\processed")

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


CAPTION_TEMPLATES = [
    "a speed bump on a road",
    "a yellow speed bump on asphalt",
    "a speed bump on a residential street",
    "a speed bump in daylight",
    "a speed bump on a suburban road",
    "a close view of a road speed bump",
    "a traffic control speed bump",
    "a concrete speed bump",
    "a wide view of a speed bump on the road",
    "a speed bump with cars nearby",
    "a speed bump on a street with road signs",
    "a speed bump seen from above",
    "a speed bump on a rainy road",
]

def rank_captions(img_path):
    image = Image.open(img_path).convert("RGB")

    inputs = processor(text=CAPTION_TEMPLATES, images=image, return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits_per_image[0]  # similarity score

    best_caption = CAPTION_TEMPLATES[logits.argmax().item()]
    return best_caption


count = 0

for img_path in tqdm(sorted(IMG_DIR.glob("*.jpg")), desc="Generating image captions"):
    best_caption = rank_captions(img_path)

    # Save caption next to image
    txt_path = img_path.with_suffix(".txt")
    with open(txt_path, "w") as f:
        f.write(best_caption)

    count += 1

print(f"\nImage captioning Generated {count} captions.")
print(f"Sample image: {img_path.name}")
print(f"Sample caption: {best_caption}")

Using device: cuda


Generating image captions: 100%|██████████| 365/365 [00:12<00:00, 29.21it/s]


Image captioning Generated 365 captions.
Sample image: speedbump_00364.jpg
Sample caption: a speed bump on a rainy road



