In [2]:
import cv2
from ultralytics import YOLO
import os
import time
import numpy as np
import faiss
import torch
from PIL import Image
import json
from datetime import datetime
from explain import explain
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

import clip_encoder

model = YOLO("yolov8n.pt")

alerts_file = "alerts/alerts.json"
os.makedirs("alerts", exist_ok=True)

d = 512

image_index_path = "data/image_index.faiss"
text_index_path = "data/text_index.faiss"
image_paths_path = "data/image_paths.txt"
text_paths_path = "data/text_paths.txt"

if os.path.exists(image_index_path) and os.path.exists(image_paths_path):
    print("[INFO] Loading existing image FAISS index...")
    image_index = faiss.read_index(image_index_path)
    with open(image_paths_path, "r") as f:
        image_gallery_paths = f.read().splitlines()
else:
    print("[INFO] No image FAISS index found, creating a new one...")
    image_index = faiss.IndexFlatIP(d)
    image_gallery_paths = []

if os.path.exists(text_index_path) and os.path.exists(text_paths_path):
    print("[INFO] Loading existing text FAISS index...")
    text_index = faiss.read_index(text_index_path)
    with open(text_paths_path, "r") as f:
        text_gallery_paths = f.read().splitlines()
else:
    print("[INFO] No text FAISS index found, creating a new one...")
    text_index = faiss.IndexFlatIP(d)
    text_gallery_paths = []

output_dir = "data/cropped_objects"
os.makedirs(output_dir, exist_ok=True)

new_queries_file = "data/new_queries.txt"

IMAGE_SIM_THRESHOLD = 0.7
TEXT_SIM_THRESHOLD = 0.6

def embed_text(text: str):
    text_tokens = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    return text_features.cpu().numpy()


Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt to 'yolov8n.pt'...


100%|██████████| 6.25M/6.25M [00:05<00:00, 1.21MB/s]


[INFO] No image FAISS index found, creating a new one...
[INFO] No text FAISS index found, creating a new one...


In [None]:

cap = cv2.VideoCapture(0)
frame_count = 0

last_detected_image = None
last_detected_text = None

IMAGE_SIM_THRESHOLD = 1
TEXT_SIM_THRESHOLD = 0.2

