In [None]:
from pathlib import Path
from PIL import Image
import torch

from src.model import (
    load_blip_captioning,
    load_blip_vqa,
    load_blip_retrieval,
)

caption_model, caption_processor, device = load_blip_captioning()
vqa_model, vqa_processor, _ = load_blip_vqa(device)
retrieval_model, retrieval_processor, _ = load_blip_retrieval(device)

print("Using device:", device)


In [None]:
DATA_DIR = Path("data")

image_paths = [
    DATA_DIR / "cat.jpeg",
    DATA_DIR / "dog.jpeg",
    DATA_DIR / "girl_cat.jpeg",
]

for p in image_paths:
    print(p, "exists:", p.exists())


In [None]:
for img_path in image_paths:
    img = Image.open(img_path).convert("RGB")

    inputs = caption_processor(images=img, return_tensors="pt").to(device)

    with torch.no_grad():
        output_ids = caption_model.generate(
            **inputs,
            max_new_tokens=30,
        )

    caption = caption_processor.decode(
        output_ids[0],
        skip_special_tokens=True,
    )

    print(f"{img_path.name}: {caption}")


In [None]:
questions_per_image = {
    "cat.jpeg": [
        "What animal is in the picture?",
    ],
    "dog.jpeg": [
        "What animal is in the picture?",
        "Is the animal sitting on grass?",
    ],
    "girl_cat.jpeg": [
        "What animal is the girl holding?",
        "Is the girl smiling?",
        "Is this photo taken indoors?",
    ],
}

for img_path in image_paths:
    img = Image.open(img_path).convert("RGB")
    name = img_path.name

    print(f"\n=== {name} ===")

    for question in questions_per_image.get(name, []):
        inputs = vqa_processor(
            images=img,
            text=question,
            return_tensors="pt",
        ).to(device)

        with torch.no_grad():
            output_ids = vqa_model.generate(
                **inputs,
                max_new_tokens=16,
            )

        answer = vqa_processor.decode(
            output_ids[0],
            skip_special_tokens=True,
        )

        print(f"Q: {question}")
        print(f"A: {answer}")


In [None]:
candidate_texts = [
    "A tabby cat sitting outside.",
    "A golden retriever puppy lying on the grass.",
    "A young woman holding a gray cat in her arms.",
]

for img_path in image_paths:
    img = Image.open(img_path).convert("RGB")
    print(f"\n=== {img_path.name} ===")

    scores = []

    for text in candidate_texts:
        inputs = retrieval_processor(
            images=img,
            text=text,
            return_tensors="pt",
        ).to(device)

        with torch.no_grad():
            # itm_scores: чем выше, тем лучше матч
            itm_scores = retrieval_model(**inputs)[0]          # ITM head
            cosine_score = retrieval_model(
                **inputs, use_itm_head=False
            )[0]                                              # cosine similarity

        score = float(itm_scores.squeeze().cpu())
        cos = float(cosine_score.squeeze().cpu())
        scores.append((text, score, cos))

    # сортируем по itm score
    scores.sort(key=lambda x: x[1], reverse=True)

    for text, s, cos in scores:
        print(f"'{text}'  ->  ITM Score={s:.3f}, Cosine={cos:.3f}")
