In [2]:
# image_caption_retrieval.py
import os
import numpy as np
import pickle
from datetime import datetime
from numpy.linalg import norm
from PIL import Image
import torch
import matplotlib.pyplot as plt
from transformers import BlipProcessor, BlipForConditionalGeneration
from sentence_transformers import SentenceTransformer

# Configuration
IMAGE_DIR = "images"
MODELS_DIR = "models"
DB_FILE = os.path.join(MODELS_DIR, "image_database.pkl")
MIN_CONFIDENCE = 0.4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_COLS = 5  # Maximum images per row

class ImageRetrievalSystem:
    def __init__(self):
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=MODELS_DIR)
        self.caption_model = BlipForConditionalGeneration.from_pretrained(
            "Salesforce/blip-image-captioning-base", 
            cache_dir=MODELS_DIR
        ).to(DEVICE)
        self.embedder = SentenceTransformer(
            "sentence-transformers/all-MiniLM-L6-v2",
            cache_folder=MODELS_DIR
        ).to(DEVICE)
        
        # Load existing database
        self.database = self.load_database()
        
    def load_database(self):
        """Load processed images from previous sessions"""
        if os.path.exists(DB_FILE):
            with open(DB_FILE, "rb") as f:
                db = pickle.load(f)
                print(f"Loaded {len(db)} existing image records")
                return self.validate_database(db)
        return []

    def save_database(self):
        """Save current state to disk"""
        os.makedirs(MODELS_DIR, exist_ok=True)
        with open(DB_FILE, "wb") as f:
            pickle.dump(self.database, f)
            
    def validate_database(self, db):
        """Remove entries for missing images or modified files"""
        valid_entries = []
        for entry in db:
            if not os.path.exists(entry['path']):
                continue
            current_mtime = os.path.getmtime(entry['path'])
            if current_mtime != entry['last_modified']:
                continue
            valid_entries.append(entry)
        return valid_entries

    def generate_caption(self, image_path):
        image = Image.open(image_path).convert("RGB")
        inputs = self.processor(image, return_tensors="pt").to(DEVICE)
        output = self.caption_model.generate(**inputs)
        return self.processor.decode(output[0], skip_special_tokens=True)

    def process_image(self, image_path):
        """Process image only if not already in database"""
        current_mtime = os.path.getmtime(image_path)
        
        # Check if we already have this image
        for entry in self.database:
            if entry['path'] == image_path and entry['last_modified'] == current_mtime:
                return
            
        # Process new/changed image
        caption = self.generate_caption(image_path)
        embedding = self.embedder.encode(caption, convert_to_tensor=True).cpu().numpy()
        
        # Update database
        self.database.append({
            "path": image_path,
            "caption": caption,
            "embedding": embedding,
            "last_modified": current_mtime
        })
        print(f"Processed: {image_path}")

    def search(self, text_query):
        query_embedding = self.embedder.encode(text_query, convert_to_tensor=True).cpu().numpy()
        results = []
        
        for item in self.database:
            sim = np.dot(query_embedding, item["embedding"]) / (
                norm(query_embedding) * norm(item["embedding"])
            )
            results.append({
                "path": item["path"],
                "caption": item["caption"],
                "confidence": sim
            })
        
        results.sort(key=lambda x: x["confidence"], reverse=True)
        return results

def display_results(results, query, num_results):
    valid_results = [r for r in results if r["confidence"] >= MIN_CONFIDENCE][:num_results]
    
    if not valid_results:
        print(f"No images found with confidence ≥ {MIN_CONFIDENCE}")
        return

    total_images = len(valid_results)
    rows = (total_images + MAX_COLS - 1) // MAX_COLS  # Calculate rows needed
    cols = min(MAX_COLS, total_images)
    
    plt.figure(figsize=(15, 3.5 * rows))
    plt.suptitle(
        f"Found {total_images} results for: '{query}'\n(Minimum confidence: {MIN_CONFIDENCE})", 
        y=1.02 + (rows * 0.03), 
        fontsize=12
    )
    
    for idx, result in enumerate(valid_results, 1):
        plt.subplot(rows, cols, idx)
        img = Image.open(result["path"])
        plt.imshow(img)
        
        # Truncate long captions for display
        caption = result["caption"] if len(result["caption"]) <= 35 else f"{result['caption'][:32]}..."
        
        plt.title(
            f"Confidence: {result['confidence']:.2f}\n{caption}",
            fontsize=9,
            pad=2
        )
        plt.axis('off')
    
    plt.tight_layout(pad=1.5)
    plt.show()

def get_positive_integer(prompt):
    while True:
        try:
            n = int(input(prompt))
            if n > 0:
                return n
            print("Please enter a number greater than 0")
        except ValueError:
            print("Invalid input. Please enter a whole number")

def main():
    system = ImageRetrievalSystem()
    
    if not os.path.exists(IMAGE_DIR):
        os.makedirs(IMAGE_DIR)
        print(f"Add images to '{IMAGE_DIR}' and restart")
        return

    print("\nChecking for new/changed images...")
    new_files = 0
    for fname in os.listdir(IMAGE_DIR):
        if fname.lower().endswith((".png", ".jpg", ".jpeg")):
            img_path = os.path.join(IMAGE_DIR, fname)
            system.process_image(img_path)
            new_files += 1
            
    if new_files > 0:
        system.save_database()
        print(f"Processed {new_files} new/changed images")

    print("\nImage retrieval ready! (type 'exit' to quit)")
    while True:
        query = input("\nSearch query: ").strip()
        if query.lower() == "exit":
            break
            
        results = system.search(query)
        if not results:
            print("No matches found!")
            continue
            
        num_images = get_positive_integer("How many images would you like to display? ")
        display_results(results, query, num_images)

if __name__ == "__main__":
    print(f"Running on device: {DEVICE.upper()}")
    main()

Running on device: CPU

Checking for new/changed images...

Image retrieval ready! (type 'exit' to quit)
No matches found!
No matches found!