while True:
    ret, frame = cap.read()
    if not ret:
        break

    results = model(frame)[0]

    for i, box in enumerate(results.boxes.xyxy):
        x1, y1, x2, y2 = map(int, box[:4])
        cropped = frame[y1:y2, x1:x2]

        timestamp = int(time.time() * 1000)
        filename = f"{output_dir}/crop_{frame_count}_{i}_{timestamp}.jpg"
        cv2.imwrite(filename, cropped)

        emb = clip_encoder.embed_image(filename).astype("float32")
        faiss.normalize_L2(emb)

        image_match = None
        text_match = None
        
        if image_index.ntotal > 0:
            D_img, I_img = image_index.search(emb, k=1)
            img_score = D_img[0][0]
            img_idx = I_img[0][0]
            
            if img_score >= IMAGE_SIM_THRESHOLD:
                image_match = {
                    'path': image_gallery_paths[img_idx],
                    'score': img_score,
                    'type': 'image'
                }

        if text_index.ntotal > 0:
            D_txt, I_txt = text_index.search(emb, k=1)
            txt_score = D_txt[0][0]
            txt_idx = I_txt[0][0]
            
            if txt_score >= TEXT_SIM_THRESHOLD:
                text_match = {
                    'path': text_gallery_paths[txt_idx],
                    'score': txt_score,
                    'type': 'text'
                }

        y_offset = y1 - 10
        
        if image_match:
            person_id = os.path.basename(image_match['path'])
            text = f"IMG: {person_id} ({image_match['score']:.2f})"
            color = (0, 0, 255)
            
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            cv2.putText(frame, text, (x1, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
            y_offset -= 25
            
            if person_id != last_detected_image:
                last_detected_image = person_id
                
                filename_only = f"img_{person_id}_{int(time.time())}.jpg"
                explained_img = os.path.join("alerts", filename_only)
                
                explain(filename, image_match['path'], explained_img)
                
                alert = {
                    "type": "image",
                    "person": person_id,
                    "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    "image": filename_only,
                    "score": float(image_match['score'])
                }
                
                if os.path.exists(alerts_file):
                    with open(alerts_file, "r") as f:
                        alerts = json.load(f)
                else:
                    alerts = []
                
                alerts.append(alert)
                with open(alerts_file, "w") as f:
                    json.dump(alerts, f, indent=4)
        
        if text_match:
            text_desc = text_match['path']
            text = f"TXT: {text_desc} ({text_match['score']:.2f})"
            color = (255, 0, 0)
            
            if not image_match:
                cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            
            cv2.putText(frame, text, (x1, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
            y_offset -= 25
            
            if text_desc != last_detected_text:
                last_detected_text = text_desc
                
                filename_only = f"txt_{int(time.time())}.jpg"
                saved_img = os.path.join("alerts", filename_only)
                cv2.imwrite(saved_img, cropped)
                
                alert = {
                    "type": "text",
                    "person": text_desc,
                    "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    "image": filename_only,
                    "score": float(text_match['score'])
                }
                
                if os.path.exists(alerts_file):
                    with open(alerts_file, "r") as f:
                        alerts = json.load(f)
                else:
                    alerts = []
                
                alerts.append(alert)
                with open(alerts_file, "w") as f:
                    json.dump(alerts, f, indent=4)
        
        if not image_match and not text_match:
            text = "No match"
            color = (0, 255, 0)
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            cv2.putText(frame, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)

    if os.path.exists(new_queries_file):
        with open(new_queries_file, "r") as f:
            new_queries = f.read().splitlines()

        if new_queries:
            for q in new_queries:
                if q.startswith("img:"):
                    q_path = q.replace("img:", "").strip()
                    if os.path.exists(q_path):
                        print(f"Adding new IMAGE query: {q_path}")
                        emb = clip_encoder.embed_image(q_path).astype("float32")
                        faiss.normalize_L2(emb)
                        image_index.add(emb)
                        image_gallery_paths.append(q_path)

                elif q.startswith("txt:"):
                    q_text = q.replace("txt:", "").strip()
                    if q_text:
                        print(f"Adding new TEXT query: '{q_text}'")
                        emb = embed_text(q_text).astype("float32")
                        faiss.normalize_L2(emb)
                        text_index.add(emb)
                        text_gallery_paths.append(q_text)

            faiss.write_index(image_index, image_index_path)
            faiss.write_index(text_index, text_index_path)
            
            with open(image_paths_path, "w") as f:
                f.write("\n".join(image_gallery_paths))
            
            with open(text_paths_path, "w") as f:
                f.write("\n".join(text_gallery_paths))

            open(new_queries_file, "w").close()

    cv2.imshow("Live-Feed", frame)
    frame_count += 1

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()


0: 480x640 1 person, 1 bottle, 133.7ms
Speed: 16.8ms preprocess, 133.7ms inference, 3.1ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 1 person, 1 bottle, 122.7ms
Speed: 8.7ms preprocess, 122.7ms inference, 2.0ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 1 person, 1 bottle, 1 bowl, 87.5ms
Speed: 1.3ms preprocess, 87.5ms inference, 1.8ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 1 person, 2 bottles, 88.1ms
Speed: 1.5ms preprocess, 88.1ms inference, 1.2ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 1 person, 1 bowl, 75.0ms
Speed: 1.9ms preprocess, 75.0ms inference, 1.1ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 1 person, 1 bottle, 78.5ms
Speed: 1.3ms preprocess, 78.5ms inference, 1.8ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 1 person, 1 bottle, 1 bowl, 72.7ms
Speed: 2.2ms preprocess, 72.7ms inference, 1.1ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 1 person, 79.1ms